diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 8d950c81..d6740ccd 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -153,7 +153,7 @@ class TestSymbolic(unittest.TestCase): def test_mod_to_sub(self): # This is mod reduction - self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render()) + self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(-1+a)") def test_sum_div_const(self): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 6e108fc5..8ea5f28c 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -209,10 +209,9 @@ class TestSymbolic(unittest.TestCase): # NOTE: even though the mod max is 50, it can't know this without knowing about the mul self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") - @unittest.expectedFailure def test_mod_to_sub(self): # This is mod reduction - self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render()) + self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, {"(-1+a)", "(a+(-1))"}) @unittest.expectedFailure def test_sum_div_const(self): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index a5c89fd9..99a5a535 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -250,6 +250,8 @@ constant_folder = PatternMatcher([ (UOp.var('x') // UOp.cvar('c'), lambda x,c: x.const(0) if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None), # mod folding (UOp.var('x') % UOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None), + # mod reduction + (UOp.var('x') % UOp.cvar('c'), lambda x,c: (x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None), # mul -> (sum) -> mod ((UOp.cvar('c0')*UOp.var('x')) % UOp.cvar('c1'), lambda x,c0,c1: x*(c0.arg%c1.arg)%c1 if c0.arg >= c1.arg > 0 else None), (((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c),