mirror of https://github.com/commaai/tinygrad.git
more symbolic symbolic ops (#1564)
* more symbolic symbolic ops * handle NumNode in __mul__
This commit is contained in:
parent
dfec16cc83
commit
be50b2fe8f
|
@ -139,7 +139,7 @@ class TestSymbolicShapeExpr(unittest.TestCase):
|
|||
view = View(shape, strides)
|
||||
st = ShapeTracker(shape, [view])
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
|
||||
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
|
||||
|
||||
class TestShapeTrackerVarVals(unittest.TestCase):
|
||||
def test_reshape_reshape_updates_var_vals(self):
|
||||
|
|
|
@ -277,25 +277,29 @@ class TestSymRender(unittest.TestCase):
|
|||
assert sym_render(a*b) == "(a*b)"
|
||||
|
||||
class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
def test_node_div_node(self):
|
||||
def test_node_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert i // i == 1
|
||||
|
||||
def test_node_mod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 % (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 128 // (Variable("i", 1, 10)*128 + 128) == 0
|
||||
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 // (Variable("i", 1, 10)*128) == 0
|
||||
assert 0 % (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
assert idx0 % (i*3) == idx0
|
||||
assert i // i == 1
|
||||
assert i % i == 0
|
||||
assert 128 // NumNode(4) == 32
|
||||
assert 128 % NumNode(4) == 0
|
||||
assert NumNode(128) // NumNode(4) == 32
|
||||
assert NumNode(128) % NumNode(4) == 0
|
||||
|
||||
def test_mulnode_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
|
@ -311,7 +315,26 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
|||
idx2 = Variable("idx2", 0, i)
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
|
||||
assert (i+1) // (i*128+128) == 0
|
||||
assert (i+1) % (i*128+128) == (i+1)
|
||||
assert (i+1+idx2) // (i+1) == 1
|
||||
assert (i+1+idx2) % (i+1) == idx2
|
||||
assert (idx0*(i*4+4)+i+1+idx2) // (i+1) == idx0*4+1
|
||||
assert (idx0*(i*4+4)+i+1+idx2) % (i+1) == idx2
|
||||
assert (i*128+128)*2 // (i*128+128) == 2
|
||||
assert (i*128+128)*2 % (i*128+128) == 0
|
||||
|
||||
def test_sumnode_divmod_sumnode_complex(self):
|
||||
i = Variable("i", 1, 1024)
|
||||
gidx0 = Variable("gidx0", 0, i)
|
||||
lidx1 = Variable("lidx1", 0, 7)
|
||||
ridx2 = Variable("ridx1", 0, 31)
|
||||
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) // (i*128+128) == 2 + lidx1*4
|
||||
assert ((i*128+128)*2 + gidx0*128 + lidx1*(i*512+512) + ridx2*4) % (i*128+128) == gidx0*128 + ridx2*4
|
||||
assert ((gidx0*128+i*128+ridx2*4+129)) // (i*128+128) == 1
|
||||
assert ((gidx0*128+i*128+ridx2*4+129)) % (i*128+128) == gidx0*128 + ridx2*4 + 1
|
||||
assert (ridx2*(i*4+4)+1+i+gidx0) // (i*128+128) == 0
|
||||
assert (ridx2*(i*4+4)+1+i+gidx0) % (i*128+128) == (ridx2*(i*4+4)+1+i+gidx0)
|
||||
|
||||
def test_node_lt_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
|
@ -330,11 +353,16 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
|||
assert a > 3
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = NumNode(2)
|
||||
b = Variable("b", 1, 5)
|
||||
c = a * b
|
||||
assert c == b * 2
|
||||
assert isinstance(c, MulNode)
|
||||
a = Variable("a", 1, 5)
|
||||
b = NumNode(2) * a
|
||||
assert b == a * 2
|
||||
assert isinstance(b, MulNode)
|
||||
b = NumNode(1) * a
|
||||
assert b == a
|
||||
assert isinstance(b, Variable)
|
||||
b = NumNode(0) * a
|
||||
assert b == 0
|
||||
assert isinstance(b, NumNode)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -56,17 +56,21 @@ class Node:
|
|||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else create_node(MulNode(b, self.b))
|
||||
return create_node(MulNode(self, b))
|
||||
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
|
||||
return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __rfloordiv__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(0)
|
||||
if isinstance(self, NumNode): return NumNode(b // self.b)
|
||||
raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self // b.b
|
||||
if self == b: return NumNode(1)
|
||||
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
|
||||
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
|
@ -81,9 +85,11 @@ class Node:
|
|||
|
||||
def __rmod__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(b)
|
||||
if isinstance(self, NumNode): return NumNode(b % self.b)
|
||||
raise RuntimeError(f"not supported: {b} % {self}")
|
||||
def __mod__(self, b:Union[Node,int]):
|
||||
if isinstance(b, Node):
|
||||
if b.__class__ is NumNode: return self % b.b
|
||||
if self == b: return NumNode(0)
|
||||
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} % {b}")
|
||||
|
@ -208,12 +214,12 @@ class SumNode(RedNode):
|
|||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if de_num and nu_num % de_num == 0 and b * (d := nu_num // de_num) == self: return NumNode(d)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
|
||||
if isinstance(b, Node):
|
||||
for x in self.flat_components:
|
||||
if x % b == 0: fully_divided.append(x // b)
|
||||
else: rest.append(x)
|
||||
if (b > (sum_rest:=create_rednode(SumNode, rest))).min and (sum_rest >= 0).min: return create_rednode(SumNode, fully_divided)
|
||||
if (sum_fully_divided:=create_rednode(SumNode, fully_divided)) != 0: return sum_fully_divided + create_rednode(SumNode, rest) // b
|
||||
return Node.__floordiv__(self, b, False)
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
|
@ -238,7 +244,7 @@ class SumNode(RedNode):
|
|||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if de_num and nu_num % de_num == 0 and b * (nu_num // de_num) == self: return NumNode(0)
|
||||
if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
|
|
Loading…
Reference in New Issue