diff --git a/docs/abstractions.py b/docs/abstractions.py index f3d52d9c..e527bae6 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -266,9 +266,9 @@ from tinygrad.tensor import Tensor result = Tensor(2).realize() + Tensor(3).realize() # use the real Linearizer to linearize 2+3 -from tinygrad.lazy import _replace_loadops +from tinygrad.lazy import _replace_bufferops from tinygrad.codegen.linearizer import Linearizer -op, _ = _replace_loadops(result.lazydata.op) +op, _ = _replace_bufferops(result.lazydata.op) linearizer = Linearizer(op) linearizer.linearize() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 9da77cdd..b011d867 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -5,7 +5,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOps from tinygrad.ops import Compiled, Device, MovementOps, LazyOp from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector -from tinygrad.lazy import _replace_loadops +from tinygrad.lazy import _replace_bufferops class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): @@ -31,7 +31,7 @@ class TestLinearizer(unittest.TestCase): r = a[:-1] + a[1:] ast = r.lazydata.op r = r.realize() # realize an output buffer - k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts) + k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts) k.process() k.upcast() k.linearize() @@ -49,7 +49,7 @@ class TestLinearizer(unittest.TestCase): r = a.expand([2]) + b.expand([2]) ast = r.lazydata.op r = r.realize() # realize an output buffer - k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts) + k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts) k.process() k.upcast() k.linearize() @@ -64,7 +64,7 @@ class TestLinearizer(unittest.TestCase): r = Tensor.stack([a, b]) ast = r.lazydata.op r = r.realize() # realize an output buffer - k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts) + k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts) k.process() k.upcast() k.linearize() @@ -80,7 +80,7 @@ class TestLinearizer(unittest.TestCase): r = a * b ast = r.lazydata.op r = r.realize() # realize an output buffer - k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts) + k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts) k.process() k.linearize() num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) diff --git a/test/unit/test_flopcounter.py b/test/unit/test_flopcounter.py index bb9ae400..22f91ee3 100644 --- a/test/unit/test_flopcounter.py +++ b/test/unit/test_flopcounter.py @@ -1,13 +1,13 @@ #!/usr/bin/env python import unittest -from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer +from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.helpers import dtypes class TestFlopCounter(unittest.TestCase): def setUp(self): - self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,)))) - self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,)))) + self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,)))) + self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,)))) def test_flops_add(self): op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index ba63907f..625087fd 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,6 +1,6 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict import itertools -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sint @@ -43,13 +43,13 @@ class Kernel: self.reduceop = reduceops[0] if reduceops else None # create new shapetrackers inside this kernel, we will permute them - self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps]) + self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps]) self.sts: List[ShapeTracker] = [x.st for x in self.bufs] self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs) # get earlybufs, before the one reduce op - self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else [] + self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else [] self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0 # parameters diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 42792edb..f69ed2c8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -5,7 +5,7 @@ from collections import defaultdict from enum import Enum, auto from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same -from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer +from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename @@ -484,7 +484,7 @@ class Linearizer(OptimizedKernel): def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]: if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER - if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg] + if x.op in BufferOps: return loaded_buffers[x.arg] if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op if x.op in ReduceOps and not do_reduce: return acc # MULACC fusion. TODO: this is copied from Interpreted diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 499ce25b..07f498eb 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -1,7 +1,7 @@ from typing import Tuple, List, cast, Optional import itertools, math, os from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes -from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, LoadOps +from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps from tinygrad.codegen.kernel import Kernel, LocalBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -172,7 +172,7 @@ class OptimizedKernel(Kernel): if getenv("TC", 1) != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \ isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \ isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \ - self.reduceop.src[0].src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local and \ + self.reduceop.src[0].src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[0].src[1].op == BufferOps.MEM and self.opts.has_local and \ cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg.dtype == dtypes.half and cast(LazyOp, self.reduceop.src[0].src[0].src[1]).arg.dtype == dtypes.half: # HIP tensor cores are 16x16x16 buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg) @@ -232,7 +232,7 @@ class OptimizedKernel(Kernel): tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.opts.device == "METAL" and os.uname().machine == "arm64")) if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \ isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \ - self.reduceop.src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local: + self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM and self.opts.has_local: # METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg) buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index bdf7d070..36e6ffbf 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -6,7 +6,7 @@ from weakref import ref, WeakSet, WeakValueDictionary import numpy as np from tinygrad.graph import log_op from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts -from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer +from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import Variable, sint @@ -63,15 +63,15 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp: ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs)) return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast -def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: +def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: replacements:Dict[LazyBuffer, LazyOp] = {} realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)]) for x in op.buffers: assert x.realized, "buffer isn't realized" if isinstance(x.realized, RawConst): - replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify())) + replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify())) elif x.realized in realized_bufs: - replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify())) + replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify())) else: raise NotImplementedError(f"not handled {x}") return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs @@ -138,8 +138,6 @@ class LazyBuffer: elif self.optype is MovementOps: self.realized = self.op.src[0].realize().realized # run the ast if we still have to, and log the op if not self.realized: - for x in self.op.buffers: x.realize() - # HACK: image shape can be wrong, hot cast it back to a normal float if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): if self.op.op == MovementOps.RESHAPE: @@ -148,8 +146,11 @@ class LazyBuffer: else: self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False)) self.dtype = dtypes.float32 + + # realize the past and exec the AST + for x in self.op.buffers: x.realize() self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in self.op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) - op, realized_bufs = _replace_loadops(self.op) + op, realized_bufs = _replace_bufferops(self.op) self.realized = Device[self.device].exec_ast(op, output=self, inputs=realized_bufs, var_vals=self.var_vals, **self._device_extra_args()) assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}" @@ -216,16 +217,6 @@ class LazyBuffer: return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) - def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: - if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: - return self.op.replace_with_movement_ops([(op, arg)]) - if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: - # MovementOps aren't stacked any more, they each have one parent, find the root - root = get_movementroot(self) - if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape): - return root.reshape(st.shape) - return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals) - def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) @@ -238,6 +229,16 @@ class LazyBuffer: def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape) + def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: + if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: + return self.op.replace_with_movement_ops([(op, arg)]) + if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: + # MovementOps aren't stacked any more, they each have one parent, find the root + root = get_movementroot(self) + if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape): + return root.reshape(st.shape) + return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals) + def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: if self.shape == arg: return self new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int)) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 0bbb9116..6c24d246 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -84,27 +84,6 @@ class Sigmoid(Function): def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output) -# ************* reduce ops ************* - -class Sum(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape - return x.r(ReduceOps.SUM, new_shape) - - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.expand(self.input_shape) - -class Max(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) - return self.ret - - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) - div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) - return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) - # ************* binary ops ************* class Less(Function): @@ -157,6 +136,27 @@ class Where(Function): self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None +# ************* reduce ops ************* + +class Sum(Function): + def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: + self.input_shape = x.shape + return x.r(ReduceOps.SUM, new_shape) + + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return grad_output.expand(self.input_shape) + +class Max(Function): + def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: + self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) + return self.ret + + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + # 1s in locations where the max was chosen (can be two locations) + max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) + div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) + # ************* movement ops ************* # NOTE: this is sum in reverse diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e29b9ff1..bfb71a93 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -3,9 +3,7 @@ import time, importlib, inspect, functools, pathlib from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored -from tinygrad.shape.shapetracker import ShapeTracker from dataclasses import dataclass -if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly @@ -15,11 +13,17 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto() class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 +class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702 +# Ops below this line are not allowed in ASTs class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); BUFFER = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 -Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps] -OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]] +Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps] +OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] + +if TYPE_CHECKING: + from tinygrad.lazy import LazyBuffer + from tinygrad.shape.shapetracker import ShapeTracker @dataclass(frozen=True) class MemBuffer: @@ -101,28 +105,6 @@ Device = _Device() # **************** for Interpreted Buffers **************** -def apply_shapetracker(fxn_for_op, ret, st): - for v in st.views: - real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape - real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) - # first, we apply the offset - # then, we make it the correct shape - # then, we apply permutations - # TODO: don't use as_strided - ret = fxn_for_op[MovementOps.AS_STRIDED](ret, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset)) - # then, we apply pre expand pads - if v.mask is not None: - pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) - if any(x != (0,0) for x in pre_expand_pads): - ret = fxn_for_op[MovementOps.PAD](ret, pre_expand_pads) - real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) - # then, we do any expands - if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): ret = fxn_for_op[MovementOps.EXPAND](ret, real_shape) - # lastly, we apply post expand pads - if v.mask is not None and any(x != (0,0) for x in post_expand_pads): ret = fxn_for_op[MovementOps.PAD](ret, post_expand_pads) - return ret - class Interpreted: def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None): self.buffer, self.fxn_for_op, self.to_underlying = buffer, fxn_for_op, to_underlying @@ -131,9 +113,11 @@ class Interpreted: self.codegen = None def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs): - if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op: + if ast.op == BufferOps.MEM and BufferOps.MEM not in self.fxn_for_op: assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch" - return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st)) + buf = self.to_underlying(inputs[ast.arg.idx-1]) + for mop,arg in ast.arg.st.to_movement_ops(): buf = self.fxn_for_op[mop](buf, arg) + return self.from_underlying(buf) if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) created_context = context is None @@ -162,7 +146,7 @@ class FlopCounter: self.flops, ret = 0, self.flops return ret shape_fxn_for_op: Dict[Op, Callable] = { - LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0), + BufferOps.MEM: lambda arg: (arg.st.shape, arg.dtype, 0), BufferOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0), UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST}, **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, @@ -241,7 +225,7 @@ class Compiled: for i,a in enumerate(inputs): # TODO: if this is contiguous it's fine if a == output.realized: - if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1): + if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1): output.realized = None break diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index d58bf95c..a00b4d2b 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -3,6 +3,7 @@ from __future__ import annotations import functools from dataclasses import dataclass from typing import Tuple, List, Optional, cast +from tinygrad.ops import MovementOps from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint from tinygrad.shape.view import View @@ -81,6 +82,29 @@ class ShapeTracker: # this is the real size (ish) def size(self): return self.views[-1].size() + def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: + to_apply:List[Tuple[MovementOps, Tuple]] = [] + for v in self.views: + real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape + real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0) + # first, we apply the offset + # then, we make it the correct shape + # then, we apply permutations + # TODO: don't use as_strided + to_apply.append((MovementOps.AS_STRIDED, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset))) + # then, we apply pre expand pads + if v.mask is not None: + pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) + post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides)) + if any(x != (0,0) for x in pre_expand_pads): + to_apply.append((MovementOps.PAD, pre_expand_pads)) + real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads)) + # then, we do any expands + if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) + # lastly, we apply post expand pads + if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) + return to_apply + # these are multiview strides, value is None if it's not a simple strided dimension # TODO: this can be shared code between simplify and merge_views def real_offset(self) -> sint: