mirror of https://github.com/commaai/tinygrad.git
symbolic shapetracker (#1506)
* symbolic shapetracker * no need * keep only symbolic and clean up * explicit // and % Node support * NumNode * Node
This commit is contained in:
parent
875da762a8
commit
3e0c2d256f
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
@ -99,3 +99,20 @@ class TestSymbolicExpand(unittest.TestCase):
|
|||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
||||
|
||||
class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
def test_symbolic_expr_idxs(self):
|
||||
# taken from symbolic shape llama
|
||||
i = Variable("i", 1, 120)
|
||||
gidx0 = Variable("gidx0", 0, i)
|
||||
lidx1 = Variable("lidx1", 0, 7)
|
||||
idx = (gidx0, lidx1, Variable.num(1))
|
||||
shape = (i+1, 8, 4)
|
||||
strides = (1, (i*4)+4, i+1)
|
||||
view = View(shape, strides)
|
||||
st = ShapeTracker(shape, [view])
|
||||
idx, valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, sym_vars
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_vars, sym_render
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
|
@ -257,8 +257,7 @@ class TestSymbolicVars(unittest.TestCase):
|
|||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
# TODO: update this after we support symbolic * symbolic
|
||||
assert (a + b * c).vars() == [a, b]
|
||||
assert (a + b * c).vars() == [a, b, c]
|
||||
assert (a % 3 + b // 5).vars() == [a, b]
|
||||
assert (a + b + c - a).vars() == [b, c]
|
||||
|
||||
|
@ -276,6 +275,73 @@ class TestSymbolicMinMax(unittest.TestCase):
|
|||
assert max(1, a) == max(a, 1) == a
|
||||
assert min(1, a) == min(a, 1) == 1
|
||||
|
||||
class TestSymRender(unittest.TestCase):
|
||||
def test_sym_render(self):
|
||||
a = Variable("a", 1, 8)
|
||||
b = Variable("b", 1, 10)
|
||||
assert sym_render(a) == "a"
|
||||
assert sym_render(1) == "1"
|
||||
assert sym_render(a+1) == "(1+a)"
|
||||
assert sym_render(a*b) == "(a*b)"
|
||||
|
||||
class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
def test_node_div_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
|
||||
assert idx0 // (i*3) == 0
|
||||
|
||||
def test_node_mod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, i*3-1)
|
||||
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
|
||||
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
|
||||
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert 0 % (Variable("i", 1, 10)*128) == 0
|
||||
assert 127 % (Variable("i", 1, 10)*128) == 127
|
||||
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
|
||||
assert idx0 % (i*3) == idx0
|
||||
assert i % i == 0
|
||||
|
||||
def test_mulnode_divmod_node(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, 31)
|
||||
assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
|
||||
assert (idx0*(i*4+4)) % (i+1) == 0
|
||||
assert (idx0*i) % i == 0
|
||||
|
||||
def test_sumnode_divmod_sumnode(self):
|
||||
i = Variable("i", 1, 10)
|
||||
idx0 = Variable("idx0", 0, 7)
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, i)
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
|
||||
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
|
||||
assert (i+1) % (i*128+128) == (i+1)
|
||||
|
||||
def test_node_lt_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = Variable("b", 6, 9)
|
||||
c = Variable("c", 1, 10)
|
||||
# if the value is always the same, it folds to num
|
||||
assert (a < b) == 1
|
||||
# if it remains as a LtNode, bool is always true and we need to test against min to test if it always evals to True
|
||||
assert (a < c).__class__ is LtNode and (a < c).min == 0 and (a < c).max == 1
|
||||
assert a < c
|
||||
assert not (a < c).min
|
||||
assert (a > c).__class__ is LtNode and (a > c).min == 0 and (a > c).max == 1
|
||||
assert not (a > c).min
|
||||
# same when comparing with a constant
|
||||
assert a < 3
|
||||
assert a > 3
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = NumNode(2)
|
||||
b = Variable("b", 1, 5)
|
||||
c = a * b
|
||||
assert c == b * 2
|
||||
assert isinstance(c, MulNode)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
|||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
|
||||
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
|
||||
|
||||
class Node:
|
||||
|
@ -40,9 +40,8 @@ class Node:
|
|||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
if self == b: return NumNode(0)
|
||||
lhs = self
|
||||
if isinstance(lhs, SumNode):
|
||||
if isinstance(lhs, SumNode) and isinstance(b, int):
|
||||
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if len(muls):
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
|
@ -58,12 +57,17 @@ class Node:
|
|||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else create_node(MulNode(b, self.b))
|
||||
return create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __floordiv__(self, b:int, factoring_allowed=True):
|
||||
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
|
||||
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
|
||||
if isinstance(b, Node):
|
||||
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
|
||||
raise RuntimeError(f"not supported: {self} // {b}")
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
if b == 1: return self
|
||||
|
@ -75,7 +79,14 @@ class Node:
|
|||
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
||||
return create_node(DivNode(self, b))
|
||||
|
||||
def __mod__(self, b:int):
|
||||
def __rmod__(self, b:int):
|
||||
if self.min > b >= 0: return NumNode(b)
|
||||
raise RuntimeError(f"not supported: {b} % {self}")
|
||||
def __mod__(self, b:Union[Node,int]):
|
||||
if isinstance(b, Node):
|
||||
if self == b: return NumNode(0)
|
||||
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
|
||||
raise RuntimeError(f"not supported: {self} % {b}")
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if self.min >= 0 and self.max < b: return self
|
||||
|
@ -137,6 +148,8 @@ class NumNode(Node):
|
|||
def __init__(self, num:int):
|
||||
self.b, self.min, self.max = num, num, num
|
||||
def __int__(self): return self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
|
@ -147,38 +160,37 @@ class OpNode(Node):
|
|||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars()
|
||||
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
class LtNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
assert isinstance(self.b, int)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __mod__(self, b: int):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
a = (self.a * (self.b%b))
|
||||
return Node.__mod__(a, b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
|
||||
|
||||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
|
@ -191,11 +203,22 @@ class RedNode(Node):
|
|||
|
||||
class SumNode(RedNode):
|
||||
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
|
||||
fully_divided: List[Node] = []
|
||||
rest: List[Node] = []
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if de_num and nu_num % de_num == 0 and b * (d := nu_num // de_num) == self: return NumNode(d)
|
||||
if isinstance(b, Node):
|
||||
for x in self.flat_components:
|
||||
if x % b == 0: fully_divided.append(x // b)
|
||||
else: rest.append(x)
|
||||
if (b > (sum_rest:=create_rednode(SumNode, rest))).min and (sum_rest >= 0).min: return create_rednode(SumNode, fully_divided)
|
||||
return Node.__floordiv__(self, b, False)
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
fully_divided, rest = [], []
|
||||
_gcd = b
|
||||
divisor = 1
|
||||
for x in self.flat_components:
|
||||
|
@ -212,7 +235,12 @@ class SumNode(RedNode):
|
|||
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
|
||||
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
|
||||
|
||||
def __mod__(self, b: int):
|
||||
def __mod__(self, b: Union[Node, int]):
|
||||
if isinstance(b, SumNode):
|
||||
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
|
||||
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
|
||||
if de_num and nu_num % de_num == 0 and b * (nu_num // de_num) == self: return NumNode(0)
|
||||
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
|
||||
new_nodes: List[Node] = []
|
||||
for x in self.nodes:
|
||||
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
|
||||
|
@ -228,7 +256,7 @@ class SumNode(RedNode):
|
|||
|
||||
class AndNode(RedNode):
|
||||
def __mul__(self, b: Union[Node, int]): Variable.ands([x*b for x in self.nodes])
|
||||
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
|
||||
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
ret = typ(nodes)
|
||||
|
@ -236,13 +264,15 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
|||
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
|
||||
return create_node(ret)
|
||||
|
||||
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
|
||||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}",
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
|
||||
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
}
|
Loading…
Reference in New Issue