rewrite bool ADD to OR and MUL to AND (#6084)

* rewrite bool ADD to OR and MUL to AND

fixed running `tinyphysics.onnx`, which contains a getitem from a boolean tensor.

only can repro through BEAM_COMPARE, which i think is a different bug in test_linearizer_failure

* fold those, and fix tests

* only for bool

* move dtypes.bool
This commit is contained in:
chenyu 2024-08-15 10:11:57 -04:00 committed by GitHub
parent b765996d54
commit 5accfe26a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 4 deletions

View File

@ -474,5 +474,22 @@ class TestLinearizerFailures(unittest.TestCase):
opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
def test_failure_50(self):
# from BEAM_COMPARE=2 running tinyphysics.onnx model
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),))), src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(ReduceOps.SUM, arg=(3,), src=(
LazyOp(BinaryOps.MUL, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),))), src=()),
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BinaryOps.CMPNE, arg=None, src=(
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)]
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@ -566,7 +566,7 @@ class TestIFUOps(TestUOps):
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid*(lidx.ne(2))
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
st = UOp(UOps.STORE, None, (sbuf, idx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, None, (st,))
@ -585,7 +585,7 @@ class TestIFUOps(TestUOps):
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid*(lidx.ne(2))
gate = valid&(lidx.ne(2))
st = UOp(UOps.STORE, None, (sbuf, lidx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, None, (st,))
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
@ -604,7 +604,7 @@ class TestIFUOps(TestUOps):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid*(lidx.ne(2))
gate = valid&(lidx.ne(2))
stores = [UOp(UOps.STORE, None, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(stores))
sink = gate_rewrite(sink)

View File

@ -89,7 +89,7 @@ class TestSymbolic(unittest.TestCase):
def test_ge_divides_and(self):
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)*(idx2<128))"})
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)&(idx2<128))"})
# # bool divided by int is not allowed
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])

View File

@ -186,6 +186,9 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
# this is symbolic 2.0
constant_folder = PatternMatcher([
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
(UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
# VECTORIZE/GEP
(NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
*[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
@ -258,6 +261,8 @@ constant_folder = PatternMatcher([
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
(NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1
(NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c),
(NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x),
# ** zero folding **
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.