mirror of https://github.com/commaai/tinygrad.git
Variable.sum -> Node.sum, Variable.ands -> Node.ands (#2961)
This commit is contained in:
parent
3d720b5761
commit
8291986959
|
@ -8,7 +8,7 @@ from tinygrad.device import Compiled, Device, Buffer
|
|||
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, create_rednode
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
|
@ -576,7 +576,7 @@ class TestLinearizerHelper(unittest.TestCase):
|
|||
b = Variable("b", 5, 7)
|
||||
|
||||
s1 = create_rednode(SumNode, [a, b])
|
||||
assert expand_node(s1) == [Variable.sum([NumNode(i),b]) for i in range(1,4)]
|
||||
assert expand_node(s1) == [Node.sum([NumNode(i),b]) for i in range(1,4)]
|
||||
|
||||
def test_multi_expand(self):
|
||||
a = Variable("a", 1, 3)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
|
||||
|
||||
class TestSymbolicPickle(unittest.TestCase):
|
||||
def test_pickle_variable(self):
|
||||
|
@ -39,15 +39,15 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
|
||||
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
|
||||
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
|
||||
expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
|
||||
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
|
||||
|
||||
def test_div_becomes_num(self):
|
||||
|
@ -121,39 +121,39 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
def test_sum_div_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
|
||||
|
||||
def test_sum_div_some_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
|
||||
|
||||
def test_sum_div_some_partial_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Variable.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
|
||||
self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
|
||||
|
||||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
def test_mod_factor(self):
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
# This is mod reduction
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
|
||||
|
||||
def test_sum_div_const_big(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
|
||||
self.helper_test_variable(Variable.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
|
||||
self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
|
||||
|
||||
def test_mod_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
|
||||
|
@ -176,13 +176,13 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
def test_distribute_mul(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
|
||||
def test_mod_mul_sum(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
|
||||
self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
|
||||
|
||||
def test_sum_0(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
|
||||
|
||||
def test_mod_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
|
||||
|
@ -207,23 +207,23 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)")
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(Variable.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
|
||||
def test_and_remove(self):
|
||||
self.helper_test_variable(Variable.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
|
||||
def test_mod_factor_negative(self):
|
||||
self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Variable.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
|
||||
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
|
||||
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
def test_div_factor(self):
|
||||
self.helper_test_variable(Variable.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
@ -235,7 +235,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
|
||||
|
||||
def test_div_remove(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||
|
||||
def test_div_numerator_negative(self):
|
||||
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
|
||||
|
|
|
@ -17,7 +17,7 @@ def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
|
|||
base = ((idx//acc)%d)
|
||||
expr += [base >= x, base < y]
|
||||
acc *= d
|
||||
return Variable.ands(expr)
|
||||
return Node.ands(expr)
|
||||
|
||||
# generate an expression if you have a single idx variable
|
||||
def expr_node(view:View, idx:Optional[Node]=None) -> Node:
|
||||
|
@ -27,12 +27,12 @@ def expr_node(view:View, idx:Optional[Node]=None) -> Node:
|
|||
for d,s,_ in reversed(_merge_dims(view.shape, view.strides)):
|
||||
ret.append(((idx//acc)%d)*s)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
return Node.sum(ret)
|
||||
|
||||
# generate an expression if you have a variable or expression for each index
|
||||
def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
|
||||
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
|
||||
return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
|
||||
return Node.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
|
@ -49,7 +49,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
|
|||
for tidx,d in zip(reversed(idxs), reversed(shape)):
|
||||
ret.append(tidx * acc)
|
||||
acc *= d
|
||||
return Variable.sum(ret)
|
||||
return Node.sum(ret)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
|
|
|
@ -32,7 +32,7 @@ class Node:
|
|||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)])
|
||||
def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b, Node) else NumNode(b)])
|
||||
def __radd__(self, b:int): return self+b
|
||||
def __sub__(self, b:Union[Node,int]): return self+-b
|
||||
def __rsub__(self, b:int): return -self+b
|
||||
|
@ -283,12 +283,12 @@ class SumNode(RedNode):
|
|||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
|
||||
all_others = Variable.sum(others)
|
||||
all_others = Node.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
|
||||
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
||||
@property
|
||||
def flat_components(self): # recursively expand sumnode components
|
||||
|
@ -297,13 +297,13 @@ class SumNode(RedNode):
|
|||
return new_nodes
|
||||
|
||||
class AndNode(RedNode):
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes])
|
||||
def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes])
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
|
||||
subed = []
|
||||
for node in self.nodes:
|
||||
if not (sub:=node.substitute(var_vals)): return NumNode(0)
|
||||
subed.append(sub)
|
||||
return Variable.ands(subed)
|
||||
return Node.ands(subed)
|
||||
|
||||
def create_rednode(typ:Type[RedNode], nodes:List[Node]):
|
||||
ret = typ(nodes)
|
||||
|
|
Loading…
Reference in New Issue