symbolic stride (#1326)

This commit is contained in:
chenyu 2023-07-23 15:41:22 -04:00 committed by GitHub
parent 2d4e182294
commit aa05495620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 24 deletions

View File

@ -32,6 +32,14 @@ def ge(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr >= rng, rng
def le(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr <= rng, rng
def gt(expr, rng=None):
if rng is None: rng = random.randint(-4,4)
return expr > rng, rng
if __name__ == "__main__":
ops = [add_v, div, mul, add_num, mod]
for _ in range(1000):
@ -41,9 +49,8 @@ if __name__ == "__main__":
u3 = Variable("v3", 0, random.choice(upper_bounds))
v = [u1,u2,u3]
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
# 10% of the time, add a less than or greater than
if random.random() < 0.05: tape.append(lt)
elif random.random() < 0.05: tape.append(ge)
# 10% of the time, add one of lt, le, gt, ge
if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
expr = Variable.num(0)
rngs = []
for t in tape:

View File

@ -0,0 +1,36 @@
import unittest
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
def test_symbolic_st(self):
x = Variable("x", 1, 100)
st = ShapeTracker((x, 3))
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
def test_expr_idxs(self):
x = Variable("x", 1, 100)
st = ShapeTracker((x, 3))
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((x*3)+y)"
assert e2.render() == "1"
st.permute((1, 0))
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
def test_cat_strides(self):
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t1.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t1.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)

View File

@ -4,7 +4,7 @@ from enum import Enum, auto
import functools
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
# these ops live here
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
@ -86,8 +86,8 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)])
@functools.lru_cache(maxsize=None)
def view_from_shape(shape:Tuple[int, ...]) -> View:
assert all(isinstance(x, int) for x in shape)
def view_from_shape(shape:Tuple[Union[Node, int], ...]) -> View:
assert all(is_sym_int(x) for x in shape)
return View(tuple(shape), strides_for_shape(shape))
@functools.lru_cache(maxsize=None)
@ -158,11 +158,11 @@ class ShapeTracker:
return real_offset.b
# NOTE: if a stride is not always valid, it will be None
def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]:
def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]], ...]:
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
idx, valid = self.expr_idxs(idxs)
ret: List[Optional[int]] = [None] * len(self.views[-1].shape)
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
ret[idxs.index(this_dim.a)] = this_dim.b
@ -235,8 +235,10 @@ class ShapeTracker:
def reshape(self, new_shape: Tuple[int, ...]):
if self.views[-1].shape == new_shape: return self
assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
new_view, extra = _reshape(self.views[-1], new_shape)
if extra: self.views.append(new_view)
else: self.views[-1] = new_view

View File

@ -3,13 +3,15 @@ from abc import abstractmethod
import functools
from math import gcd
from tinygrad.helpers import partition
from typing import List, Dict, Callable, Tuple, Type, Union, Optional
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
class Node:
b: int
b: Union[Node, int]
min: int
max: int
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
@ -29,10 +31,14 @@ class Node:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return (-self) < (-b+1)
def __lt__(self, b:int):
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b
def __le__(self, b:Union[Node,int]): return self < (b+1)
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
def __lt__(self, b:Union[Node,int]):
if self == b: return False
lhs = self
if isinstance(lhs, SumNode):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
@ -47,10 +53,11 @@ class Node:
# TODO: should we divide both by mul_gcd here?
lhs = Variable.sum(muls)
return create_node(LtNode(lhs, b))
def __mul__(self, b:int):
def __mul__(self, b:Union[Node, int]):
if b == 0: return NumNode(0)
if b == 1: return self
return create_node(MulNode(self, b))
def __rmul__(self, b:int): return self*b
# *** complex ops ***
@ -132,6 +139,7 @@ class Variable(Node):
class NumNode(Node):
def __init__(self, num:int):
self.b, self.min, self.max = num, num, num
def __int__(self): return self.b
def create_node(ret:Node):
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
@ -139,7 +147,7 @@ def create_node(ret:Node):
return ret
class OpNode(Node):
def __init__(self, a:Node, b:int):
def __init__(self, a:Node, b:Union[Node, int]):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars()
@ -147,13 +155,14 @@ class OpNode(Node):
def get_bounds(self) -> Tuple[int, int]: pass
class LtNode(OpNode):
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
class MulNode(OpNode):
def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
assert isinstance(self.b, int)
if self.b % b == 0: return self.a*(self.b//b)
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
return Node.__floordiv__(self, b, factoring_allowed)
@ -166,7 +175,7 @@ class MulNode(OpNode):
class DivNode(OpNode):
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
assert self.a.min >= 0 and isinstance(self.b, int)
return self.a.min//self.b, self.a.max//self.b
class ModNode(OpNode):
@ -174,7 +183,7 @@ class ModNode(OpNode):
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
return Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, int]:
assert self.a.min >= 0
assert self.a.min >= 0 and isinstance(self.b, int)
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
class RedNode(Node):
@ -182,7 +191,7 @@ class RedNode(Node):
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
def __floordiv__(self, b: int, factoring_allowed=True):
if b == 1: return self
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
@ -219,7 +228,7 @@ class SumNode(RedNode):
return new_nodes
class AndNode(RedNode):
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
def __mul__(self, b: Union[Node, int]): Variable.ands([x*b for x in self.nodes])
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
def create_rednode(typ:Type[RedNode], nodes:List[Node]):