reorder a few things (#1915)

* reorder a few things

* huh, that has to be there

* move apply shapetracker

* BufferOps

* only for type checking
This commit is contained in:
George Hotz 2023-09-25 10:17:21 +08:00 committed by GitHub
parent 25a767cd5d
commit c907efbf4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 87 deletions

View File

@ -266,9 +266,9 @@ from tinygrad.tensor import Tensor
result = Tensor(2).realize() + Tensor(3).realize()
# use the real Linearizer to linearize 2+3
from tinygrad.lazy import _replace_loadops
from tinygrad.lazy import _replace_bufferops
from tinygrad.codegen.linearizer import Linearizer
op, _ = _replace_loadops(result.lazydata.op)
op, _ = _replace_bufferops(result.lazydata.op)
linearizer = Linearizer(op)
linearizer.linearize()

View File

@ -5,7 +5,7 @@ from tinygrad.codegen.linearizer import Linearizer, UOps
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.lazy import _replace_loadops
from tinygrad.lazy import _replace_bufferops
class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
@ -31,7 +31,7 @@ class TestLinearizer(unittest.TestCase):
r = a[:-1] + a[1:]
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@ -49,7 +49,7 @@ class TestLinearizer(unittest.TestCase):
r = a.expand([2]) + b.expand([2])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@ -64,7 +64,7 @@ class TestLinearizer(unittest.TestCase):
r = Tensor.stack([a, b])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
@ -80,7 +80,7 @@ class TestLinearizer(unittest.TestCase):
r = a * b
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_bufferops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])

View File

@ -1,13 +1,13 @@
#!/usr/bin/env python
import unittest
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, BufferOps, MemBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import dtypes
class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf0 = LazyOp(BufferOps.MEM, (), MemBuffer(1, dtypes.float32, ShapeTracker.from_shape((4,))))
self.buf1 = LazyOp(BufferOps.MEM, (), MemBuffer(2, dtypes.float32, ShapeTracker.from_shape((4,))))
def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)

View File

@ -1,6 +1,6 @@
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
import itertools
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
@ -43,13 +43,13 @@ class Kernel:
self.reduceop = reduceops[0] if reduceops else None
# create new shapetrackers inside this kernel, we will permute them
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
self.bufs = [MemBuffer(0, self.info.dtype, ShapeTracker.from_shape(self.info.shape))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in BufferOps])
self.sts: List[ShapeTracker] = [x.st for x in self.bufs]
self.mem_estimate: int = sum(x.dtype.itemsize*x.st.size() for x in self.bufs)
# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps] if self.reduceop else []
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
# parameters

View File

@ -5,7 +5,7 @@ from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer
from tinygrad.ops import LazyOp, UnaryOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
@ -484,7 +484,7 @@ class Linearizer(OptimizedKernel):
def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted

View File

@ -1,7 +1,7 @@
from typing import Tuple, List, cast, Optional
import itertools, math, os
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType, dtypes
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, LoadOps
from tinygrad.ops import ReduceOps, BinaryOps, UnaryOps, LazyOp, BufferOps
from tinygrad.codegen.kernel import Kernel, LocalBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@ -172,7 +172,7 @@ class OptimizedKernel(Kernel):
if getenv("TC", 1) != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
isinstance(self.reduceop.src[0].src[0], LazyOp) and self.reduceop.src[0].src[0].op == BinaryOps.MUL and \
self.reduceop.src[0].src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local and \
self.reduceop.src[0].src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[0].src[1].op == BufferOps.MEM and self.opts.has_local and \
cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg.dtype == dtypes.half and cast(LazyOp, self.reduceop.src[0].src[0].src[1]).arg.dtype == dtypes.half:
# HIP tensor cores are 16x16x16
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0].src[0]).arg)
@ -232,7 +232,7 @@ class OptimizedKernel(Kernel):
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (self.opts.device == "METAL" and os.uname().machine == "arm64"))
if tensor_cores_allowed and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == BinaryOps.MUL and \
self.reduceop.src[0].src[0].op == LoadOps.BUFFER and self.reduceop.src[0].src[1].op == LoadOps.BUFFER and self.opts.has_local:
self.reduceop.src[0].src[0].op == BufferOps.MEM and self.reduceop.src[0].src[1].op == BufferOps.MEM and self.opts.has_local:
# METAL tensor cores are 8x8x8, with 2 elements per thread in the 32 thread warp
buf0 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[0]).arg)
buf1 = self.bufs.index(cast(LazyOp, self.reduceop.src[0].src[1]).arg)

View File

@ -6,7 +6,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@ -63,15 +63,15 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
ast = self.op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
def _replace_loadops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
realized_bufs = dedup([x.realized for x in op.buffers if buf_is_kernel_arg(x)])
for x in op.buffers:
assert x.realized, "buffer isn't realized"
if isinstance(x.realized, RawConst):
replacements[x] = LazyOp(LoadOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, x.st.simplify()))
elif x.realized in realized_bufs:
replacements[x] = LazyOp(LoadOps.BUFFER, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(realized_bufs.index(x.realized)+1, x.realized.dtype, x.st.simplify()))
else:
raise NotImplementedError(f"not handled {x}")
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), realized_bufs
@ -138,8 +138,6 @@ class LazyBuffer:
elif self.optype is MovementOps: self.realized = self.op.src[0].realize().realized
# run the ast if we still have to, and log the op
if not self.realized:
for x in self.op.buffers: x.realize()
# HACK: image shape can be wrong, hot cast it back to a normal float
if isinstance(self.dtype, ImageDType) and self.optype != MovementOps and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
if self.op.op == MovementOps.RESHAPE:
@ -148,8 +146,11 @@ class LazyBuffer:
else:
self.op = LazyOp(UnaryOps.CAST, (self.op,), (dtypes.float32, False))
self.dtype = dtypes.float32
# realize the past and exec the AST
for x in self.op.buffers: x.realize()
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in self.op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
op, realized_bufs = _replace_loadops(self.op)
op, realized_bufs = _replace_bufferops(self.op)
self.realized = Device[self.device].exec_ast(op, output=self, inputs=realized_bufs, var_vals=self.var_vals, **self._device_extra_args())
assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}"
@ -216,16 +217,6 @@ class LazyBuffer:
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
@ -238,6 +229,16 @@ class LazyBuffer:
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int))

View File

@ -84,27 +84,6 @@ class Sigmoid(Function):
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.e(BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)).e(BinaryOps.MUL, grad_output)
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
# ************* binary ops *************
class Less(Function):
@ -157,6 +136,27 @@ class Where(Function):
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)
class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))
# ************* movement ops *************
# NOTE: this is sum in reverse

View File

@ -3,9 +3,7 @@ import time, importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.shape.shapetracker import ShapeTracker
from dataclasses import dataclass
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
@ -15,11 +13,17 @@ class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto()
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); BUFFER = auto() # noqa: E702
class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps]]
Op = Union[UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, BufferOps]
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
@dataclass(frozen=True)
class MemBuffer:
@ -101,28 +105,6 @@ Device = _Device()
# **************** for Interpreted Buffers ****************
def apply_shapetracker(fxn_for_op, ret, st):
for v in st.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
ret = fxn_for_op[MovementOps.AS_STRIDED](ret, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
ret = fxn_for_op[MovementOps.PAD](ret, pre_expand_pads)
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): ret = fxn_for_op[MovementOps.EXPAND](ret, real_shape)
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): ret = fxn_for_op[MovementOps.PAD](ret, post_expand_pads)
return ret
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], to_underlying=lambda x: x._buf, from_underlying=None):
self.buffer, self.fxn_for_op, self.to_underlying = buffer, fxn_for_op, to_underlying
@ -131,9 +113,11 @@ class Interpreted:
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
if ast.op == LoadOps.BUFFER and LoadOps.BUFFER not in self.fxn_for_op:
if ast.op == BufferOps.MEM and BufferOps.MEM not in self.fxn_for_op:
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
return self.from_underlying(apply_shapetracker(self.fxn_for_op, self.to_underlying(inputs[ast.arg.idx-1]), ast.arg.st))
buf = self.to_underlying(inputs[ast.arg.idx-1])
for mop,arg in ast.arg.st.to_movement_ops(): buf = self.fxn_for_op[mop](buf, arg)
return self.from_underlying(buf)
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
created_context = context is None
@ -162,7 +146,7 @@ class FlopCounter:
self.flops, ret = 0, self.flops
return ret
shape_fxn_for_op: Dict[Op, Callable] = {
LoadOps.BUFFER: lambda arg: (arg.st.shape, arg.dtype, 0), LoadOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
BufferOps.MEM: lambda arg: (arg.st.shape, arg.dtype, 0), BufferOps.CONST: lambda arg: (arg.st.shape, arg.dtype, 0),
UnaryOps.CAST: lambda self,arg: (self.shape, arg[0], self.consume_flops()), # cast uses no flops
**{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST},
**{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps},
@ -241,7 +225,7 @@ class Compiled:
for i,a in enumerate(inputs):
# TODO: if this is contiguous it's fine
if a == output.realized:
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == LoadOps.BUFFER and x.arg.idx == i+1):
if any(not x.arg.st.contiguous for x in ast.get_lazyops() if x.op == BufferOps.MEM and x.arg.idx == i+1):
output.realized = None
break

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Tuple, List, Optional, cast
from tinygrad.ops import MovementOps
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, sint
from tinygrad.shape.view import View
@ -81,6 +82,29 @@ class ShapeTracker:
# this is the real size (ish)
def size(self): return self.views[-1].size()
def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]:
to_apply:List[Tuple[MovementOps, Tuple]] = []
for v in self.views:
real_shape = tuple(y-x for x,y in v.mask) if v.mask else v.shape
real_offset = v.offset + (sum(x*st for (x,_),st in zip(v.mask, v.strides)) if v.mask else 0)
# first, we apply the offset
# then, we make it the correct shape
# then, we apply permutations
# TODO: don't use as_strided
to_apply.append((MovementOps.AS_STRIDED, ([s if st != 0 else 1 for s,st in zip(real_shape, v.strides)], v.strides, real_offset)))
# then, we apply pre expand pads
if v.mask is not None:
pre_expand_pads = tuple((x,s-y) if st != 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
post_expand_pads = tuple((x,s-y) if st == 0 else (0,0) for (x,y),s,st in zip(v.mask, v.shape, v.strides))
if any(x != (0,0) for x in pre_expand_pads):
to_apply.append((MovementOps.PAD, pre_expand_pads))
real_shape = tuple(x+s[0]+s[1] for x,s in zip(real_shape, pre_expand_pads))
# then, we do any expands
if any(s != 1 and st == 0 for s,st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape))
# lastly, we apply post expand pads
if v.mask is not None and any(x != (0,0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads))
return to_apply
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> sint: