symbolic shapetracker (#1506)

* symbolic shapetracker

* no need

* keep only symbolic and clean up

* explicit // and % Node support

* NumNode * Node
This commit is contained in:
chenyu 2023-08-12 12:22:58 -07:00 committed by GitHub
parent 875da762a8
commit 3e0c2d256f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 138 additions and 25 deletions

View File

@ -1,5 +1,5 @@
import unittest
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
@ -98,4 +98,21 @@ class TestSymbolicExpand(unittest.TestCase):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)
assert a.shape == (3, vi)
class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self):
# taken from symbolic shape llama
i = Variable("i", 1, 120)
gidx0 = Variable("gidx0", 0, i)
lidx1 = Variable("lidx1", 0, 7)
idx = (gidx0, lidx1, Variable.num(1))
shape = (i+1, 8, 4)
strides = (1, (i*4)+4, i+1)
view = View(shape, strides)
st = ShapeTracker(shape, [view])
idx, valid = st.expr_idxs(idx)
assert idx.render() == "(((1+i)*1)+(lidx1*((i*4)+4))+gidx0)"
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, sym_vars
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, sym_vars, sym_render
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
@ -257,8 +257,7 @@ class TestSymbolicVars(unittest.TestCase):
a = Variable("a", 0, 10)
b = Variable("b", 0, 10)
c = Variable("c", 0, 10)
# TODO: update this after we support symbolic * symbolic
assert (a + b * c).vars() == [a, b]
assert (a + b * c).vars() == [a, b, c]
assert (a % 3 + b // 5).vars() == [a, b]
assert (a + b + c - a).vars() == [b, c]
@ -276,6 +275,73 @@ class TestSymbolicMinMax(unittest.TestCase):
assert max(1, a) == max(a, 1) == a
assert min(1, a) == min(a, 1) == 1
class TestSymRender(unittest.TestCase):
def test_sym_render(self):
a = Variable("a", 1, 8)
b = Variable("b", 1, 10)
assert sym_render(a) == "a"
assert sym_render(1) == "1"
assert sym_render(a+1) == "(1+a)"
assert sym_render(a*b) == "(a*b)"
class TestSymbolicSymbolicOps(unittest.TestCase):
def test_node_div_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
def test_node_mod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, i*3-1)
assert NumNode(0) % (Variable("i", 1, 10)*128) == 0
assert NumNode(127) % (Variable("i", 1, 10)*128) == 127
assert NumNode(128) % (Variable("i", 1, 10)*128 + 128) == 128
assert 0 % (Variable("i", 1, 10)*128) == 0
assert 127 % (Variable("i", 1, 10)*128) == 127
assert 128 % (Variable("i", 1, 10)*128 + 128) == 128
assert idx0 % (i*3) == idx0
assert i % i == 0
def test_mulnode_divmod_node(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, 31)
assert (idx0*(i*4+4)) // (i+1) == (idx0*4)
assert (idx0*(i*4+4)) % (i+1) == 0
assert (idx0*i) % i == 0
def test_sumnode_divmod_sumnode(self):
i = Variable("i", 1, 10)
idx0 = Variable("idx0", 0, 7)
idx1 = Variable("idx1", 0, 3)
idx2 = Variable("idx2", 0, i)
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) // (i+1) == idx0*4+idx1
assert (idx0*(i*4+4)+idx1*(i+1)+idx2) % (i+1) == idx2
assert (i+1) % (i*128+128) == (i+1)
def test_node_lt_node(self):
a = Variable("a", 1, 5)
b = Variable("b", 6, 9)
c = Variable("c", 1, 10)
# if the value is always the same, it folds to num
assert (a < b) == 1
# if it remains as a LtNode, bool is always true and we need to test against min to test if it always evals to True
assert (a < c).__class__ is LtNode and (a < c).min == 0 and (a < c).max == 1
assert a < c
assert not (a < c).min
assert (a > c).__class__ is LtNode and (a > c).min == 0 and (a > c).max == 1
assert not (a > c).min
# same when comparing with a constant
assert a < 3
assert a > 3
def test_num_node_mul_node(self):
a = NumNode(2)
b = Variable("b", 1, 5)
c = a * b
assert c == b * 2
assert isinstance(c, MulNode)
if __name__ == '__main__':
unittest.main()

View File

@ -8,7 +8,7 @@ from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
def sym_vars(x: Union[Node, int]) -> List[Variable]: return [] if isinstance(x, int) else x.vars()
class Node:
@ -40,9 +40,8 @@ class Node:
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]):
if self == b: return NumNode(0)
lhs = self
if isinstance(lhs, SumNode):
if isinstance(lhs, SumNode) and isinstance(b, int):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if len(muls):
# NOTE: gcd in python 3.8 takes exactly 2 args
@ -58,12 +57,17 @@ class Node:
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else create_node(MulNode(b, self.b))
return create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
# *** complex ops ***
def __floordiv__(self, b:int, factoring_allowed=True):
def __rfloordiv__(self, b:int): raise RuntimeError(f"not supported: {b} // {self}")
def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if (b > self).min > 0 and self.min >= 0: return NumNode(0)
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
@ -75,7 +79,14 @@ class Node:
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_node(DivNode(self, b))
def __mod__(self, b:int):
def __rmod__(self, b:int):
if self.min > b >= 0: return NumNode(b)
raise RuntimeError(f"not supported: {b} % {self}")
def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if self == b: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1: return NumNode(0)
if self.min >= 0 and self.max < b: return self
@ -137,6 +148,8 @@ class NumNode(Node):
def __init__(self, num:int):
self.b, self.min, self.max = num, num, num
def __int__(self): return self.b
def __eq__(self, other): return self.b == other
def __hash__(self): return self.hash # needed with __eq__ override
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
@ -147,38 +160,37 @@ class OpNode(Node):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars()
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
class LtNode(OpNode):
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]:
if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
class MulNode(OpNode):
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
assert isinstance(self.b, int)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=False): # NOTE: mod negative isn't handled right
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
def __mod__(self, b: int):
def __mod__(self, b: Union[Node, int]):
a = (self.a * (self.b%b))
return Node.__mod__(a, b)
def get_bounds(self) -> Tuple[int, int]:
return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
class DivNode(OpNode):
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min//self.b, self.a.max//self.b
class ModNode(OpNode):
def __floordiv__(self, b: int, factoring_allowed=True):
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
@ -191,11 +203,22 @@ class RedNode(Node):
class SumNode(RedNode):
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
def __floordiv__(self, b: int, factoring_allowed=True):
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
fully_divided: List[Node] = []
rest: List[Node] = []
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if de_num and nu_num % de_num == 0 and b * (d := nu_num // de_num) == self: return NumNode(d)
if isinstance(b, Node):
for x in self.flat_components:
if x % b == 0: fully_divided.append(x // b)
else: rest.append(x)
if (b > (sum_rest:=create_rednode(SumNode, rest))).min and (sum_rest >= 0).min: return create_rednode(SumNode, fully_divided)
return Node.__floordiv__(self, b, False)
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
fully_divided, rest = [], []
_gcd = b
divisor = 1
for x in self.flat_components:
@ -212,7 +235,12 @@ class SumNode(RedNode):
if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
def __mod__(self, b: int):
def __mod__(self, b: Union[Node, int]):
if isinstance(b, SumNode):
nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
if de_num and nu_num % de_num == 0 and b * (nu_num // de_num) == self: return NumNode(0)
if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
new_nodes: List[Node] = []
for x in self.nodes:
if x.__class__ is NumNode: new_nodes.append(Variable.num(x.b%b))
@ -228,7 +256,7 @@ class SumNode(RedNode):
class AndNode(RedNode):
def __mul__(self, b: Union[Node, int]): Variable.ands([x*b for x in self.nodes])
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
ret = typ(nodes)
@ -236,13 +264,15 @@ def create_rednode(typ:Type[RedNode], nodes:List[Node]):
elif typ == AndNode: ret.min, ret.max = (min([x.min for x in nodes]), max([x.max for x in nodes]))
return create_node(ret)
def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}]" if ctx == "DEBUG" else f"{self.expr}",
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{self.b})",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
}