mirror of https://github.com/commaai/tinygrad.git
remove NumNode (#7035)
This commit is contained in:
parent
c4c806a210
commit
bd8ecf7fd6
|
@ -7,7 +7,6 @@ from tinygrad.ops import UOp, UOps, KernelInfo
|
|||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import NumNode
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import itertools
|
||||
import random
|
||||
from tinygrad import Variable
|
||||
from tinygrad import Variable, dtypes
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import NumNode
|
||||
random.seed(42)
|
||||
|
||||
def add_v(expr, rng=None):
|
||||
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
|||
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
|
||||
# 10% of the time, add one of lt, le, gt, ge
|
||||
if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
|
||||
expr = NumNode(0)
|
||||
expr = UOp.const(dtypes.int, 0)
|
||||
rngs = []
|
||||
for t in tape:
|
||||
expr, rng = t(expr)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
|
@ -16,17 +15,17 @@ class TestSymbolic(unittest.TestCase):
|
|||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_0(self):
|
||||
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||
st = ShapeTracker(views=(View(shape=(2, (Variable('start_pos', 1, 8)+1), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (Variable('start_pos', 1, 8)+1)), strides=((Variable('start_pos', 1, 8)+1), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (8, None))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_1(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+2)), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_2(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=0, mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
def test_cat_dim0_strides(self):
|
||||
|
@ -152,12 +151,12 @@ class TestSymbolicReshapeFromContiguous(unittest.TestCase):
|
|||
def test_symbolic_mask(self):
|
||||
# taken from gpt2 single kvcache
|
||||
# these two caused problems in gpt2 if reshape merged views
|
||||
view = View(shape=(1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64), strides=(0, 0, 64, 1), offset=NumNode(1024), mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (NumNode(1)+Variable('start_pos', 1, 128).bind(2))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (1, 1, (NumNode(1)+Variable('start_pos', 1, 128).bind(2)), 16, 64)
|
||||
view = View(shape=(1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64), strides=(0, 0, 64, 1), offset=1024, mask=((0, 1), (Variable('start_pos', 1, 128).bind(2), (Variable('start_pos', 1, 128).bind(2)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (1, 1, (Variable('start_pos', 1, 128).bind(2)+1), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
view = View(shape=(2, 1, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (NumNode(1)+Variable('start_pos', 1, 128))), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (2, (NumNode(1)+Variable('start_pos', 1, 128)), 16, 64)
|
||||
view = View(shape=(2, 1, (Variable('start_pos', 1, 128)+1), 16, 64), strides=(0, 0, 1024, 64, 1), offset=131072, mask=((1, 2), (0, 1), (0, (Variable('start_pos', 1, 128)+1)), (0, 16), (0, 64)), contiguous=False) # noqa: E501
|
||||
new_shape = (2, (Variable('start_pos', 1, 128)+1), 16, 64)
|
||||
assert view.reshape(new_shape) is None
|
||||
|
||||
class TestSymbolicReshapeFromNonContiguous(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,6 @@ from tinygrad.helpers import Context, ContextVar
|
|||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.ops import NumNode
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
@ -141,7 +140,6 @@ class TestProd(unittest.TestCase):
|
|||
def test_ints(self): self.assertEqual(30, prod((2, 3, 5)))
|
||||
def test_variable(self): self.assertEqual("(a*12)", prod((Variable("a", 1, 5), 3, 4)).render())
|
||||
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6).render(), prod((NumNode(2), NumNode(3))).render())
|
||||
|
||||
class TestRoundUp(unittest.TestCase):
|
||||
def test_round_up(self):
|
||||
|
|
|
@ -5,7 +5,6 @@ from tinygrad.dtype import dtypes
|
|||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad import Variable
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
from itertools import product
|
||||
|
@ -154,13 +153,12 @@ class TestViewMinify(unittest.TestCase):
|
|||
assert len(View.create((10,10,10,10)).permute((1,0,2,3)).minify().shape) == 3
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker.from_shape((prod(base_shape)+offset,)).shrink(((offset, offset+prod(base_shape)),)).\
|
||||
reshape(base_shape) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [offset for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.idxs_exprs = []
|
||||
|
||||
|
@ -791,8 +789,8 @@ class TestShapeTrackerSize(unittest.TestCase):
|
|||
self.assertEqual(st.real_size(), 9950) # careful here
|
||||
|
||||
def test_size_variable(self):
|
||||
st = ShapeTracker(views=(View(shape=(1, 1, 1, (NumNode(1)+Variable('start_pos', 0, 8192)), 1, 8, 4, 128), strides=(0, 0, 0, 1024, 0, 128, 0, 1),
|
||||
offset=0, mask=None, contiguous=False), View(shape=(1, 32, 1, (NumNode(1)+Variable('start_pos', 0, 8192)), 128),
|
||||
st = ShapeTracker(views=(View(shape=(1, 1, 1, (Variable('start_pos', 0, 8192)+1), 1, 8, 4, 128), strides=(0, 0, 0, 1024, 0, 128, 0, 1),
|
||||
offset=0, mask=None, contiguous=False), View(shape=(1, 32, 1, (Variable('start_pos', 0, 8192)+1), 128),
|
||||
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
|
||||
self.assertEqual(st.real_size(), 8389632)
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
|
||||
import math, struct, ctypes
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
ConstType = Union[float, int, bool]
|
||||
|
|
|
@ -976,6 +976,5 @@ renderer = PatternMatcher([
|
|||
sint = Union[int, UOp]
|
||||
Variable = UOp
|
||||
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int:
|
||||
return int(uop.substitute({k:k.const_like(v) for k,v in var_vals.items()})) if isinstance(uop, UOp) else uop
|
||||
|
|
|
@ -3,7 +3,7 @@ import functools, operator, itertools, math
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp, NumNode, Variable, sint, sym_infer
|
||||
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
|
@ -177,14 +177,14 @@ class View:
|
|||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
extents: List[Tuple[sint, UOp]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
|
||||
merged_size *= s
|
||||
if not resolve(merged_term >= merged_size) and not resolve(merged_term < 0):
|
||||
extents.append((merged_size, merged_term))
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
|
||||
if resolve(merged_term != 0): return None
|
||||
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
|
||||
return (reshaped_vm2 := vm2.reshape(vm2_shape)) and reshaped_vm2 + vm1
|
||||
|
|
Loading…
Reference in New Issue