mirror of https://github.com/commaai/tinygrad.git
reorder a few things (#1915)
* reorder a few things * huh, that has to be there * move apply shapetracker * BufferOps * only for type checking
This commit is contained in:
parent
25a767cd5d
commit
c907efbf4a
|
@ -266,9 +266,9 @@ from tinygrad.tensor import Tensor
|
|||
result = Tensor(2).realize() + Tensor(3).realize()
|
||||
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.lazy import _replace_loadops
|
||||
from tinygrad.lazy import _replace_bufferops
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
op, _ = _replace_loadops(result.lazydata.op)
|
||||
op, _ = _replace_bufferops(result.lazydata.op)
|
||||
linearizer = Linearizer(op)
|
||||
linearizer.linearize()
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOps
|
|||
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.lazy import _replace_loadops
|
||||
from tinygrad.lazy import _replace_bufferops
|
||||
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
|
@ -31,7 +31,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = a[:-1] + a[1:]
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
@ -49,7 +49,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = a.expand([2]) + b.expand([2])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
@ -64,7 +64,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = Tensor.stack([a, b])
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
@ -80,7 +80,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
r = a * b
|
||||
ast = r.lazydata.op
|
||||
r = r.realize() # realize an output buffer
|
||||
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
|
||||
k.process()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
|
||||
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import dtypes
|
||||
|
||||
class TestFlopCounter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
|
||||
|
||||
def test_flops_add(self):
|
||||
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
@ -43,13 +43,13 @@ class Kernel:
|
|||
self.reduceop = reduceops[0] if reduceops else None
|
||||
|
||||
# create new shapetrackers inside this kernel, we will permute them
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
|
||||
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
|
||||
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
|
||||
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs)
|
||||
|
||||
# get earlybufs, before the one reduce op
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
|
||||
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
|
||||
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
|
||||
|
||||
# parameters
|
||||
|
|
|
@ -5,7 +5,7 @@ from collections import defaultdict
|
|||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
||||
from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer
|
||||
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
|
||||
|
@ -484,7 +484,7 @@ class Linearizer(OptimizedKernel):
|
|||
|
||||
def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
||||
if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op
|
||||
if x.op in ReduceOps and not do_reduce: return acc
|
||||
# MULACC fusion. TODO: this is copied from Interpreted
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Tuple, List, cast, Optional
|
||||
import itertools, math, os
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, LoadOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.codegen.kernel import Kernel, LocalBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
@ -172,7 +172,7 @@ class OptimizedKernel(Kernel):
|
|||
if getenv("TC", 1) != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
|
||||
isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local and \
|
||||
self.reduceop.src[0].src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[0].src[1].op == BufferOps.MEM and self.opts.has_local and \
|
||||
cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg.dtype == dtypes.half and cast(LazyOp, self.reduceop.src[0].src[0].src[1]).arg.dtype == dtypes.half:
|
||||
# HIP tensor cores are 16x16x16
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg)
|
||||
|
@ -232,7 +232,7 @@ class OptimizedKernel(Kernel):
|
|||
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.opts.device == "METAL" and os.uname().machine == "arm64"))
|
||||
if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
|
||||
self.reduceop.src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local:
|
||||
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM and self.opts.has_local:
|
||||
# METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
|
||||
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
|
||||
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)
|
||||
|
|
|
@ -6,7 +6,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
|
|||
import numpy as np
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
|
@ -63,15 +63,15 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
|
|||
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
|
||||
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
|
||||
|
||||
def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
replacements:Dict[LazyBuffer, LazyOp] = {}
|
||||
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
|
||||
for x in op.buffers:
|
||||
assert x.realized, "buffer isn't realized"
|
||||
if isinstance(x.realized, RawConst):
|
||||
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
|
||||
elif x.realized in realized_bufs:
|
||||
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
|
||||
else:
|
||||
raise NotImplementedError(f"not handled {x}")
|
||||
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
|
||||
|
@ -138,8 +138,6 @@ class LazyBuffer:
|
|||
elif self.optype is MovementOps: self.realized = self.op.src[0].realize().realized
|
||||
# run the ast if we still have to, and log the op
|
||||
if not self.realized:
|
||||
for x in self.op.buffers: x.realize()
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
if self.op.op == MovementOps.RESHAPE:
|
||||
|
@ -148,8 +146,11 @@ class LazyBuffer:
|
|||
else:
|
||||
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
|
||||
self.dtype = dtypes.float32
|
||||
|
||||
# realize the past and exec the AST
|
||||
for x in self.op.buffers: x.realize()
|
||||
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in self.op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
op, realized_bufs = _replace_loadops(self.op)
|
||||
op, realized_bufs = _replace_bufferops(self.op)
|
||||
self.realized = Device[self.device].exec_ast(op, output=self, inputs=realized_bufs, var_vals=self.var_vals, **self._device_extra_args())
|
||||
|
||||
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
|
||||
|
@ -216,16 +217,6 @@ class LazyBuffer:
|
|||
|
||||
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
|
@ -238,6 +229,16 @@ class LazyBuffer:
|
|||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
# MovementOps aren't stacked any more, they each have one parent, find the root
|
||||
root = get_movementroot(self)
|
||||
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
|
||||
return root.reshape(st.shape)
|
||||
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))
|
||||
|
|
|
@ -84,27 +84,6 @@ class Sigmoid(Function):
|
|||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
||||
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
|
@ -157,6 +136,27 @@ class Where(Function):
|
|||
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, new_shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
|
||||
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
|
||||
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
|
|
|
@ -3,9 +3,7 @@ import time, importlib, inspect, functools, pathlib
|
|||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from dataclasses import dataclass
|
||||
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
|
@ -15,11 +13,17 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto()
|
|||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); BUFFER = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]]
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
|
@ -101,28 +105,6 @@ Device = _Device()
|
|||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
def apply_shapetracker(fxn_for_op, ret, st):
|
||||
for v in st.views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
# then, we make it the correct shape
|
||||
# then, we apply permutations
|
||||
# TODO: don't use as_strided
|
||||
ret = fxn_for_op[MovementOps.AS_STRIDED](ret, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
if any(x != (0,0) for x in pre_expand_pads):
|
||||
ret = fxn_for_op[MovementOps.PAD](ret, pre_expand_pads)
|
||||
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
||||
# then, we do any expands
|
||||
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): ret = fxn_for_op[MovementOps.EXPAND](ret, real_shape)
|
||||
# lastly, we apply post expand pads
|
||||
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): ret = fxn_for_op[MovementOps.PAD](ret, post_expand_pads)
|
||||
return ret
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.to_underlying = buffer, fxn_for_op, to_underlying
|
||||
|
@ -131,9 +113,11 @@ class Interpreted:
|
|||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
||||
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
|
||||
if ast.op == BufferOps.MEM and BufferOps.MEM not in self.fxn_for_op:
|
||||
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
|
||||
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st))
|
||||
buf = self.to_underlying(inputs[ast.arg.idx-1])
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): buf = self.fxn_for_op[mop](buf, arg)
|
||||
return self.from_underlying(buf)
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
created_context = context is None
|
||||
|
@ -162,7 +146,7 @@ class FlopCounter:
|
|||
self.flops, ret = 0, self.flops
|
||||
return ret
|
||||
shape_fxn_for_op: Dict[Op, Callable] = {
|
||||
LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
|
||||
BufferOps.MEM: lambda arg: (arg.st.shape, arg.dtype, 0), BufferOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
|
||||
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
|
||||
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
|
||||
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
|
||||
|
@ -241,7 +225,7 @@ class Compiled:
|
|||
for i,a in enumerate(inputs):
|
||||
# TODO: if this is contiguous it's fine
|
||||
if a == output.realized:
|
||||
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1):
|
||||
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1):
|
||||
output.realized = None
|
||||
break
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, cast
|
||||
from tinygrad.ops import MovementOps
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
|
@ -81,6 +82,29 @@ class ShapeTracker:
|
|||
# this is the real size (ish)
|
||||
def size(self): return self.views[-1].size()
|
||||
|
||||
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
|
||||
to_apply:List[Tuple[MovementOps, Tuple]] = []
|
||||
for v in self.views:
|
||||
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
|
||||
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
|
||||
# first, we apply the offset
|
||||
# then, we make it the correct shape
|
||||
# then, we apply permutations
|
||||
# TODO: don't use as_strided
|
||||
to_apply.append((MovementOps.AS_STRIDED, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset)))
|
||||
# then, we apply pre expand pads
|
||||
if v.mask is not None:
|
||||
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
|
||||
if any(x != (0,0) for x in pre_expand_pads):
|
||||
to_apply.append((MovementOps.PAD, pre_expand_pads))
|
||||
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
|
||||
# then, we do any expands
|
||||
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
|
||||
# lastly, we apply post expand pads
|
||||
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
|
||||
return to_apply
|
||||
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> sint:
|
||||
|
|
Loading…
Reference in New Issue