diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 56e66f27..66a64094 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -1,12 +1,13 @@ # stuff needed to unpack a kernel from typing import Tuple from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps +from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.ops import UOp, UOps, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad.shape.symbolic import NumNode inf, nan = float('inf'), float('nan') # kernel unpacker diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index e4350286..c37e8e57 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -2,8 +2,7 @@ import unittest import numpy as np from tinygrad.helpers import BEAM, Timing, CI -from tinygrad.shape.symbolic import Variable -from tinygrad.tensor import Tensor +from tinygrad import Variable, Tensor from tinygrad.nn import Conv2d def rand(*shape): diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index 79eb52f9..a3c505f9 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -1,7 +1,8 @@ import itertools import random +from tinygrad import Variable from tinygrad.helpers import DEBUG -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad.shape.symbolic import NumNode random.seed(42) def add_v(expr, rng=None): diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index fdf80a76..9b543f0d 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -3,9 +3,8 @@ import numpy as np from tinygrad.nn import optim from tinygrad.nn.state import get_parameters from tinygrad.engine.jit import TinyJit -from tinygrad import Tensor, Device, GlobalCounters, dtypes +from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable from tinygrad.helpers import CI, Context -from tinygrad.shape.symbolic import Variable from extra.lr_scheduler import OneCycleLR from test.helpers import derandomize_model, is_dtype_supported diff --git a/test/test_sample.py b/test/test_sample.py index 3bb4cf76..9af7557c 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -1,7 +1,6 @@ import unittest import numpy as np -from tinygrad.tensor import Tensor -from tinygrad.shape.symbolic import Variable +from tinygrad import Tensor, Variable class TestSample(unittest.TestCase): def test_sample(self): diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index e4e87690..05608180 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -1,9 +1,7 @@ import unittest from test.helpers import assert_jit_cache_len -from tinygrad.engine.jit import TinyJit -from tinygrad.shape.symbolic import Variable -from tinygrad.tensor import Tensor +from tinygrad import Variable, Tensor, TinyJit import numpy as np class TestSymbolicJit(unittest.TestCase): diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 8142c513..ada84bd9 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -1,5 +1,5 @@ import unittest -from tinygrad.shape.symbolic import Variable +from tinygrad import Variable from tinygrad.helpers import getenv from tinygrad.tensor import Tensor from examples.gpt2 import Attention diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 17f644f2..f89583b4 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -1,6 +1,7 @@ import unittest from tinygrad.shape.shapetracker import ShapeTracker, View -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad import Variable +from tinygrad.shape.symbolic import NumNode from tinygrad.tensor import Tensor class TestSymbolic(unittest.TestCase): diff --git a/test/test_uops.py b/test/test_uops.py index 870c6cc9..0ddc0266 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -12,7 +12,6 @@ from tinygrad.engine.schedule import create_schedule, reduceop_fusor from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym -from tinygrad.shape.symbolic import Variable from test.helpers import is_dtype_supported, assert_equiv_uops def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) @@ -360,7 +359,7 @@ class TestUOpMethod(unittest.TestCase): assert (add < mul) or (mul < add), "add and mul with same src should have an order" def test_uop_variables(self): - a = Variable("a", 1, 10) + a = UOp.variable("a", 1, 10) uop_var = UOp.const(dtypes.int, a) st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0), ShapeTracker.from_shape((2, a)).to_uop())) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 77580491..c8f4fa2b 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,10 +1,11 @@ import gzip, unittest from PIL import Image +from tinygrad import Variable from tinygrad.helpers import Context, ContextVar from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv from tinygrad.tensor import get_shape from tinygrad.codegen.lowerer import get_contraction -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad.shape.symbolic import NumNode import numpy as np VARIABLE = ContextVar("VARIABLE", 0) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index f55dab82..d7082543 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -4,7 +4,8 @@ import numpy as np from tinygrad.dtype import dtypes from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View -from tinygrad.shape.symbolic import Variable, NumNode +from tinygrad import Variable +from tinygrad.shape.symbolic import NumNode from tinygrad.ops import UOp, UOps, graph_rewrite from tinygrad.codegen.uopgraph import sym from itertools import product diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index cf9a7561..96af3d82 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -3,7 +3,7 @@ from typing import List from tinygrad.helpers import prod from tinygrad.shape.view import View from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable +from tinygrad import Variable from test.unit.test_shapetracker import shapetracker_getitem class MultiShapeTracker: diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index a95d1770..f280b040 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes, PtrDType, ConstType from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite from tinygrad.ops import BinaryOps, UOp, UOps, print_uops -from tinygrad.shape.symbolic import Variable +from tinygrad import Variable import functools def render(self) -> Tuple[str, ConstType, ConstType]: diff --git a/tinygrad/__init__.py b/tinygrad/__init__.py index 841215de..947d03c9 100644 --- a/tinygrad/__init__.py +++ b/tinygrad/__init__.py @@ -4,7 +4,8 @@ if int(os.getenv("TYPED", "0")): install_import_hook(__name__) from tinygrad.tensor import Tensor # noqa: F401 from tinygrad.engine.jit import TinyJit # noqa: F401 -from tinygrad.shape.symbolic import Variable # noqa: F401 +from tinygrad.ops import UOp +Variable = UOp.variable from tinygrad.dtype import dtypes # noqa: F401 from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401 from tinygrad.device import Device # noqa: F401 diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 6a6c4b3d..2eb5f473 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -603,7 +603,8 @@ class Kernel: # kernel name (before late upcast) name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \ (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \ - colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) + colored('_', 'BLACK').join([colored(str(x.render() if isinstance(x, UOp) else x), c) \ + for x,c in zip(self.full_shape, self.colors())]) # name the function something unique Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 8ed9838b..837b759c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -185,6 +185,7 @@ class UOp(MathTrait): __slots__ = ["op", "dtype", "src", "arg"] def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): # TODO: instant check rules here make debugging easier + #assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}" #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool #if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}" #if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src]) @@ -276,14 +277,35 @@ class UOp(MathTrait): if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore + + # *** Variable stuff *** + + @staticmethod + def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val)) @staticmethod def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) + @property + def expr(self): + assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" + return self.arg[0] + def bind(self, val:int): + assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" + assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}" + return UOp(UOps.BIND, self.dtype, (self, self.const_like(val))) def unbind(self) -> Tuple[Variable, int]: assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}" - from tinygrad.shape.symbolic import Variable - return cast(Variable, self.src[0]), self.src[1].arg + return self.src[0], self.src[1].arg @property def val(self) -> int: return self.unbind()[1] + def vars(self) -> Set[UOp]: + bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR]) + bound_var_base = set(x.src[0] for x in bound_vars) + all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) + return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) + def variables(self) -> List[Variable]: + st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] + return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) + # TODO: this is context rewrite def substitute(self, dvars:Dict[UOp, UOp]): if self in dvars: return dvars[self] @@ -299,15 +321,6 @@ class UOp(MathTrait): @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) - def vars(self) -> Set[UOp]: - bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR]) - bound_var_base = set(x.src[0] for x in bound_vars) - all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) - return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) - def variables(self) -> List[Variable]: - st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - from tinygrad.shape.symbolic import Variable - return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg) def const_factor(self) -> int: """largest known int that divides self""" if self.op is UOps.CONST: return self.arg diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 1a8006c1..985b1604 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -43,7 +43,7 @@ class Program: if not self._ran_post_init and self.uops is not None: # single pass through the uops for u in self.uops: - if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2])) + if u.op is UOps.DEFINE_VAR: self.vars.append(u) if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) if u.op is UOps.SPECIAL: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 605e2141..c24dade8 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -1,28 +1,17 @@ from __future__ import annotations -from typing import Union, Optional, Dict, cast +from typing import Union, Optional, Dict from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, exec_alu, ConstType +from tinygrad.ops import UOp, UOps, exec_alu sint = Union[int, UOp] def NumNode(val:int): return UOp.const(dtypes.int, val) -class Variable(UOp): - def __reduce__(self): return Variable, self.arg - def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs - return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax)) - def __init__(self, expr:str, nmin:ConstType, nmax:ConstType): - super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax)) - def bind(self, val:int): - assert self.op is UOps.DEFINE_VAR, f"op is {self.op}" - assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}" - return UOp(UOps.BIND, self.dtype, (self, self.const_like(val))) - @property - def expr(self): return self.arg[0] +Variable = UOp def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int: if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom if uop.op == UOps.CONST: return uop.arg - if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)] + if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop] if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value if uop.op == UOps.ALU: src_values = [sym_infer(src, var_vals) for src in uop.src] diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 55dc5a8a..71970385 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -164,7 +164,7 @@ class View: # Merge dimensions in vm2 if required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. - idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] + idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] merged_size, merged_term = 1, NumNode(0) extents: List[Tuple[sint, UOp]] = [] for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):