diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 67c188db..39ce36d3 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -395,6 +395,11 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0") self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0") + def test_div_mod_recombine(self): + gidx = Variable("gidx", 0, 124) + self.helper_test_variable(gidx%4+(gidx//4)*4, 0, 124, "gidx") + self.helper_test_variable((gidx//4)*4+gidx%4, 0, 124, "gidx") + @unittest.skip("not supported on uops yet") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 8b2ad34c..e2909244 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -296,6 +296,8 @@ constant_folder = PatternMatcher([ ((NOp.cvar('c0')*NOp.var('x')) % NOp.cvar('c1'), lambda x,c0,c1: (x%(c1.arg//c0.arg))*c0 if c1.arg%c0.arg == 0 else None), # mod mod ((NOp.var('x') % NOp.cvar('c0')) % NOp.cvar('c1'), lambda x,c0,c1: x % c1 if c0.arg % c1.arg == 0 else None), + # (x%c)+(x//c)*c = x + (NOp.var('x')%NOp.cvar('c')+(NOp.var('x')//NOp.cvar('c'))*NOp.cvar('c'), lambda x,c: x), # ** combine terms ** # -(x+y) -> -x + -y (-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)),