UOp more generic div folding (#5722)

old: `x // c` can fold if `0 <= x.vmin <= x.vmax < c`
new: `x // c` can fold if `0 < c and x.vmin // c == x.vmax // c`
This commit is contained in:
chenyu 2024-07-25 17:49:14 -04:00 committed by GitHub
parent fb8148077e
commit 845b0d1c9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 5 additions and 5 deletions

View File

@ -45,8 +45,8 @@ class TestSymbolic(unittest.TestCase):
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
def test_div_becomes_num(self):
assert isinstance(Variable("a", 2, 3)//2, NumNode)
def test_div_reduction(self):
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
def test_var_becomes_num(self):
assert isinstance(Variable("a", 2, 2), NumNode)

View File

@ -97,8 +97,8 @@ class TestSymbolic(unittest.TestCase):
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
#def test_div_becomes_num(self):
# assert isinstance(Variable("a", 2, 3)//2, NumNode)
def test_div_reduction(self):
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
#def test_var_becomes_num(self):
# assert isinstance(Variable("a", 2, 2), NumNode)

View File

@ -247,7 +247,7 @@ constant_folder = PatternMatcher([
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
# *** rules from symbolic ***
# div folding
(UOp.var('x') // UOp.cvar('c'), lambda x,c: x.const(0) if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
(UOp.var('x') // UOp.cvar('c'), lambda x,c: x.const(x.vmin.arg//c.arg) if c.arg > 0 and x.vmin.arg//c.arg == x.vmax.arg//c.arg else None),
# mod folding
(UOp.var('x') % UOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
# mod reduction