mirror of https://github.com/commaai/tinygrad.git
parent
e7f6b654ad
commit
6fd24561d1
|
@ -312,6 +312,12 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
|
||||
def test_sum_mul_distribute(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
|
|
|
@ -373,6 +373,12 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
|
||||
|
||||
def test_sum_mul_distribute(self):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
|
||||
|
||||
# *** below are uop_symbolic only
|
||||
|
||||
# NOTE: tests are not correct in symbolic
|
||||
|
|
|
@ -312,11 +312,8 @@ constant_folder = PatternMatcher([
|
|||
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
|
||||
# (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue
|
||||
((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1:
|
||||
x*c1+c0*c1 if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# (x*c0)+(y*c0) -> (x+y)*c0
|
||||
#((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
|
||||
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
|
||||
((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
|
||||
# x!=0 -> (bool)x
|
||||
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
|
||||
# TODO: can do the invert of this (flip alt/load) when we fix double ops
|
||||
|
|
Loading…
Reference in New Issue