From 5accfe26a0ad7e2294619c428bcd378939c6fa72 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 15 Aug 2024 10:11:57 -0400 Subject: [PATCH] rewrite bool ADD to OR and MUL to AND (#6084) * rewrite bool ADD to OR and MUL to AND fixed running `tinyphysics.onnx`, which contains a getitem from a boolean tensor. only can repro through BEAM_COMPARE, which i think is a different bug in test_linearizer_failure * fold those, and fix tests * only for bool * move dtypes.bool --- test/test_linearizer_failures.py | 17 +++++++++++++++++ test/test_uop_graph.py | 6 +++--- test/unit/test_uop_symbolic.py | 2 +- tinygrad/codegen/uopgraph.py | 5 +++++ 4 files changed, 26 insertions(+), 4 deletions(-) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 0bc180dd..6103746e 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -474,5 +474,22 @@ class TestLinearizerFailures(unittest.TestCase): opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + def test_failure_50(self): + # from BEAM_COMPARE=2 running tinyphysics.onnx model + ast = LazyOp(MetaOps.KERNEL, arg=None, src=( + LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),))), src=( + LazyOp(BinaryOps.CMPNE, arg=None, src=( + LazyOp(ReduceOps.SUM, arg=(3,), src=( + LazyOp(BinaryOps.MUL, arg=None, src=( + LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),))), src=()), + LazyOp(BinaryOps.CMPNE, arg=None, src=( + LazyOp(BinaryOps.CMPNE, arg=None, src=( + LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),))), src=()), + LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)), + LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)), + LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)] + helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) + if __name__ == '__main__': unittest.main() diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 652ae30c..ef053d84 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -566,7 +566,7 @@ class TestIFUOps(TestUOps): sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) - gate = valid*(lidx.ne(2)) + gate = valid&(lidx.ne(2)) idx = UOp.const(dtypes.int, 0) st = UOp(UOps.STORE, None, (sbuf, idx, UOp.const(dtypes.float, 42))) barrier = UOp(UOps.BARRIER, None, (st,)) @@ -585,7 +585,7 @@ class TestIFUOps(TestUOps): sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) - gate = valid*(lidx.ne(2)) + gate = valid&(lidx.ne(2)) st = UOp(UOps.STORE, None, (sbuf, lidx, UOp.const(dtypes.float, 42))) barrier = UOp(UOps.BARRIER, None, (st,)) lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)] @@ -604,7 +604,7 @@ class TestIFUOps(TestUOps): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) - gate = valid*(lidx.ne(2)) + gate = valid&(lidx.ne(2)) stores = [UOp(UOps.STORE, None, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, None, tuple(stores)) sink = gate_rewrite(sink) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index a12b9440..6da112f1 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -89,7 +89,7 @@ class TestSymbolic(unittest.TestCase): def test_ge_divides_and(self): expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)]) - self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)*(idx2<128))"}) + self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)&(idx2<128))"}) # # bool divided by int is not allowed # expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), # create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)]) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index e483c40a..7ac6b5c4 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -186,6 +186,9 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce): # this is symbolic 2.0 constant_folder = PatternMatcher([ + # bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly + (UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)), + (UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), # VECTORIZE/GEP (NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]), *[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float, @@ -258,6 +261,8 @@ constant_folder = PatternMatcher([ (NOp.var('x') // 1, lambda x: x), # x//1 -> x (NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x (NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1 + (NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c), + (NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x), # ** zero folding ** # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value.