mirror of https://github.com/commaai/tinygrad.git
trim const in UOp div_folding (#5982)
simplify `(4*x+4*y+7)//16` to `(x+y+1)//4`. fixed `GPU=1 UOP_IS_SYMBOLIC=1 IMAGE=2 python -m pytest test/test_ops.py -k conv`
This commit is contained in:
parent
e6d41b0ce7
commit
62c77a2831
|
@ -145,6 +145,9 @@ class TestSymbolic(unittest.TestCase):
|
||||||
def test_sum_div_some_factor(self):
|
def test_sum_div_some_factor(self):
|
||||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||||
|
|
||||||
|
def test_sum_div_trim_const(self):
|
||||||
|
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, "((1+a+b)//4)")
|
||||||
|
|
||||||
def test_sum_div_some_partial_factor(self):
|
def test_sum_div_some_partial_factor(self):
|
||||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||||
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||||
|
|
|
@ -203,6 +203,9 @@ class TestSymbolic(unittest.TestCase):
|
||||||
def test_sum_div_some_factor(self):
|
def test_sum_div_some_factor(self):
|
||||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, {"(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"})
|
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, {"(((a*5)//2)+(b*2))", "((b*2)+((a*5)//2))"})
|
||||||
|
|
||||||
|
def test_sum_div_trim_const(self):
|
||||||
|
self.helper_test_variable((Variable("a", 0, 7)*4 + Variable("b", 0, 3)*4 + 7) // 16, 0, 2, {"((1+a+b)//4)", "((a+b+1)//4)"})
|
||||||
|
|
||||||
def test_sum_div_some_partial_factor(self):
|
def test_sum_div_some_partial_factor(self):
|
||||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||||
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||||
|
|
|
@ -115,12 +115,13 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||||
something_changed = True
|
something_changed = True
|
||||||
quotient.append(x.const(rem_const//c))
|
quotient.append(x.const(rem_const//c))
|
||||||
rem_const = rem_const%c
|
rem_const = rem_const%c
|
||||||
|
# make const a multiple of gcd
|
||||||
|
if c > 0 and rem_const > 0 and rem_const % gcd != 0:
|
||||||
|
something_changed = True
|
||||||
|
rem_const = (rem_const//gcd)*gcd
|
||||||
if rem_const != 0:
|
if rem_const != 0:
|
||||||
if 0 < rem_const < gcd: something_changed = True # cancel the const
|
gcd = math.gcd(gcd, rem_const)
|
||||||
else:
|
remainder.append(x.const(rem_const))
|
||||||
# only include rem_const in remainder if it's not cancelled
|
|
||||||
gcd = math.gcd(gcd, rem_const)
|
|
||||||
remainder.append(x.const(rem_const))
|
|
||||||
|
|
||||||
if not something_changed: return cast(UOp, x.divides(gcd))//(c//gcd) if gcd != c and gcd != 1 else None
|
if not something_changed: return cast(UOp, x.divides(gcd))//(c//gcd) if gcd != c and gcd != 1 else None
|
||||||
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
rem:Optional[UOp] = functools.reduce(operator.add, remainder) if remainder else None
|
||||||
|
|
Loading…
Reference in New Issue