more symbolic symbolic ops (#1564)

* more symbolic symbolic ops

* handle NumNode in __mul__
This commit is contained in:
chenyu 2023-08-18 09:21:41 -07:00 committed by GitHub
parent dfec16cc83
commit be50b2fe8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 23 deletions

View File

@ -139,7 +139,7 @@ class TestSymbolicShapeExpr(unittest.TestCase):
view = View(shape, strides)
st = ShapeTracker(shape, [view])
idx, valid = st.expr_idxs(idx)
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
class TestShapeTrackerVarVals(unittest.TestCase):
def test_reshape_reshape_updates_var_vals(self):

View File

@ -277,25 +277,29 @@ class TestSymRender(unittest.TestCase):
assert sym_render(a*b) == "(a*b)"
class TestSymbolicSymbolicOps(unittest.TestCase):
def test_node_div_node(self):
def test_node_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert i // i == 1
def test_node_mod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 % (Variable("i", 1, 10)*128) == 0
assert 127 // (Variable("i", 1, 10)*128) == 0
assert 127 % (Variable("i", 1, 10)*128) == 127
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 // (Variable("i", 1, 10)*128) == 0
assert 0 % (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert idx0 % (i*3) == idx0
assert i // i == 1
assert i % i == 0
assert 128 // NumNode(4) == 32
assert 128 % NumNode(4) == 0
assert NumNode(128) // NumNode(4) == 32
assert NumNode(128) % NumNode(4) == 0
def test_mulnode_divmod_node(self):
i = Variable("i", 1, 10)
@ -311,7 +315,26 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
idx2 = Variable("idx2", 0, i)
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
assert (i+1) // (i*128+128) == 0
assert (i+1) % (i*128+128) == (i+1)
assert (i+1+idx2) // (i+1) == 1
assert (i+1+idx2) % (i+1) == idx2
assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
assert (i*128+128)*2 // (i*128+128) == 2
assert (i*128+128)*2 % (i*128+128) == 0
def test_sumnode_divmod_sumnode_complex(self):
i = Variable("i", 1, 1024)
gidx0 = Variable("gidx0", 0, i)
lidx1 = Variable("lidx1", 0, 7)
ridx2 = Variable("ridx1", 0, 31)
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) // (i*128+128) == 2 + lidx1*4
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) % (i*128+128) == gidx0*128 + ridx2*4
assert ((gidx0*128+i*128+ridx2*4+129)) // (i*128+128) == 1
assert ((gidx0*128+i*128+ridx2*4+129)) % (i*128+128) == gidx0*128 + ridx2*4 + 1
assert (ridx2*(i*4+4)+1+i+gidx0) // (i*128+128) == 0
assert (ridx2*(i*4+4)+1+i+gidx0) % (i*128+128) == (ridx2*(i*4+4)+1+i+gidx0)
def test_node_lt_node(self):
a = Variable("a", 1, 5)
@ -330,11 +353,16 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
assert a > 3
def test_num_node_mul_node(self):
a = NumNode(2)
b = Variable("b", 1, 5)
c = a * b
assert c == b * 2
assert isinstance(c, MulNode)
a = Variable("a", 1, 5)
b = NumNode(2) * a
assert b == a * 2
assert isinstance(b, MulNode)
b = NumNode(1) * a
assert b == a
assert isinstance(b, Variable)
b = NumNode(0) * a
assert b == 0
assert isinstance(b, NumNode)
if __name__ == '__main__':
unittest.main()

View File

@ -56,17 +56,21 @@ class Node:
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else create_node(MulNode(b, self.b))
return create_node(MulNode(self, b))
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
# *** complex ops ***
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
def __rfloordiv__(self, b:int):
if self.min > b >= 0: return NumNode(0)
if isinstance(self, NumNode): return NumNode(b // self.b)
raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if b.__class__ is NumNode: return self // b.b
if self == b: return NumNode(1)
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self//-b)*-1
@ -81,9 +85,11 @@ class Node:
def __rmod__(self, b:int):
if self.min > b >= 0: return NumNode(b)
if isinstance(self, NumNode): return NumNode(b % self.b)
raise RuntimeError(f"not supported: {b} % {self}")
def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if b.__class__ is NumNode: return self % b.b
if self == b: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
@ -208,12 +214,12 @@ class SumNode(RedNode):
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if de_num and nu_num % de_num == 0 and b * (d := nu_num // de_num) == self: return NumNode(d)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0: fully_divided.append(x // b)
else: rest.append(x)
if (b > (sum_rest:=create_rednode(SumNode, rest))).min and (sum_rest >= 0).min: return create_rednode(SumNode, fully_divided)
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
return Node.__floordiv__(self, b, False)
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
@ -238,7 +244,7 @@ class SumNode(RedNode):
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if de_num and nu_num % de_num == 0 and b * (nu_num // de_num) == self: return NumNode(0)
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes: