diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8ea5f28c..9fcc619d 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -253,7 +253,6 @@ class TestSymbolic(unittest.TestCase): def test_distribute_mul(self): self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, {"((a*3)+(b*3))", "((a+b)*3)"}) - @unittest.expectedFailure def test_mod_mul_sum(self): self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 99a5a535..124c17aa 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -254,7 +254,8 @@ constant_folder = PatternMatcher([ (UOp.var('x') % UOp.cvar('c'), lambda x,c: (x-(x.vmin.arg//c.arg)*c.arg)%c if 0 < c.arg <= x.vmin.arg else None), # mul -> (sum) -> mod ((UOp.cvar('c0')*UOp.var('x')) % UOp.cvar('c1'), lambda x,c0,c1: x*(c0.arg%c1.arg)%c1 if c0.arg >= c1.arg > 0 else None), - (((UOp.cvar('c')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c'), lambda x,c,x2: x2%c), + (((UOp.cvar('c0')*UOp.var('x'))+UOp.var('x2')) % UOp.cvar('c1'), + lambda x,x2,c0,c1: x2%c1 if (r:=c0.arg%c1.arg) == 0 else (x*r+x2)%c1 if c0.arg >= c1.arg > 0 else None), # mod mod ((UOp.var('x') % UOp.cvar('c0')) % UOp.cvar('c1'), lambda x,c0,c1: x % c0 if 0 < c0.arg < c1.arg else x % c1 if c0.arg % c1.arg == 0 else None), (((UOp.var('x') * UOp.cvar('c0')) % UOp.cvar('c1')) % UOp.cvar('c0'), lambda x,c0,c1: x.const(0)),