trim const in UOp div_folding (#5982)

simplify `(4*x+4*y+7)//16` to `(x+y+1)//4`.
fixed `GPU=1 UOP_IS_SYMBOLIC=1 IMAGE=2 python -m pytest test/test_ops.py -k conv`
This commit is contained in:
chenyu 2024-08-08 12:49:05 -04:00 committed by GitHub
parent e6d41b0ce7
commit 62c77a2831
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 5 deletions

View File

@ -145,6 +145,9 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_some_factor(self): def test_sum_div_some_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((1+a+b)//4)")
def test_sum_div_some_partial_factor(self): def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")

View File

@ -203,6 +203,9 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_some_factor(self): def test_sum_div_some_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, {"(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"}) self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, {"(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"})
def test_sum_div_trim_const(self):
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, {"((1+a+b)//4)", "((a+b+1)//4)"})
def test_sum_div_some_partial_factor(self): def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")

View File

@ -115,12 +115,13 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
something_changed = True something_changed = True
quotient.append(x.const(rem_const//c)) quotient.append(x.const(rem_const//c))
rem_const = rem_const%c rem_const = rem_const%c
# make const a multiple of gcd
if c > 0 and rem_const > 0 and rem_const % gcd != 0:
something_changed = True
rem_const = (rem_const//gcd)*gcd
if rem_const != 0: if rem_const != 0:
if 0 < rem_const < gcd: something_changed = True # cancel the const gcd = math.gcd(gcd, rem_const)
else: remainder.append(x.const(rem_const))
# only include rem_const in remainder if it's not cancelled
gcd = math.gcd(gcd, rem_const)
remainder.append(x.const(rem_const))
if not something_changed: return cast(UOp, x.divides(gcd))//(c//gcd) if gcd != c and gcd != 1 else None if not something_changed: return cast(UOp, x.divides(gcd))//(c//gcd) if gcd != c and gcd != 1 else None
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None