failed uop_symbolic divmod test by variable (#6414)

This commit is contained in:
chenyu 2024-09-08 23:08:58 -04:00 committed by GitHub
parent 88941bcf16
commit 25af78c593
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -318,6 +318,14 @@ class TestSymbolic(unittest.TestCase):
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
def test_variable_divmod(self):
start_pos = Variable("start_pos", 0, 127)
v = start_pos + 1
idx0 = Variable("idx0", 0, 2)
idx1 = Variable("idx1", 0, start_pos)
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@ -27,7 +27,9 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax):
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
vmin = UOp.const(dtypes.int, nmin)
vmax = UOp.const(dtypes.int, nmax) if isinstance(nmax, int) else nmax
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, vmin, vmax))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@ -375,6 +377,15 @@ class TestSymbolic(unittest.TestCase):
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "((gidx0*4)+(lidx2*4)+(lidx3*4))")
@unittest.expectedFailure
def test_variable_divmod(self):
start_pos = Variable("start_pos", 0, 127)
v = start_pos + 1
idx0 = Variable("idx0", 0, 2)
idx1 = Variable("idx1", 0, start_pos)
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
# *** below are uop_symbolic only
# NOTE: tests are not correct in symbolic