remove NumNode (#7035)

This commit is contained in:
chenyu 2024-10-13 16:42:19 -04:00 committed by GitHub
parent c4c806a210
commit bd8ecf7fd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 17 additions and 25 deletions

View File

@ -7,7 +7,6 @@ from tinygrad.ops import UOp, UOps, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import NumNode
inf, nan = float('inf'), float('nan')
# kernel unpacker

View File

@ -1,8 +1,8 @@
import itertools
import random
from tinygrad import Variable
from tinygrad import Variable, dtypes
from tinygrad.ops import UOp
from tinygrad.helpers import DEBUG
from tinygrad.ops import NumNode
random.seed(42)
def add_v(expr, rng=None):
@ -57,7 +57,7 @@ if __name__ == "__main__":
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
# 10% of the time, add one of lt, le, gt, ge
if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
expr = NumNode(0)
expr = UOp.const(dtypes.int, 0)
rngs = []
for t in tape:
expr, rng = t(expr)

View File

@ -1,7 +1,6 @@
import unittest
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
from tinygrad.ops import NumNode
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):
@ -16,17 +15,17 @@ class TestSymbolic(unittest.TestCase):
@unittest.expectedFailure
def test_real_strides_0(self):
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
self.assertEqual(st.real_strides(), (8, None))
@unittest.expectedFailure
def test_real_strides_1(self):
st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
@unittest.expectedFailure
def test_real_strides_2(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
def test_cat_dim0_strides(self):
@ -152,12 +151,12 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
def test_symbolic_mask(self):
# taken from gpt2 single kvcache
# these two caused problems in gpt2 if reshape merged views
view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (1, 1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64)
view = View(shape=(1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64), strides=(0, 0, 64, 1), offset=1024, mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (Variable('start_pos', 1, 128).bind(2)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (1, 1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64)
assert view.reshape(new_shape) is None
view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
view = View(shape=(2, 1, (Variable('start_pos', 1, 128)+1), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (Variable('start_pos', 1, 128)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64)
assert view.reshape(new_shape) is None
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):

View File

@ -5,7 +5,6 @@ from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.ops import NumNode
import numpy as np
VARIABLE = ContextVar("VARIABLE", 0)
@ -141,7 +140,6 @@ class TestProd(unittest.TestCase):
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
def test_num_nodes(self): self.assertEqual(NumNode(6).render(), prod((NumNode(2), NumNode(3))).render())
class TestRoundUp(unittest.TestCase):
def test_round_up(self):

View File

@ -5,7 +5,6 @@ from tinygrad.dtype import dtypes
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
from tinygrad.ops import NumNode
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.codegen.uopgraph import sym
from itertools import product
@ -154,13 +153,12 @@ class TestViewMinify(unittest.TestCase):
assert len(View.create((10,10,10,10)).permute((1,0,2,3)).minify().shape) == 3
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
self.sts = [ShapeTracker.from_shape((prod(base_shape)+offset,)).shrink(((offset, offset+prod(base_shape)),)).\
reshape(base_shape) for base_shape in shapes for offset in offsets]
self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets]
self.offset = [offset for base_shape in shapes for offset in offsets]
self.shapes = [shape for shape in shapes for offset in offsets]
self.idxs_exprs = []
@ -791,8 +789,8 @@ class TestShapeTrackerSize(unittest.TestCase):
self.assertEqual(st.real_size(), 9950) # careful here
def test_size_variable(self):
st = ShapeTracker(views=(View(shape=(1, 1, 1, (NumNode(1)+Variable('start_pos', 0, 8192)), 1, 8, 4, 128), strides=(0, 0, 0, 1024, 0, 128, 0, 1),
offset=0, mask=None, contiguous=False), View(shape=(1, 32, 1, (NumNode(1)+Variable('start_pos', 0, 8192)), 128),
st = ShapeTracker(views=(View(shape=(1, 1, 1, (Variable('start_pos', 0, 8192)+1), 1, 8, 4, 128), strides=(0, 0, 0, 1024, 0, 128, 0, 1),
offset=0, mask=None, contiguous=False), View(shape=(1, 32, 1, (Variable('start_pos', 0, 8192)+1), 128),
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
self.assertEqual(st.real_size(), 8389632)

View File

@ -1,8 +1,7 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
import math, struct, ctypes
import math, struct, ctypes, functools
from dataclasses import dataclass
import functools
from tinygrad.helpers import getenv
ConstType = Union[float, int, bool]

View File

@ -976,6 +976,5 @@ renderer = PatternMatcher([
sint = Union[int, UOp]
Variable = UOp
def NumNode(val:int): return UOp.const(dtypes.int, val)
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int:
return int(uop.substitute({k:k.const_like(v) for k,v in var_vals.items()})) if isinstance(uop, UOp) else uop

View File

@ -3,7 +3,7 @@ import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union
from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp, NumNode, Variable, sint, sym_infer
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer
from tinygrad.helpers import prod, all_int, argsort
@functools.lru_cache(maxsize=None)
@ -177,14 +177,14 @@ class View:
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
extents: List[Tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s
if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
extents.append((merged_size, merged_term))
merged_size, merged_term = 1, NumNode(0)
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
if resolve(merged_term != 0): return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1