UOp.variable (#7010)

* UOp.variable [pr]

* fix tests

* clean

* improve name rendering

* last bug
This commit is contained in:
George Hotz 2024-10-12 18:20:44 +08:00 committed by GitHub
parent f79e05cac0
commit 5ae2de9845
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 52 additions and 49 deletions

View File

@ -1,12 +1,13 @@
# stuff needed to unpack a kernel # stuff needed to unpack a kernel
from typing import Tuple from typing import Tuple
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, UOps, KernelInfo from tinygrad.ops import UOp, UOps, KernelInfo
from tinygrad.dtype import dtypes, PtrDType from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable, NumNode from tinygrad.shape.symbolic import NumNode
inf, nan = float('inf'), float('nan') inf, nan = float('inf'), float('nan')
# kernel unpacker # kernel unpacker

View File

@ -2,8 +2,7 @@ import unittest
import numpy as np import numpy as np
from tinygrad.helpers import BEAM, Timing, CI from tinygrad.helpers import BEAM, Timing, CI
from tinygrad.shape.symbolic import Variable from tinygrad import Variable, Tensor
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d from tinygrad.nn import Conv2d
def rand(*shape): def rand(*shape):

View File

@ -1,7 +1,8 @@
import itertools import itertools
import random import random
from tinygrad import Variable
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.shape.symbolic import Variable, NumNode from tinygrad.shape.symbolic import NumNode
random.seed(42) random.seed(42)
def add_v(expr, rng=None): def add_v(expr, rng=None):

View File

@ -3,9 +3,8 @@ import numpy as np
from tinygrad.nn import optim from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters from tinygrad.nn.state import get_parameters
from tinygrad.engine.jit import TinyJit from tinygrad.engine.jit import TinyJit
from tinygrad import Tensor, Device, GlobalCounters, dtypes from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
from tinygrad.helpers import CI, Context from tinygrad.helpers import CI, Context
from tinygrad.shape.symbolic import Variable
from extra.lr_scheduler import OneCycleLR from extra.lr_scheduler import OneCycleLR
from test.helpers import derandomize_model, is_dtype_supported from test.helpers import derandomize_model, is_dtype_supported

View File

@ -1,7 +1,6 @@
import unittest import unittest
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad import Tensor, Variable
from tinygrad.shape.symbolic import Variable
class TestSample(unittest.TestCase): class TestSample(unittest.TestCase):
def test_sample(self): def test_sample(self):

View File

@ -1,9 +1,7 @@
import unittest import unittest
from test.helpers import assert_jit_cache_len from test.helpers import assert_jit_cache_len
from tinygrad.engine.jit import TinyJit from tinygrad import Variable, Tensor, TinyJit
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor
import numpy as np import numpy as np
class TestSymbolicJit(unittest.TestCase): class TestSymbolicJit(unittest.TestCase):

View File

@ -1,5 +1,5 @@
import unittest import unittest
from tinygrad.shape.symbolic import Variable from tinygrad import Variable
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from examples.gpt2 import Attention from examples.gpt2 import Attention

View File

@ -1,6 +1,7 @@
import unittest import unittest
from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode from tinygrad import Variable
from tinygrad.shape.symbolic import NumNode
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
class TestSymbolic(unittest.TestCase): class TestSymbolic(unittest.TestCase):

View File

@ -12,7 +12,6 @@ from tinygrad.engine.schedule import create_schedule, reduceop_fusor
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.shape.symbolic import Variable
from test.helpers import is_dtype_supported, assert_equiv_uops from test.helpers import is_dtype_supported, assert_equiv_uops
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
@ -360,7 +359,7 @@ class TestUOpMethod(unittest.TestCase):
assert (add < mul) or (mul < add), "add and mul with same src should have an order" assert (add < mul) or (mul < add), "add and mul with same src should have an order"
def test_uop_variables(self): def test_uop_variables(self):
a = Variable("a", 1, 10) a = UOp.variable("a", 1, 10)
uop_var = UOp.const(dtypes.int, a) uop_var = UOp.const(dtypes.int, a)
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0), st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
ShapeTracker.from_shape((2, a)).to_uop())) ShapeTracker.from_shape((2, a)).to_uop()))

View File

@ -1,10 +1,11 @@
import gzip, unittest import gzip, unittest
from PIL import Image from PIL import Image
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
from tinygrad.tensor import get_shape from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction from tinygrad.codegen.lowerer import get_contraction
from tinygrad.shape.symbolic import Variable, NumNode from tinygrad.shape.symbolic import NumNode
import numpy as np import numpy as np
VARIABLE = ContextVar("VARIABLE", 0) VARIABLE = ContextVar("VARIABLE", 0)

View File

@ -4,7 +4,8 @@ import numpy as np
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.helpers import prod from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode from tinygrad import Variable
from tinygrad.shape.symbolic import NumNode
from tinygrad.ops import UOp, UOps, graph_rewrite from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.codegen.uopgraph import sym from tinygrad.codegen.uopgraph import sym
from itertools import product from itertools import product

View File

@ -3,7 +3,7 @@ from typing import List
from tinygrad.helpers import prod from tinygrad.helpers import prod
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable from tinygrad import Variable
from test.unit.test_shapetracker import shapetracker_getitem from test.unit.test_shapetracker import shapetracker_getitem
class MultiShapeTracker: class MultiShapeTracker:

View File

@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
from tinygrad.shape.symbolic import Variable from tinygrad import Variable
import functools import functools
def render(self) -> Tuple[str, ConstType, ConstType]: def render(self) -> Tuple[str, ConstType, ConstType]:

View File

@ -4,7 +4,8 @@ if int(os.getenv("TYPED", "0")):
install_import_hook(__name__) install_import_hook(__name__)
from tinygrad.tensor import Tensor # noqa: F401 from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401 from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401 from tinygrad.ops import UOp
Variable = UOp.variable
from tinygrad.dtype import dtypes # noqa: F401 from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401 from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
from tinygrad.device import Device # noqa: F401 from tinygrad.device import Device # noqa: F401

View File

@ -603,7 +603,8 @@ class Kernel:
# kernel name (before late upcast) # kernel name (before late upcast)
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \ name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \ (f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())]) colored('_', 'BLACK').join([colored(str(x.render() if isinstance(x, UOp) else x), c) \
for x,c in zip(self.full_shape, self.colors())])
# name the function something unique # name the function something unique
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1 Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1

View File

@ -185,6 +185,7 @@ class UOp(MathTrait):
__slots__ = ["op", "dtype", "src", "arg"] __slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier # TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}" #if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src]) #if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
@ -276,14 +277,35 @@ class UOp(MathTrait):
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
# *** Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
@staticmethod @staticmethod
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
def unbind(self) -> Tuple[Variable, int]: def unbind(self) -> Tuple[Variable, int]:
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}" assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
from tinygrad.shape.symbolic import Variable return self.src[0], self.src[1].arg
return cast(Variable, self.src[0]), self.src[1].arg
@property @property
def val(self) -> int: return self.unbind()[1] def val(self) -> int: return self.unbind()[1]
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# TODO: this is context rewrite # TODO: this is context rewrite
def substitute(self, dvars:Dict[UOp, UOp]): def substitute(self, dvars:Dict[UOp, UOp]):
if self in dvars: return dvars[self] if self in dvars: return dvars[self]
@ -299,15 +321,6 @@ class UOp(MathTrait):
@functools.cached_property @functools.cached_property
def full_shape(self) -> Tuple[sint, ...]: def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
from tinygrad.shape.symbolic import Variable
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
if self.op is UOps.CONST: return self.arg if self.op is UOps.CONST: return self.arg

View File

@ -43,7 +43,7 @@ class Program:
if not self._ran_post_init and self.uops is not None: if not self._ran_post_init and self.uops is not None:
# single pass through the uops # single pass through the uops
for u in self.uops: for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2])) if u.op is UOps.DEFINE_VAR: self.vars.append(u)
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL: if u.op is UOps.SPECIAL:

View File

@ -1,28 +1,17 @@
from __future__ import annotations from __future__ import annotations
from typing import Union, Optional, Dict, cast from typing import Union, Optional, Dict
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, exec_alu, ConstType from tinygrad.ops import UOp, UOps, exec_alu
sint = Union[int, UOp] sint = Union[int, UOp]
def NumNode(val:int): return UOp.const(dtypes.int, val) def NumNode(val:int): return UOp.const(dtypes.int, val)
class Variable(UOp): Variable = UOp
def __reduce__(self): return Variable, self.arg
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
@property
def expr(self): return self.arg[0]
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int: def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
if uop.op == UOps.CONST: return uop.arg if uop.op == UOps.CONST: return uop.arg
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)] if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop]
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
if uop.op == UOps.ALU: if uop.op == UOps.ALU:
src_values = [sym_infer(src, var_vals) for src in uop.src] src_values = [sym_infer(src, var_vals) for src in uop.src]

View File

@ -164,7 +164,7 @@ class View:
# Merge dimensions in vm2 if required. # Merge dimensions in vm2 if required.
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
merged_size, merged_term = 1, NumNode(0) merged_size, merged_term = 1, NumNode(0)
extents: List[Tuple[sint, UOp]] = [] extents: List[Tuple[sint, UOp]] = []
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)): for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):