diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 65dc2dae..00d3f53b 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -312,6 +312,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)") self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)") + def test_sum_mul_distribute(self): + gidx0 = Variable("gidx0", 0, 7) + lidx2 = Variable("lidx2", 0, 12) + lidx3 = Variable("lidx3", 0, 1) + self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): # TODO: why are the negative tests broken? (even if we did support negative variables) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 34246c7b..6f35e57e 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -373,6 +373,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)") self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)") + def test_sum_mul_distribute(self): + gidx0 = Variable("gidx0", 0, 7) + lidx2 = Variable("lidx2", 0, 12) + lidx3 = Variable("lidx3", 0, 1) + self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))") + # *** below are uop_symbolic only # NOTE: tests are not correct in symbolic diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6185a1a3..7dea1e3d 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -312,11 +312,8 @@ constant_folder = PatternMatcher([ ((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) (-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y ((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0 - # (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue - ((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1: - x*c1+c0*c1 if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None), - # (x*c0)+(y*c0) -> (x+y)*c0 - #((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)), + # (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue + ((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None), # x!=0 -> (bool)x (NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)), # TODO: can do the invert of this (flip alt/load) when we fix double ops