mirror of https://github.com/commaai/tinygrad.git
UOps div folding (#5690)
#5689, with just div folding and new test cases
This commit is contained in:
parent
fb1b51811b
commit
85710e86cb
|
@ -118,6 +118,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_add_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
|
||||
|
||||
def test_div_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
|
||||
|
||||
|
@ -125,6 +128,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
|
||||
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
|
@ -242,7 +248,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_mul_div_factor_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_div_remove(self):
|
||||
def test_sum_div_partial_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
def test_div_numerator_negative(self):
|
||||
|
|
|
@ -172,6 +172,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_add_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8) * 2 + 12, 12, 16+12, "((a*2)+12)")
|
||||
|
||||
def test_div_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 20, 0, 0, "0")
|
||||
|
||||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
|
||||
|
||||
|
@ -180,6 +183,10 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable("a", 0, 7) // -2, -4, 0, "((((a*-1)+8)//2)+-4)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((((a*-1)+6)//2)+-3)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
|
@ -309,7 +316,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_div_remove(self):
|
||||
def test_sum_div_partial_remove(self):
|
||||
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
@unittest.expectedFailure
|
||||
|
|
|
@ -246,6 +246,8 @@ constant_folder = PatternMatcher([
|
|||
((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))),
|
||||
((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+x.const(exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),
|
||||
# *** rules from symbolic ***
|
||||
# div folding
|
||||
(UOp.var('x') // UOp.cvar('c'), lambda x,c: x.const(0) if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
|
||||
# mod folding
|
||||
(UOp.var('x') % UOp.cvar('c'), lambda x,c: x if 0 <= x.vmin.arg <= x.vmax.arg < c.arg else None),
|
||||
# mul -> (sum) -> mod
|
||||
|
|
|
@ -110,7 +110,8 @@ class UOp:
|
|||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool: return self.const(-self.src[0].vmax.arg), self.const(-self.src[0].vmin.arg)
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
||||
return self.const(-self.src[0].vmax.arg), self.const(-self.src[0].vmin.arg)
|
||||
return None, None
|
||||
|
||||
class UPat:
|
||||
|
|
Loading…
Reference in New Issue