Variable.sum -> Node.sum, Variable.ands -> Node.ands (#2961)

This commit is contained in:
chenyu 2024-01-01 16:21:28 -05:00 committed by GitHub
parent 3d720b5761
commit 8291986959
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 41 deletions

View File

@ -8,7 +8,7 @@ from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, create_rednode from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule from tinygrad.realize import run_schedule
@ -576,7 +576,7 @@ class TestLinearizerHelper(unittest.TestCase):
b = Variable("b", 5, 7) b = Variable("b", 5, 7)
s1 = create_rednode(SumNode, [a, b]) s1 = create_rednode(SumNode, [a, b])
assert expand_node(s1) == [Variable.sum([NumNode(i),b]) for i in range(1,4)] assert expand_node(s1) == [Node.sum([NumNode(i),b]) for i in range(1,4)]
def test_multi_expand(self): def test_multi_expand(self):
a = Variable("a", 1, 3) a = Variable("a", 1, 3)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest, pickle import unittest, pickle
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
class TestSymbolicPickle(unittest.TestCase): class TestSymbolicPickle(unittest.TestCase):
def test_pickle_variable(self): def test_pickle_variable(self):
@ -39,15 +39,15 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(expr, 0, 1, "(idx<128)") self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_ge_divides_and(self): def test_ge_divides_and(self):
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512]) (Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))") self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512, expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512]) (Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))") self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))")
def test_lt_factors(self): def test_lt_factors(self):
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512]) expr = Node.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)") self.helper_test_variable(expr, 0, 1, "(((idx1*4)+FLOAT4_INDEX)<512)")
def test_div_becomes_num(self): def test_div_becomes_num(self):
@ -121,39 +121,39 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)")
def test_sum_div_min_max(self): def test_sum_div_min_max(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") self.helper_test_variable(Node.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
def test_sum_div_factor(self): def test_sum_div_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, Variable("b", 0, 3)*4]) // 2, 0, 20, "((a*2)+(b*2))")
def test_sum_div_some_factor(self): def test_sum_div_some_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*4]) // 2, 0, 23, "(((a*5)//2)+(b*2))")
def test_sum_div_some_partial_factor(self): def test_sum_div_some_partial_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 0, 5, "(((a*3)+(b*3))//8)")
self.helper_test_variable(Variable.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)") self.helper_test_variable(Node.sum([NumNode(16), Variable("a", 0, 7)*6, Variable("b", 0, 7)*6]) // 16, 1, 6, "((((a*3)+(b*3))//8)+1)")
def test_sum_div_no_factor(self): def test_sum_div_no_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self): def test_mod_factor(self):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul # NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
def test_mod_to_sub(self): def test_mod_to_sub(self):
# This is mod reduction # This is mod reduction
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render()) self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
def test_sum_div_const(self): def test_sum_div_const(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 4, 0, 7, "a")
def test_sum_div_const_big(self): def test_sum_div_const_big(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)") self.helper_test_variable(Node.sum([Variable("a", 0, 7)*4, NumNode(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self): def test_sum_lt_fold(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)") self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)") self.helper_test_variable(Node.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
self.helper_test_variable(Variable.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)") self.helper_test_variable(Node.sum([Variable("uidx", 0, 3), Variable("a", 0, 1529) * 12]) < (4 * 67), 0, 1, "(a<23)")
def test_mod_mul(self): def test_mod_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a") self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
@ -176,13 +176,13 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
def test_distribute_mul(self): def test_distribute_mul(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") self.helper_test_variable(Node.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
def test_mod_mul_sum(self): def test_mod_mul_sum(self):
self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") self.helper_test_variable(Node.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
def test_sum_0(self): def test_sum_0(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") self.helper_test_variable(Node.sum([Variable("a", 0, 7)]), 0, 7, "a")
def test_mod_remove(self): def test_mod_remove(self):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
@ -207,23 +207,23 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)") self.helper_test_variable((Variable("a", 0, 6) + 2) < 3, 0, 1, "(a<1)")
def test_and_fold(self): def test_and_fold(self):
self.helper_test_variable(Variable.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0") self.helper_test_variable(Node.ands([NumNode(0), Variable("a", 0, 1)]), 0, 0, "0")
def test_and_remove(self): def test_and_remove(self):
self.helper_test_variable(Variable.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a") self.helper_test_variable(Node.ands([NumNode(1), Variable("a", 0, 1)]), 0, 1, "a")
def test_mod_factor_negative(self): def test_mod_factor_negative(self):
self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
self.helper_test_variable(Variable.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") self.helper_test_variable(Node.sum([NumNode(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
def test_sum_combine_num(self): def test_sum_combine_num(self):
self.helper_test_variable(Variable.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)") self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, "(6+a)")
def test_sum_num_hoisted_and_factors_cancel_out(self): def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
def test_div_factor(self): def test_div_factor(self):
self.helper_test_variable(Variable.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") self.helper_test_variable(Node.sum([NumNode(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
def test_mul_div(self): def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@ -235,7 +235,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)") self.helper_test_variable((Variable("a", 0, 10)*4)//8, 0, 5, "(a//2)")
def test_div_remove(self): def test_div_remove(self):
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") self.helper_test_variable(Node.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
def test_div_numerator_negative(self): def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")

View File

@ -17,7 +17,7 @@ def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
base = ((idx//acc)%d) base = ((idx//acc)%d)
expr += [base >= x, base < y] expr += [base >= x, base < y]
acc *= d acc *= d
return Variable.ands(expr) return Node.ands(expr)
# generate an expression if you have a single idx variable # generate an expression if you have a single idx variable
def expr_node(view:View, idx:Optional[Node]=None) -> Node: def expr_node(view:View, idx:Optional[Node]=None) -> Node:
@ -27,12 +27,12 @@ def expr_node(view:View, idx:Optional[Node]=None) -> Node:
for d,s,_ in reversed(_merge_dims(view.shape, view.strides)): for d,s,_ in reversed(_merge_dims(view.shape, view.strides)):
ret.append(((idx//acc)%d)*s) ret.append(((idx//acc)%d)*s)
acc *= d acc *= d
return Variable.sum(ret) return Node.sum(ret)
# generate an expression if you have a variable or expression for each index # generate an expression if you have a variable or expression for each index
def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node: def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}" assert len(idxs) == len(view.shape), f"need an idx for all dimensions {idxs} vs {view.shape}"
return Variable.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501 return Node.sum([NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [idx*st for idx,sh,st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0]) # noqa: E501
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]: def merge_views(vm2:View, vm1:View) -> Optional[View]:
@ -49,7 +49,7 @@ def idxs_to_idx(shape:Tuple[int, ...], idxs:Tuple[Node, ...]) -> Node:
for tidx,d in zip(reversed(idxs), reversed(shape)): for tidx,d in zip(reversed(idxs), reversed(shape)):
ret.append(tidx * acc) ret.append(tidx * acc)
acc *= d acc *= d
return Variable.sum(ret) return Node.sum(ret)
@dataclass(frozen=True) @dataclass(frozen=True)
class ShapeTracker: class ShapeTracker:

View File

@ -32,7 +32,7 @@ class Node:
if not isinstance(other, Node): return NotImplemented if not isinstance(other, Node): return NotImplemented
return self.key == other.key return self.key == other.key
def __neg__(self): return self*-1 def __neg__(self): return self*-1
def __add__(self, b:Union[Node,int]): return Variable.sum([self, b if isinstance(b, Node) else NumNode(b)]) def __add__(self, b:Union[Node,int]): return Node.sum([self, b if isinstance(b, Node) else NumNode(b)])
def __radd__(self, b:int): return self+b def __radd__(self, b:int): return self+b
def __sub__(self, b:Union[Node,int]): return self+-b def __sub__(self, b:Union[Node,int]): return self+-b
def __rsub__(self, b:int): return -self+b def __rsub__(self, b:int): return -self+b
@ -283,12 +283,12 @@ class SumNode(RedNode):
# NOTE: gcd in python 3.8 takes exactly 2 args # NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b mul_gcd = b
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell that x.b is int here due to assert above
all_others = Variable.sum(others) all_others = Node.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd: if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b return Node.__lt__(lhs, b) if isinstance(lhs, SumNode) else lhs < b
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes]) def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Node.sum([node.substitute(var_vals) for node in self.nodes])
@property @property
def flat_components(self): # recursively expand sumnode components def flat_components(self): # recursively expand sumnode components
@ -297,13 +297,13 @@ class SumNode(RedNode):
return new_nodes return new_nodes
class AndNode(RedNode): class AndNode(RedNode):
def __floordiv__(self, b: Union[Node, int], _=True): return Variable.ands([x//b for x in self.nodes]) def __floordiv__(self, b: Union[Node, int], _=True): return Node.ands([x//b for x in self.nodes])
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node:
subed = [] subed = []
for node in self.nodes: for node in self.nodes:
if not (sub:=node.substitute(var_vals)): return NumNode(0) if not (sub:=node.substitute(var_vals)): return NumNode(0)
subed.append(sub) subed.append(sub)
return Variable.ands(subed) return Node.ands(subed)
def create_rednode(typ:Type[RedNode], nodes:List[Node]): def create_rednode(typ:Type[RedNode], nodes:List[Node]):
ret = typ(nodes) ret = typ(nodes)