mirror of https://github.com/commaai/tinygrad.git
skip grouped store for umatching upcasts (#3723)
* skip if upcasts dont match * outputs match now * this ast is hardcoded --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
199f7c4342
commit
43953c0ba9
|
@ -843,5 +843,21 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
|||
# the global store doesn't change
|
||||
assert stores[1].vin[-1].dtype == dtypes.float
|
||||
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
|
||||
self.skipTest("Needs locals and float4")
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||
opts = [
|
||||
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
|
||||
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
|
||||
]
|
||||
|
||||
k = Linearizer(ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
|
||||
out = [u for u in k.uops if u.uop == UOps.STORE][0]
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -180,8 +180,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
|||
def test_failure_23(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
|
||||
# Output does not match...
|
||||
helper_test_lin(Linearizer(ast), opts, failed_platforms=["CUDA", "HIP", "HSA", "METAL", "GPU"])
|
||||
helper_test_lin(Linearizer(ast), opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -336,6 +336,7 @@ class UOpGraph:
|
|||
if all(el.uop is UOps.GEP for el in val.vin): replaced_stores[u] = val.vin[0].vin[0]
|
||||
elif all(el.uop is UOps.PHI for el in val.vin): replaced_stores[u] = phi_resolve_acc(val)
|
||||
for prev,new in replaced_stores.items():
|
||||
if prev.vin[-1].dtype != new.dtype: continue
|
||||
try: self.uops.remove(prev.vin[-1]) # remove the old upcast NOTE: the upcast's vins become childless now
|
||||
except ValueError: pass # already removed
|
||||
self.uops[self.uops.index(prev)].vin = (prev.vin[0],prev.vin[1],new) # replace with the float4 value
|
||||
|
|
Loading…
Reference in New Issue