mirror of https://github.com/commaai/tinygrad.git
symbolic stride (#1326)
This commit is contained in:
parent
2d4e182294
commit
aa05495620
|
@ -32,6 +32,14 @@ def ge(expr, rng=None):
|
|||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr >= rng, rng
|
||||
|
||||
def le(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr <= rng, rng
|
||||
|
||||
def gt(expr, rng=None):
|
||||
if rng is None: rng = random.randint(-4,4)
|
||||
return expr > rng, rng
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops = [add_v, div, mul, add_num, mod]
|
||||
for _ in range(1000):
|
||||
|
@ -41,9 +49,8 @@ if __name__ == "__main__":
|
|||
u3 = Variable("v3", 0, random.choice(upper_bounds))
|
||||
v = [u1,u2,u3]
|
||||
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
|
||||
# 10% of the time, add a less than or greater than
|
||||
if random.random() < 0.05: tape.append(lt)
|
||||
elif random.random() < 0.05: tape.append(ge)
|
||||
# 10% of the time, add one of lt, le, gt, ge
|
||||
if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
|
||||
expr = Variable.num(0)
|
||||
rngs = []
|
||||
for t in tape:
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import unittest
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def test_symbolic_st(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker((x, 3))
|
||||
assert st.shape == (x, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
def test_expr_idxs(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker((x, 3))
|
||||
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((x*3)+y)"
|
||||
assert e2.render() == "1"
|
||||
st.permute((1, 0))
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
|
||||
def test_cat_strides(self):
|
||||
i = Variable("i", 1, 5)
|
||||
j = Variable("j", 1, 5)
|
||||
k = Variable("k", 1, 5)
|
||||
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
|
||||
st = t1.lazydata.st
|
||||
assert st.shape == (i+j+k, 4)
|
||||
assert st.real_strides() == (4, 1)
|
||||
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
|
||||
st = t1.lazydata.st
|
||||
assert st.shape == (3, i+j+k)
|
||||
assert st.real_strides() == (i+j+k, 1)
|
|
@ -4,7 +4,7 @@ from enum import Enum, auto
|
|||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast, NamedTuple
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int
|
||||
|
||||
# these ops live here
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
|
@ -86,8 +86,8 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
|||
return tuple([st if s != 1 else 0 for st, s in zip(strides, shape)])
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
assert all(isinstance(x, int) for x in shape)
|
||||
def view_from_shape(shape:Tuple[Union[Node, int], ...]) -> View:
|
||||
assert all(is_sym_int(x) for x in shape)
|
||||
return View(tuple(shape), strides_for_shape(shape))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
@ -158,11 +158,11 @@ class ShapeTracker:
|
|||
return real_offset.b
|
||||
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[int], ...]:
|
||||
def real_strides(self, ignore_valid=False) -> Tuple[Optional[Union[Node, int]], ...]:
|
||||
if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)]
|
||||
idx, valid = self.expr_idxs(idxs)
|
||||
ret: List[Optional[int]] = [None] * len(self.views[-1].shape)
|
||||
ret: List[Optional[Union[Node, int]]] = [None] * len(self.views[-1].shape)
|
||||
for this_dim in (idx.nodes if isinstance(idx, SumNode) else [idx]):
|
||||
if isinstance(this_dim, MulNode) and isinstance(this_dim.a, Variable):
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
|
@ -235,8 +235,10 @@ class ShapeTracker:
|
|||
|
||||
def reshape(self, new_shape: Tuple[int, ...]):
|
||||
if self.views[-1].shape == new_shape: return self
|
||||
assert all(isinstance(x, int) and x > 0 for x in new_shape), f"shape must be ints and can't contain 0 or negative numbers {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}"
|
||||
# only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done
|
||||
if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
new_view, extra = _reshape(self.views[-1], new_shape)
|
||||
if extra: self.views.append(new_view)
|
||||
else: self.views[-1] = new_view
|
||||
|
|
|
@ -3,13 +3,15 @@ from abc import abstractmethod
|
|||
import functools
|
||||
from math import gcd
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def is_sym_int(x: Any) -> bool: return isinstance(x, int) or isinstance(x, Node)
|
||||
|
||||
class Node:
|
||||
b: int
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
|
||||
|
@ -29,10 +31,14 @@ class Node:
|
|||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __sub__(self, b:Union[Node, int]): return self+-b
|
||||
def __ge__(self, b:int): return (-self) < (-b+1)
|
||||
def __lt__(self, b:int):
|
||||
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __radd__(self, b:int): return self+b
|
||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||
def __le__(self, b:Union[Node,int]): return self < (b+1)
|
||||
def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
|
||||
def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
|
||||
def __lt__(self, b:Union[Node,int]):
|
||||
if self == b: return False
|
||||
lhs = self
|
||||
if isinstance(lhs, SumNode):
|
||||
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
|
@ -47,10 +53,11 @@ class Node:
|
|||
# TODO: should we divide both by mul_gcd here?
|
||||
lhs = Variable.sum(muls)
|
||||
return create_node(LtNode(lhs, b))
|
||||
def __mul__(self, b:int):
|
||||
def __mul__(self, b:Union[Node, int]):
|
||||
if b == 0: return NumNode(0)
|
||||
if b == 1: return self
|
||||
return create_node(MulNode(self, b))
|
||||
def __rmul__(self, b:int): return self*b
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
|
@ -132,6 +139,7 @@ class Variable(Node):
|
|||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
self.b, self.min, self.max = num, num, num
|
||||
def __int__(self): return self.b
|
||||
|
||||
def create_node(ret:Node):
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
|
||||
|
@ -139,7 +147,7 @@ def create_node(ret:Node):
|
|||
return ret
|
||||
|
||||
class OpNode(Node):
|
||||
def __init__(self, a:Node, b:int):
|
||||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars()
|
||||
|
@ -147,13 +155,14 @@ class OpNode(Node):
|
|||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
class LtNode(OpNode):
|
||||
def __mul__(self, b: int): return (self.a*b) < (self.b*b)
|
||||
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: int, _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]: return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
|
||||
class MulNode(OpNode):
|
||||
def __mul__(self, b: int): return self.a*(self.b*b) # two muls in one mul
|
||||
def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
|
||||
def __floordiv__(self, b: int, factoring_allowed=False): # NOTE: mod negative isn't handled right
|
||||
assert isinstance(self.b, int)
|
||||
if self.b % b == 0: return self.a*(self.b//b)
|
||||
if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
|
@ -166,7 +175,7 @@ class MulNode(OpNode):
|
|||
class DivNode(OpNode):
|
||||
def __floordiv__(self, b: int, _=False): return self.a//(self.b*b) # two divs is one div
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return self.a.min//self.b, self.a.max//self.b
|
||||
|
||||
class ModNode(OpNode):
|
||||
|
@ -174,7 +183,7 @@ class ModNode(OpNode):
|
|||
if (self.b % b == 0): return (self.a//b) % (self.b//b) # put the div inside mod
|
||||
return Node.__floordiv__(self, b, factoring_allowed)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
assert self.a.min >= 0
|
||||
assert self.a.min >= 0 and isinstance(self.b, int)
|
||||
return (0, self.b-1) if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b) else (self.a.min%self.b, self.a.max%self.b)
|
||||
|
||||
class RedNode(Node):
|
||||
|
@ -182,7 +191,7 @@ class RedNode(Node):
|
|||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
|
||||
class SumNode(RedNode):
|
||||
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
def __floordiv__(self, b: int, factoring_allowed=True):
|
||||
if b == 1: return self
|
||||
if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
|
||||
|
@ -219,7 +228,7 @@ class SumNode(RedNode):
|
|||
return new_nodes
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __mul__(self, b: int): Variable.ands([x*b for x in self.nodes])
|
||||
def __mul__(self, b: Union[Node, int]): Variable.ands([x*b for x in self.nodes])
|
||||
def __floordiv__(self, b: int, _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
|
||||
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
|
|
Loading…
Reference in New Issue