skip grouped store for umatching upcasts (#3723)

* skip if upcasts dont match

* outputs match now

* this ast is hardcoded

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
qazal 2024-03-14 07:18:31 +02:00 committed by GitHub
parent 199f7c4342
commit 43953c0ba9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 2 deletions

View File

@ -843,5 +843,21 @@ class TestLinearizerUOptimize(unittest.TestCase):
# the global store doesn't change
assert stores[1].vin[-1].dtype == dtypes.float
def test_skip_unmatching_upcasts(self):
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
self.skipTest("Needs locals and float4")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
]
k = Linearizer(ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
out = [u for u in k.uops if u.uop == UOps.STORE][0]
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
if __name__ == '__main__':
unittest.main()

View File

@ -180,8 +180,7 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_23(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))))
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
# Output does not match...
helper_test_lin(Linearizer(ast), opts, failed_platforms=["CUDA", "HIP", "HSA", "METAL", "GPU"])
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@ -336,6 +336,7 @@ class UOpGraph:
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
for prev,new in replaced_stores.items():
if prev.vin[-1].dtype != new.dtype: continue
try: self.uops.remove(prev.vin[-1]) # remove the old upcast NOTE: the upcast's vins become childless now
except ValueError: pass # already removed
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],new) # replace with the float4 value