From 8291986959502579f7c009572dcafd61c54b5e26 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 1 Jan 2024 16:21:28 -0500 Subject: [PATCH] Variable.sum -> Node.sum, Variable.ands -> Node.ands (#2961) --- test/test_linearizer.py | 4 +-- test/unit/test_symbolic.py | 58 +++++++++++++++++----------------- tinygrad/shape/shapetracker.py | 8 ++--- tinygrad/shape/symbolic.py | 12 +++---- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 82e5479e..e509d922 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -8,7 +8,7 @@ from tinygrad.device import Compiled, Device, Buffer from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, create_rednode +from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector from tinygrad.realize import run_schedule @@ -576,7 +576,7 @@ class TestLinearizerHelper(unittest.TestCase): b = Variable("b", 5, 7) s1 = create_rednode(SumNode, [a, b]) - assert expand_node(s1) == [Variable.sum([NumNode(i),b]) for i in range(1,4)] + assert expand_node(s1) == [Node.sum([NumNode(i),b]) for i in range(1,4)] def test_multi_expand(self): a = Variable("a", 1, 3) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 6e89c928..ef28eb26 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest, pickle -from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer +from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer class TestSymbolicPickle(unittest.TestCase): def test_pickle_variable(self): @@ -39,15 +39,15 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(expr, 0, 1, "(idx<128)") def test_ge_divides_and(self): - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, - (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) + expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))") - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, - (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) + expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, + (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))") def test_lt_factors(self): - expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512]) + expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512]) self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") def test_div_becomes_num(self): @@ -121,39 +121,39 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") def test_sum_div_min_max(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") def test_sum_div_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") def test_sum_div_some_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") def test_sum_div_some_partial_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") - self.helper_test_variable(Variable.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") + self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") def test_sum_div_no_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") def test_mod_factor(self): # NOTE: even though the mod max is 50, it can't know this without knowing about the mul - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") def test_mod_to_sub(self): # This is mod reduction self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render()) def test_sum_div_const(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") def test_sum_div_const_big(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") def test_sum_lt_fold(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") - self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)") - self.helper_test_variable(Variable.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") + self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)") + self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)") def test_mod_mul(self): self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a") @@ -176,13 +176,13 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") def test_distribute_mul(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") + self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") def test_mod_mul_sum(self): - self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") + self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") def test_sum_0(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") + self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a") def test_mod_remove(self): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") @@ -207,23 +207,23 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)") def test_and_fold(self): - self.helper_test_variable(Variable.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0") + self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0") def test_and_remove(self): - self.helper_test_variable(Variable.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") + self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") def test_mod_factor_negative(self): - self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") - self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") + self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") + self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") def test_sum_combine_num(self): - self.helper_test_variable(Variable.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)") + self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)") def test_sum_num_hoisted_and_factors_cancel_out(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") + self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") def test_div_factor(self): - self.helper_test_variable(Variable.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") + self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -235,7 +235,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") def test_div_remove(self): - self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") + self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") def test_div_numerator_negative(self): self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index e71ce0e3..91120159 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -17,7 +17,7 @@ def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node: base = ((idx//acc)%d) expr += [base >= x, base < y] acc *= d - return Variable.ands(expr) + return Node.ands(expr) # generate an expression if you have a single idx variable def expr_node(view:View, idx:Optional[Node]=None) -> Node: @@ -27,12 +27,12 @@ def expr_node(view:View, idx:Optional[Node]=None) -> Node: for d,s,_ in reversed(_merge_dims(view.shape, view.strides)): ret.append(((idx//acc)%d)*s) acc *= d - return Variable.sum(ret) + return Node.sum(ret) # generate an expression if you have a variable or expression for each index def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node: assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" - return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501 + return Node.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501 @functools.lru_cache(maxsize=None) def merge_views(vm2:View, vm1:View) -> Optional[View]: @@ -49,7 +49,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node: for tidx,d in zip(reversed(idxs), reversed(shape)): ret.append(tidx * acc) acc *= d - return Variable.sum(ret) + return Node.sum(ret) @dataclass(frozen=True) class ShapeTracker: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index ee0d2db1..f0e65760 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -32,7 +32,7 @@ class Node: if not isinstance(other, Node): return NotImplemented return self.key == other.key def __neg__(self): return self*-1 - def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)]) + def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b, Node) else NumNode(b)]) def __radd__(self, b:int): return self+b def __sub__(self, b:Union[Node,int]): return self+-b def __rsub__(self, b:int): return -self+b @@ -283,12 +283,12 @@ class SumNode(RedNode): # NOTE: gcd in python 3.8 takes exactly 2 args mul_gcd = b for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above - all_others = Variable.sum(others) + all_others = Node.sum(others) if all_others.min >= 0 and all_others.max < mul_gcd: - lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd + lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b - def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes]) + def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes]) @property def flat_components(self): # recursively expand sumnode components @@ -297,13 +297,13 @@ class SumNode(RedNode): return new_nodes class AndNode(RedNode): - def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes]) + def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes]) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: subed = [] for node in self.nodes: if not (sub:=node.substitute(var_vals)): return NumNode(0) subed.append(sub) - return Variable.ands(subed) + return Node.ands(subed) def create_rednode(typ:Type[RedNode], nodes:List[Node]): ret = typ(nodes)