merge gated stores spec (#5652)

* test_unmerged_ifs should merge ifs

* test_tiny_gate_store

* test_merge_ifs_alt

* assert assert asserts
This commit is contained in:
qazal 2024-07-23 18:53:27 +08:00 committed by GitHub
parent 4dcca0a6d4
commit 7cb67e6fb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 23 deletions

View File

@ -4,6 +4,7 @@
import unittest
from tinygrad import Device, dtypes
from tinygrad.codegen.uops import UOps
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, TernaryOps, BufferOps, MemBuffer, ConstBuffer, MetaOps # noqa: F401 # pylint: disable=unused-import
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps
@ -29,8 +30,13 @@ class TestLinearizerDumb(unittest.TestCase):
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
prg.uops.print()
k.uops.print()
print(prg.src)
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
with self.assertRaises(AssertionError):
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_max_simplify_and_cancel(self):

View File

@ -237,43 +237,60 @@ class TestConstantFolding(unittest.TestCase):
assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
class TestGatedStoreRewrite(unittest.TestCase):
@unittest.skip("not yet implemented")
def test_wrap_store_parents(self):
# wraps all store parents in the valid branch
@unittest.expectedFailure
def test_tiny_gate_store(self):
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value = UOp(UOps.CONST, dtypes.float, (), 42.0)
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
store = UOp(UOps.STORE, None, (gmem, idx, val, gate))
uops = UOpGraph([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
assert endif.src[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem, gidx0, idx, value)
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, UOps.STORE)
@unittest.skip("not yet implemented")
def test_wrap_some_parents(self):
# some parents are used outside the branch
@unittest.expectedFailure
def test_gate_some_stores(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
value0 = UOp(UOps.CONST, dtypes.float, (), 42.0)
value1 = UOp(UOps.CONST, dtypes.float, (), 43.0)
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
outs = [UOp(UOps.STORE, None, (gmem0, idx, value0, gate))]
outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
uops = UOpGraph(outs)
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
uops = UOpGraph(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
assert endif.src[0] is if_uop
nested_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
assert nested_uops == (gmem0, value0)
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, UOps.STORE)
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.expectedFailure
def test_merge_ifs_alt(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
uops = UOpGraph(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
ifs = [u for u in uops if u.op is UOps.IF]
endifs = [u for u in uops if u.op is UOps.ENDIF]
self.assertEqual(len(ifs), 1)
self.assertEqual(len(endifs), 1)
gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])])
self.assertEqual(len(gated_uops), 2)
for x in gated_uops: self.assertIs(x.op, UOps.STORE)
class TestLocalAccess(unittest.TestCase):
# NOTE: this is failing on METAL CI, no idea why. Works locally.