UOp.variable (#7010)

* UOp.variable [pr]

* fix tests

* clean

* improve name rendering

* last bug
This commit is contained in:
George Hotz 2024-10-12 18:20:44 +08:00 committed by GitHub
parent f79e05cac0
commit 5ae2de9845
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 52 additions and 49 deletions

View File

@ -1,12 +1,13 @@
# stuff needed to unpack a kernel
from typing import Tuple
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, UOps, KernelInfo
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.shape.symbolic import NumNode
inf, nan = float('inf'), float('nan')
# kernel unpacker

View File

@ -2,8 +2,7 @@ import unittest
import numpy as np
from tinygrad.helpers import BEAM, Timing, CI
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
from tinygrad import Variable, Tensor
from tinygrad.nn import Conv2d
def rand(*shape):

View File

@ -1,7 +1,8 @@
import itertools
import random
from tinygrad import Variable
from tinygrad.helpers import DEBUG
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.shape.symbolic import NumNode
random.seed(42)
def add_v(expr, rng=None):

View File

@ -3,9 +3,8 @@ import numpy as np
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.engine.jit import TinyJit
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
from tinygrad.helpers import CI, Context
from tinygrad.shape.symbolic import Variable
from extra.lr_scheduler import OneCycleLR
from test.helpers import derandomize_model, is_dtype_supported

View File

@ -1,7 +1,6 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.shape.symbolic import Variable
from tinygrad import Tensor, Variable
class TestSample(unittest.TestCase):
def test_sample(self):

View File

@ -1,9 +1,7 @@
import unittest
from test.helpers import assert_jit_cache_len
from tinygrad.engine.jit import TinyJit
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
from tinygrad import Variable, Tensor, TinyJit
import numpy as np
class TestSymbolicJit(unittest.TestCase):

View File

@ -1,5 +1,5 @@
import unittest
from tinygrad.shape.symbolic import Variable
from tinygrad import Variable
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor
from examples.gpt2 import Attention

View File

@ -1,6 +1,7 @@
import unittest
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad import Variable
from tinygrad.shape.symbolic import NumNode
from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase):

View File

@ -12,7 +12,6 @@ from tinygrad.engine.schedule import create_schedule, reduceop_fusor
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.shape.symbolic import Variable
from test.helpers import is_dtype_supported, assert_equiv_uops
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
@ -360,7 +359,7 @@ class TestUOpMethod(unittest.TestCase):
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
def test_uop_variables(self):
a = Variable("a", 1, 10)
a = UOp.variable("a", 1, 10)
uop_var = UOp.const(dtypes.int, a)
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
ShapeTracker.from_shape((2, a)).to_uop()))

View File

@ -1,10 +1,11 @@
import gzip, unittest
from PIL import Image
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.shape.symbolic import NumNode
import numpy as np
VARIABLE = ContextVar("VARIABLE", 0)

View File

@ -4,7 +4,8 @@ import numpy as np
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad import Variable
from tinygrad.shape.symbolic import NumNode
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.codegen.uopgraph import sym
from itertools import product

View File

@ -3,7 +3,7 @@ from typing import List
from tinygrad.helpers import prod
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from tinygrad import Variable
from test.unit.test_shapetracker import shapetracker_getitem
class MultiShapeTracker:

View File

@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
from tinygrad.shape.symbolic import Variable
from tinygrad import Variable
import functools
def render(self) -> Tuple[str, ConstType, ConstType]:

View File

@ -4,7 +4,8 @@ if int(os.getenv("TYPED", "0")):
install_import_hook(__name__)
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.ops import UOp
Variable = UOp.variable
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
from tinygrad.device import Device # noqa: F401

View File

@ -603,7 +603,8 @@ class Kernel:
# kernel name (before late upcast)
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
colored('_', 'BLACK').join([colored(str(x.render() if isinstance(x, UOp) else x), c) \
for x,c in zip(self.full_shape, self.colors())])
# name the function something unique
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1

View File

@ -185,6 +185,7 @@ class UOp(MathTrait):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
@ -276,14 +277,35 @@ class UOp(MathTrait):
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
# *** Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
@staticmethod
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
def unbind(self) -> Tuple[Variable, int]:
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
from tinygrad.shape.symbolic import Variable
return cast(Variable, self.src[0]), self.src[1].arg
return self.src[0], self.src[1].arg
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# TODO: this is context rewrite
def substitute(self, dvars:Dict[UOp, UOp]):
if self in dvars: return dvars[self]
@ -299,15 +321,6 @@ class UOp(MathTrait):
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
from tinygrad.shape.symbolic import Variable
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg

View File

@ -43,7 +43,7 @@ class Program:
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2]))
if u.op is UOps.DEFINE_VAR: self.vars.append(u)
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL:

View File

@ -1,28 +1,17 @@
from __future__ import annotations
from typing import Union, Optional, Dict, cast
from typing import Union, Optional, Dict
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, exec_alu, ConstType
from tinygrad.ops import UOp, UOps, exec_alu
sint = Union[int, UOp]
def NumNode(val:int): return UOp.const(dtypes.int, val)
class Variable(UOp):
def __reduce__(self): return Variable, self.arg
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
@property
def expr(self): return self.arg[0]
Variable = UOp
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
if uop.op == UOps.CONST: return uop.arg
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)]
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop]
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
if uop.op == UOps.ALU:
src_values = [sym_infer(src, var_vals) for src in uop.src]

View File

@ -164,7 +164,7 @@ class View:
# Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):