node cleanup + local metal test speed [pr] (#6880)

* node cleanup [pr]

* fix tests, including the double one on metal

* no time tqdm tests
This commit is contained in:
George Hotz 2024-10-04 18:14:23 +08:00 committed by GitHub
parent cdff1d75b6
commit a0cb16ac61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 37 additions and 45 deletions

View File

@ -41,7 +41,6 @@ def assert_same_lin(l1, l2):
# get features
import math
from tinygrad.shape.symbolic import Node
MAX_DIMS = 16
MAX_BUFS = 9
@ -58,7 +57,7 @@ def lin_to_feats(lin:Kernel, use_sts=True):
# first, the full shape, including the colors
for s,os,c in zip(lin.full_shape,lin.output_shape,lc):
if isinstance(s, Node):
if isinstance(s, UOp):
ret.append(False)
ret += [0]*9
else:

View File

@ -483,7 +483,7 @@ class TestIndexing(unittest.TestCase):
def get_set_tensor(indexed: Tensor, indexer):
set_size = indexed[indexer].shape
set_count = indexed[indexer].numel()
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size).cast(dtypes.float64)
set_tensor = Tensor.randint(set_count, high=set_count).reshape(set_size) #.cast(dtypes.float64)
return set_tensor
# Tensor is 0 1 2 3 4

View File

@ -4,7 +4,7 @@ from tinygrad import Device
from tinygrad.helpers import Timing, CI, OSX
import multiprocessing.shared_memory as shared_memory
N = 4096 if CI else 16384
N = 4096
class TestCopySpeed(unittest.TestCase):
@classmethod
def setUpClass(cls): Device[Device.DEFAULT].synchronize()

View File

@ -10,7 +10,7 @@ class TestTensorVariable(unittest.TestCase):
def test_inner_tvar_node(self):
vv = Variable("w", 0, 10).bind(2)
ret = Tensor.from_node(vv * 4).item()
ret = Tensor.from_uop(vv * 4).item()
assert ret == 8
def test_inner_tvar_mul(self):

View File

@ -274,8 +274,8 @@ class TestIndexExpressions2d(unittest.TestCase):
def test_reshape_combining_4(self):
# interestingly this one is quite slow
self.st = CheckingShapeTracker((1,1,5,5,1,1,5))
self.st.pad(((3,6), (0,0), (0,5), (0,0), (3,6), (0,0), (0,5)))
self.st.reshape((100,5,100))
self.st.pad(((2,1), (0,0), (0,2), (0,0), (2,1), (0,0), (0,2)))
self.st.reshape((28,5,28))
assert len(self.st.views) == 1
self.st.assert_same()

View File

@ -6,6 +6,8 @@ from tqdm import tqdm
from tinygrad.helpers import tqdm as tinytqdm, trange as tinytrange
import numpy as np
SLEEP_TIME = 0 # NOTE: this was 0.01, disabled tests that are flaky with time
class TestProgressBar(unittest.TestCase):
def _compare_bars(self, bar1, bar2):
prefix1, prog1, suffix1 = bar1.split("|")
@ -43,7 +45,7 @@ class TestProgressBar(unittest.TestCase):
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytqdm(range(total), desc="Test")):
time.sleep(0.01)
time.sleep(SLEEP_TIME)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
@ -60,6 +62,7 @@ class TestProgressBar(unittest.TestCase):
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
@unittest.skip("flaky without sleep time")
def test_unit_scale(self, mock_terminal_size, mock_stderr):
for unit_scale in [True, False]:
# NOTE: numpy comparison raises TypeError if exponent > 22
@ -72,7 +75,7 @@ class TestProgressBar(unittest.TestCase):
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale):
time.sleep(0.01)
time.sleep(SLEEP_TIME)
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
elapsed = n/iters_per_sec if n>0 else 0
@ -93,7 +96,7 @@ class TestProgressBar(unittest.TestCase):
expected_prefix = "Test"
# compare bars at each iteration (only when tinytqdm bar has been updated)
for i,n in enumerate(bar := tinytqdm(range(total), desc="Test")):
time.sleep(0.01)
time.sleep(SLEEP_TIME)
if bar.i % bar.skip != 0: continue
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
@ -120,7 +123,7 @@ class TestProgressBar(unittest.TestCase):
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n in (bar := tinytrange(total, desc="Test")):
time.sleep(0.01)
time.sleep(SLEEP_TIME)
if bar.i % bar.skip != 0: continue
tiny_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
iters_per_sec = float(tiny_output.split("it/s")[-2].split(" ")[-1]) if n>0 else 0
@ -147,7 +150,7 @@ class TestProgressBar(unittest.TestCase):
bar = tinytqdm(total=total, desc="Test")
n = 0
while n < total:
time.sleep(0.01)
time.sleep(SLEEP_TIME)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
@ -172,7 +175,7 @@ class TestProgressBar(unittest.TestCase):
bar = tinytqdm(total=0, desc="Test")
n = 0
while n < total:
time.sleep(0.01)
time.sleep(SLEEP_TIME)
incr = (total // 10) + random.randint(0, 100)
if n + incr > total: incr = total - n
bar.update(incr, close=n+incr==total)
@ -187,6 +190,7 @@ class TestProgressBar(unittest.TestCase):
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
@unittest.skip("flaky without sleep time")
def test_tqdm_output_custom_nolen_total(self, mock_terminal_size, mock_stderr):
for unit_scale in [True, False]:
for _ in range(3):
@ -198,7 +202,7 @@ class TestProgressBar(unittest.TestCase):
# compare bars at each iteration (only when tinytqdm bar has been updated)
for n,g in enumerate(tinytqdm(gen, desc="Test", unit_scale=unit_scale)):
assert g == n
time.sleep(0.01)
time.sleep(SLEEP_TIME)
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
if n:
iters_per_sec = float(tinytqdm_output.split("it/s")[-2].split(" ")[-1])
@ -211,11 +215,11 @@ class TestProgressBar(unittest.TestCase):
def test_tqdm_perf(self):
st = time.perf_counter()
for _ in tqdm(range(100)): time.sleep(0.01)
for _ in tqdm(range(100)): time.sleep(SLEEP_TIME)
tqdm_time = time.perf_counter() - st
st = time.perf_counter()
for _ in tinytqdm(range(100)): time.sleep(0.01)
for _ in tinytqdm(range(100)): time.sleep(SLEEP_TIME)
tinytqdm_time = time.perf_counter() - st
assert tinytqdm_time < 2 * tqdm_time

View File

@ -11,6 +11,7 @@ from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
from tinygrad.shape.symbolic import Variable
import functools
def render(self) -> Tuple[str, ConstType, ConstType]:
@ -26,8 +27,6 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
return fxn.split("data0[0] = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax):
return UOp.define_var(expr, dtypes.int, nmin, nmax if isinstance(nmax, int) else nmax.arg)
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)

View File

@ -5,16 +5,7 @@ from tinygrad.ops import UOp, UOps, exec_alu, ConstType
sint = Union[int, UOp]
# broken
Node = UOp
MulNode = UOp
SumNode = UOp
DivNode = UOp
ModNode = UOp
LtNode = UOp
AndNode = UOp
def NumNode(val:int): return UOp.const(dtypes.int, val)
class Variable(UOp):
def __reduce__(self): return Variable, self.arg
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union
from tinygrad.ops import resolve, UOp
from tinygrad.helpers import prod, all_int, argsort
from tinygrad.shape.symbolic import Node, NumNode, Variable, sint, sym_infer
from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer
@functools.lru_cache(maxsize=None)
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
@ -93,7 +93,7 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self) -> int:
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
ret = prod([x.vmax if isinstance(x, Node) else x for x in self.shape])
ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
assert isinstance(ret, int), f"{ret=} is not int"
return ret
@ -127,7 +127,7 @@ class View:
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def vars(self) -> Set[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def unbind(self) -> Tuple[View, Dict[Variable, int]]:
@ -164,9 +164,9 @@ class View:
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[Node] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, Node]] = []
extents: List[Tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
merged_term += sum([idxs[d1] * (s1 * merged_size) for d1, s1 in term]) + o * merged_size
merged_size *= s

View File

@ -12,7 +12,7 @@ from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.shape.symbolic import sint, Variable, Node
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
@ -137,7 +137,9 @@ class Tensor:
# create a LazyBuffer from the different types of inputs
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, UOp): data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
elif isinstance(data, UOp):
assert data.op is UOps.ASSIGN and data.src[0].op is UOps.DEFINE_VAR and data.src[1].op is UOps.CONST, f"can't create tensor from UOp {data}"
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
if dtype is None:
@ -375,15 +377,12 @@ class Tensor:
return self
@staticmethod
def from_node(y:UOp, **kwargs) -> Tensor:
# NOTE: we only support Tensors from DEFINE_VAR or CONST
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is UOps.ASSIGN: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is UOps.ASSIGN:
assert y.src[0].op is UOps.DEFINE_VAR
return Tensor(y, **kwargs, requires_grad=False)
if y.op is UOps.ALU:
if y.arg is BinaryOps.MUL: return Tensor.from_node(y.src[0]) * Tensor.from_node(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_node(y.src[0]) + Tensor.from_node(y.src[1])
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
raise RuntimeError(f"unhandled Node {y}")
# ***** creation entrypoint *****
@ -2696,14 +2695,14 @@ class Tensor:
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(padded), shape=shape)
def _broadcasted(self, y:Union[Tensor, Node, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
assert isinstance(y, (*get_args(ConstType), Node)), f"{type(y)=}, {y=}"
assert isinstance(y, (*get_args(ConstType), UOp)), f"{type(y)=}, {y=}"
if isinstance(x.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
elif not isinstance(y, Node): y_dtype = dtypes.from_py(y)
if isinstance(y, Node): y = Tensor.from_node(y, device=x.device)
elif not isinstance(y, UOp): y_dtype = dtypes.from_py(y)
if isinstance(y, UOp): y = Tensor.from_uop(y, device=x.device)
else: y = Tensor(dtypes.as_const(y, y_dtype), x.device, y_dtype, requires_grad=False)
if match_dtype and x.dtype != y.dtype: