mirror of https://github.com/commaai/tinygrad.git
rewrite bool ADD to OR and MUL to AND (#6084)
* rewrite bool ADD to OR and MUL to AND fixed running `tinyphysics.onnx`, which contains a getitem from a boolean tensor. only can repro through BEAM_COMPARE, which i think is a different bug in test_linearizer_failure * fold those, and fix tests * only for bool * move dtypes.bool
This commit is contained in:
parent
b765996d54
commit
5accfe26a0
|
@ -474,5 +474,22 @@ class TestLinearizerFailures(unittest.TestCase):
|
|||
opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
def test_failure_50(self):
|
||||
# from BEAM_COMPARE=2 running tinyphysics.onnx model
|
||||
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
|
||||
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),))), src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(ReduceOps.SUM, arg=(3,), src=(
|
||||
LazyOp(BinaryOps.MUL, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=1, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BinaryOps.CMPNE, arg=None, src=(
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),
|
||||
LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),))), src=()),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),)),
|
||||
LazyOp(BufferOps.CONST, arg=ConstBuffer(val=True, dtype=dtypes.bool, st=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))), src=()),)),)),))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)]
|
||||
helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -566,7 +566,7 @@ class TestIFUOps(TestUOps):
|
|||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
st = UOp(UOps.STORE, None, (sbuf, idx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st,))
|
||||
|
@ -585,7 +585,7 @@ class TestIFUOps(TestUOps):
|
|||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
st = UOp(UOps.STORE, None, (sbuf, lidx, UOp.const(dtypes.float, 42)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st,))
|
||||
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
|
||||
|
@ -604,7 +604,7 @@ class TestIFUOps(TestUOps):
|
|||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid*(lidx.ne(2))
|
||||
gate = valid&(lidx.ne(2))
|
||||
stores = [UOp(UOps.STORE, None, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(stores))
|
||||
sink = gate_rewrite(sink)
|
||||
|
|
|
@ -89,7 +89,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_ge_divides_and(self):
|
||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)*(idx2<128))"})
|
||||
self.helper_test_variable(expr, 0, 1, {"((idx1<128) and (idx2<128))", "((idx1<128)&(idx2<128))"})
|
||||
# # bool divided by int is not allowed
|
||||
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
|
||||
|
|
|
@ -186,6 +186,9 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
|||
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
(UPat(UOps.ALU, BinaryOps.ADD, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.OR)),
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, dtype=dtypes.bool, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# VECTORIZE/GEP
|
||||
(NOp(UOps.GEP, src=(NOp(UOps.VECTORIZE, name="cast"),), name="gep"), lambda gep, cast: cast.src[gep.arg]),
|
||||
*[(NOp(UOps.VECTORIZE, dtypes.float.vec(i), tuple(NOp(UOps.GEP, dtypes.float,
|
||||
|
@ -258,6 +261,8 @@ constant_folder = PatternMatcher([
|
|||
(NOp.var('x') // 1, lambda x: x), # x//1 -> x
|
||||
(NOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
||||
(NOp.var('x') / NOp.var('x'), lambda x: x.const(1)), # x/x -> 1
|
||||
(NOp.var('x', dtype=dtypes.bool) & NOp.cvar('c'), lambda x,c: x if c.arg else c),
|
||||
(NOp.var('x', dtype=dtypes.bool) | NOp.cvar('c'), lambda x,c: c if c.arg else x),
|
||||
# ** zero folding **
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
|
|
Loading…
Reference in New Issue