distribute MUL const into ADD for int (#6361)

pre-req for real_stride
This commit is contained in:
chenyu 2024-09-05 01:36:57 -04:00 committed by GitHub
parent e7f6b654ad
commit 6fd24561d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 5 deletions

View File

@ -312,6 +312,12 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
def test_sum_mul_distribute(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@ -373,6 +373,12 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((gidx0*4+lidx2*2+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
self.helper_test_variable((lidx2*2+gidx0*4+lidx3)//12, 0, 4, "(((lidx2//2)+gidx0)//3)")
def test_sum_mul_distribute(self):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
# *** below are uop_symbolic only
# NOTE: tests are not correct in symbolic

View File

@ -312,11 +312,8 @@ constant_folder = PatternMatcher([
((NOp.var("x") / NOp.var("x2")) / NOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-(NOp.var("x") + NOp.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
((NOp.cvar("c0") + NOp.var("x")).lt(NOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, c1-c0)), # c0 + x < c1 -> x < c1 - c0
# (x+c0)*c1 -> x*c1+c0*c1. only for signed int, float have inf*0=nan issue
((NOp.var("x") + NOp.cvar("c0")) * NOp.cvar("c1"), lambda x,c0,c1:
x*c1+c0*c1 if dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# (x*c0)+(y*c0) -> (x+y)*c0
#((NOp.var("x") * NOp.cvar("c0")) + (NOp.var("y") * NOp.cvar("c0")), lambda x,y,c0: c0*(x+y)),
# (x+y)*c -> x*c+y*c. only for int, float has inf*0=nan issue
((NOp.var("x") + NOp.var("y")) * NOp.cvar("c"), lambda x,y,c: x*c+y*c if dtypes.is_int(x.dtype) else None),
# x!=0 -> (bool)x
(NOp.var("x").ne(0), lambda x: x.cast(dtypes.bool)),
# TODO: can do the invert of this (flip alt/load) when we fix double ops