mirror of https://github.com/commaai/tinygrad.git
UOp.variable (#7010)
* UOp.variable [pr] * fix tests * clean * improve name rendering * last bug
This commit is contained in:
parent
f79e05cac0
commit
5ae2de9845
|
@ -1,12 +1,13 @@
|
||||||
# stuff needed to unpack a kernel
|
# stuff needed to unpack a kernel
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
|
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
|
||||||
|
from tinygrad import Variable
|
||||||
from tinygrad.codegen.kernel import Opt, OptOps
|
from tinygrad.codegen.kernel import Opt, OptOps
|
||||||
from tinygrad.ops import UOp, UOps, KernelInfo
|
from tinygrad.ops import UOp, UOps, KernelInfo
|
||||||
from tinygrad.dtype import dtypes, PtrDType
|
from tinygrad.dtype import dtypes, PtrDType
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode
|
from tinygrad.shape.symbolic import NumNode
|
||||||
inf, nan = float('inf'), float('nan')
|
inf, nan = float('inf'), float('nan')
|
||||||
|
|
||||||
# kernel unpacker
|
# kernel unpacker
|
||||||
|
|
|
@ -2,8 +2,7 @@ import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tinygrad.helpers import BEAM, Timing, CI
|
from tinygrad.helpers import BEAM, Timing, CI
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad import Variable, Tensor
|
||||||
from tinygrad.tensor import Tensor
|
|
||||||
from tinygrad.nn import Conv2d
|
from tinygrad.nn import Conv2d
|
||||||
|
|
||||||
def rand(*shape):
|
def rand(*shape):
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import itertools
|
import itertools
|
||||||
import random
|
import random
|
||||||
|
from tinygrad import Variable
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode
|
from tinygrad.shape.symbolic import NumNode
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
|
|
||||||
def add_v(expr, rng=None):
|
def add_v(expr, rng=None):
|
||||||
|
|
|
@ -3,9 +3,8 @@ import numpy as np
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
from tinygrad.nn.state import get_parameters
|
from tinygrad.nn.state import get_parameters
|
||||||
from tinygrad.engine.jit import TinyJit
|
from tinygrad.engine.jit import TinyJit
|
||||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
|
||||||
from tinygrad.helpers import CI, Context
|
from tinygrad.helpers import CI, Context
|
||||||
from tinygrad.shape.symbolic import Variable
|
|
||||||
from extra.lr_scheduler import OneCycleLR
|
from extra.lr_scheduler import OneCycleLR
|
||||||
from test.helpers import derandomize_model, is_dtype_supported
|
from test.helpers import derandomize_model, is_dtype_supported
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad import Tensor, Variable
|
||||||
from tinygrad.shape.symbolic import Variable
|
|
||||||
|
|
||||||
class TestSample(unittest.TestCase):
|
class TestSample(unittest.TestCase):
|
||||||
def test_sample(self):
|
def test_sample(self):
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from test.helpers import assert_jit_cache_len
|
from test.helpers import assert_jit_cache_len
|
||||||
from tinygrad.engine.jit import TinyJit
|
from tinygrad import Variable, Tensor, TinyJit
|
||||||
from tinygrad.shape.symbolic import Variable
|
|
||||||
from tinygrad.tensor import Tensor
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class TestSymbolicJit(unittest.TestCase):
|
class TestSymbolicJit(unittest.TestCase):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad import Variable
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from examples.gpt2 import Attention
|
from examples.gpt2 import Attention
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode
|
from tinygrad import Variable
|
||||||
|
from tinygrad.shape.symbolic import NumNode
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
class TestSymbolic(unittest.TestCase):
|
class TestSymbolic(unittest.TestCase):
|
||||||
|
|
|
@ -12,7 +12,6 @@ from tinygrad.engine.schedule import create_schedule, reduceop_fusor
|
||||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||||
from tinygrad.shape.symbolic import Variable
|
|
||||||
from test.helpers import is_dtype_supported, assert_equiv_uops
|
from test.helpers import is_dtype_supported, assert_equiv_uops
|
||||||
|
|
||||||
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||||
|
@ -360,7 +359,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||||
|
|
||||||
def test_uop_variables(self):
|
def test_uop_variables(self):
|
||||||
a = Variable("a", 1, 10)
|
a = UOp.variable("a", 1, 10)
|
||||||
uop_var = UOp.const(dtypes.int, a)
|
uop_var = UOp.const(dtypes.int, a)
|
||||||
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
|
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
|
||||||
ShapeTracker.from_shape((2, a)).to_uop()))
|
ShapeTracker.from_shape((2, a)).to_uop()))
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import gzip, unittest
|
import gzip, unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tinygrad import Variable
|
||||||
from tinygrad.helpers import Context, ContextVar
|
from tinygrad.helpers import Context, ContextVar
|
||||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
||||||
from tinygrad.tensor import get_shape
|
from tinygrad.tensor import get_shape
|
||||||
from tinygrad.codegen.lowerer import get_contraction
|
from tinygrad.codegen.lowerer import get_contraction
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode
|
from tinygrad.shape.symbolic import NumNode
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
VARIABLE = ContextVar("VARIABLE", 0)
|
VARIABLE = ContextVar("VARIABLE", 0)
|
||||||
|
|
|
@ -4,7 +4,8 @@ import numpy as np
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.helpers import prod
|
from tinygrad.helpers import prod
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||||
from tinygrad.shape.symbolic import Variable, NumNode
|
from tinygrad import Variable
|
||||||
|
from tinygrad.shape.symbolic import NumNode
|
||||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||||
from tinygrad.codegen.uopgraph import sym
|
from tinygrad.codegen.uopgraph import sym
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
||||||
from tinygrad.helpers import prod
|
from tinygrad.helpers import prod
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad import Variable
|
||||||
from test.unit.test_shapetracker import shapetracker_getitem
|
from test.unit.test_shapetracker import shapetracker_getitem
|
||||||
|
|
||||||
class MultiShapeTracker:
|
class MultiShapeTracker:
|
||||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
||||||
from tinygrad.shape.symbolic import Variable
|
from tinygrad import Variable
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||||
|
|
|
@ -4,7 +4,8 @@ if int(os.getenv("TYPED", "0")):
|
||||||
install_import_hook(__name__)
|
install_import_hook(__name__)
|
||||||
from tinygrad.tensor import Tensor # noqa: F401
|
from tinygrad.tensor import Tensor # noqa: F401
|
||||||
from tinygrad.engine.jit import TinyJit # noqa: F401
|
from tinygrad.engine.jit import TinyJit # noqa: F401
|
||||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
from tinygrad.ops import UOp
|
||||||
|
Variable = UOp.variable
|
||||||
from tinygrad.dtype import dtypes # noqa: F401
|
from tinygrad.dtype import dtypes # noqa: F401
|
||||||
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
|
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
|
||||||
from tinygrad.device import Device # noqa: F401
|
from tinygrad.device import Device # noqa: F401
|
||||||
|
|
|
@ -603,7 +603,8 @@ class Kernel:
|
||||||
# kernel name (before late upcast)
|
# kernel name (before late upcast)
|
||||||
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
|
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
|
||||||
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
|
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
|
||||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
colored('_', 'BLACK').join([colored(str(x.render() if isinstance(x, UOp) else x), c) \
|
||||||
|
for x,c in zip(self.full_shape, self.colors())])
|
||||||
|
|
||||||
# name the function something unique
|
# name the function something unique
|
||||||
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
||||||
|
|
|
@ -185,6 +185,7 @@ class UOp(MathTrait):
|
||||||
__slots__ = ["op", "dtype", "src", "arg"]
|
__slots__ = ["op", "dtype", "src", "arg"]
|
||||||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||||
# TODO: instant check rules here make debugging easier
|
# TODO: instant check rules here make debugging easier
|
||||||
|
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||||
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
|
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
|
||||||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||||
|
@ -276,14 +277,35 @@ class UOp(MathTrait):
|
||||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
||||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||||
|
|
||||||
|
# *** Variable stuff ***
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||||
|
@property
|
||||||
|
def expr(self):
|
||||||
|
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||||
|
return self.arg[0]
|
||||||
|
def bind(self, val:int):
|
||||||
|
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||||
|
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
|
||||||
|
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
|
||||||
def unbind(self) -> Tuple[Variable, int]:
|
def unbind(self) -> Tuple[Variable, int]:
|
||||||
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
|
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
|
||||||
from tinygrad.shape.symbolic import Variable
|
return self.src[0], self.src[1].arg
|
||||||
return cast(Variable, self.src[0]), self.src[1].arg
|
|
||||||
@property
|
@property
|
||||||
def val(self) -> int: return self.unbind()[1]
|
def val(self) -> int: return self.unbind()[1]
|
||||||
|
def vars(self) -> Set[UOp]:
|
||||||
|
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
|
||||||
|
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||||
|
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||||
|
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||||
|
def variables(self) -> List[Variable]:
|
||||||
|
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||||
|
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||||
|
|
||||||
# TODO: this is context rewrite
|
# TODO: this is context rewrite
|
||||||
def substitute(self, dvars:Dict[UOp, UOp]):
|
def substitute(self, dvars:Dict[UOp, UOp]):
|
||||||
if self in dvars: return dvars[self]
|
if self in dvars: return dvars[self]
|
||||||
|
@ -299,15 +321,6 @@ class UOp(MathTrait):
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def full_shape(self) -> Tuple[sint, ...]:
|
def full_shape(self) -> Tuple[sint, ...]:
|
||||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||||
def vars(self) -> Set[UOp]:
|
|
||||||
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
|
|
||||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
|
||||||
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
|
||||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
|
||||||
def variables(self) -> List[Variable]:
|
|
||||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
|
||||||
from tinygrad.shape.symbolic import Variable
|
|
||||||
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
|
|
||||||
def const_factor(self) -> int:
|
def const_factor(self) -> int:
|
||||||
"""largest known int that divides self"""
|
"""largest known int that divides self"""
|
||||||
if self.op is UOps.CONST: return self.arg
|
if self.op is UOps.CONST: return self.arg
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Program:
|
||||||
if not self._ran_post_init and self.uops is not None:
|
if not self._ran_post_init and self.uops is not None:
|
||||||
# single pass through the uops
|
# single pass through the uops
|
||||||
for u in self.uops:
|
for u in self.uops:
|
||||||
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2]))
|
if u.op is UOps.DEFINE_VAR: self.vars.append(u)
|
||||||
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||||
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
||||||
if u.op is UOps.SPECIAL:
|
if u.op is UOps.SPECIAL:
|
||||||
|
|
|
@ -1,28 +1,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Union, Optional, Dict, cast
|
from typing import Union, Optional, Dict
|
||||||
from tinygrad.dtype import dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.ops import UOp, UOps, exec_alu, ConstType
|
from tinygrad.ops import UOp, UOps, exec_alu
|
||||||
|
|
||||||
sint = Union[int, UOp]
|
sint = Union[int, UOp]
|
||||||
|
|
||||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||||
class Variable(UOp):
|
Variable = UOp
|
||||||
def __reduce__(self): return Variable, self.arg
|
|
||||||
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
|
|
||||||
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
|
||||||
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
|
|
||||||
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
|
||||||
def bind(self, val:int):
|
|
||||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
|
|
||||||
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
|
|
||||||
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
|
|
||||||
@property
|
|
||||||
def expr(self): return self.arg[0]
|
|
||||||
|
|
||||||
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
||||||
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
|
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
|
||||||
if uop.op == UOps.CONST: return uop.arg
|
if uop.op == UOps.CONST: return uop.arg
|
||||||
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)]
|
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[uop]
|
||||||
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
|
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
|
||||||
if uop.op == UOps.ALU:
|
if uop.op == UOps.ALU:
|
||||||
src_values = [sym_infer(src, var_vals) for src in uop.src]
|
src_values = [sym_infer(src, var_vals) for src in uop.src]
|
||||||
|
|
|
@ -164,7 +164,7 @@ class View:
|
||||||
|
|
||||||
# Merge dimensions in vm2 if required.
|
# Merge dimensions in vm2 if required.
|
||||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||||
idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||||
merged_size, merged_term = 1, NumNode(0)
|
merged_size, merged_term = 1, NumNode(0)
|
||||||
extents: List[Tuple[sint, UOp]] = []
|
extents: List[Tuple[sint, UOp]] = []
|
||||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||||
|
|
Loading…
Reference in New Issue