ods777/tinygrad_repo/tinygrad/lazy.py

348 lines
20 KiB
Python

from __future__ import annotations
import sys, operator, math, functools
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, dedup, merge_dicts, all_int
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.ops_cpu import RawNumpyBuffer
# lazy can recurse a lot
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZYCACHE = getenv("LAZYCACHE", 1)
# TODO: movement ops that only change shape are really nops. treat them as such
REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
PUSH_RESHAPES = OPT>=4
# **** ast fixing functions ****
def _ast_reduceops(op:LazyOp) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = op.src[0]
if not src.realized:
assert isinstance(src.op, LazyOp), "if not src.realized, then src.op must be a LazyOp"
if MERGE_ELEMENTWISE_INTO_REDUCE and src.optype is BinaryOps and len(src.children) <= 1: src = src.op
return LazyOp(op.op, (src,), op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(op:LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {x:None for x in op.buffers}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and not x.realized and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1]
intermediate_shape: Tuple[sint, ...] = shape
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop
if psrc[1].optype == ReduceOps:
top = _ast_reduceops(psrc[1].op)
real_srcs[psrc[0]] = top
real_srcs.update({x:x for x in top.buffers}) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if psrc[0].shape != psrc[1].shape:
intermediate_shape = psrc[1].shape
assert psrc[0].shape == shape, f"shape mismatch {psrc[0].shape} != {shape}"
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None: real_srcs[x] = x.reshape(intermediate_shape)
# NOTE: cast the type to remove the Optional
ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
return LazyOp(MovementOps.RESHAPE, (ast, ), shape) if intermediate_shape != shape else ast
def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
replacements:Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
st = x.st.simplify().unbind()
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
else:
raise NotImplementedError(f"not handled {x}")
return (op.src[0] if op.op == MovementOps.RESHAPE else op).map_buffers(replacements), base_bufs
# **** lazy operations ****
def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(LazyBuffer, root.op.src[0])) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
# fromcpu aren't cached
if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, base=base)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
if wop in lazycache:
for x in op.buffers: x.children.add(lazycache[wop])
return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
return ret
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[RawBuffer] = src
self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet = WeakSet()
self.views: WeakSet = WeakSet()
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
if op is not None:
self.op: LazyOp = op
for x in op.buffers: x.children.add(self)
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
self._base = base
if base: base.views.add(self)
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
@property
def base(self): return self._base if self._base is not None else self
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
@property
def realized(self): return self.base._realized
@realized.setter
def realized(self, val):
assert self._base is None, "no setting realized of based LazyBuffers"
self._realized = val
@property
def dtype(self): return self.base._dtype
@dtype.setter
def dtype(self, val):
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st)
return (self.dtype, self.op.op, self.st)
def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@property
def buffers(self) -> Tuple[LazyBuffer, ...]: return (self,)
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
# *** scheduling ***
def schedule(self, seen=None) -> List[ScheduleItem]:
if seen is None: seen = set()
if self in seen or self.realized or self.is_unrealized_const(): return []
seen.add(self)
if self.base != self: return self.base.schedule(seen)
# rewrite unbased CONTIGUOUS into UnaryOps.NOOP
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
# schedule the past
ret = []
for x in op.buffers: ret += x.schedule(seen)
var_vals = dict(sorted(merge_dicts([self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
return ret + [ScheduleItem(op, self, tuple(base_bufs), {k:var_vals[k] for k in vars_from_ast(op)})]
# *** creation/special ops ***
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype)
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def copy_to_device(self, device:str) -> LazyBuffer:
# back off a FROM if it's a double FROM
if not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device: return cast(LazyBuffer, self.op.src[0])
return LazyBuffer.loadop(LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous())
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST)
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const():
# this will turn into nothing, it's based and a copy
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
return create_lazybuffer(self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base)
# real contiguous, this will turn into a UnaryOps.NOOP
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x))
def cast(self, dtype:DType, bitcast:bool=False):
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
# *** elementwise ops ***
def e(self:LazyBuffer, op:Union[UnaryOps, BinaryOps, TernaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
# srcs includes self
srcs = (self,)+srcs
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs)
# get outputs now
out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0]
# push all contiguous to the end of BinaryOps. kernels 198 -> 196
if PUSH_CONTIGUOUS and any(not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1:
x.op.src[0].children.discard(x)
new_srcs.append(cast(LazyBuffer, x.op.src[0]))
else:
new_srcs.append(x)
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS:
# remove the buffers from any (childless) BinaryOps that feed into this
_srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore
# TODO: needs general merge limiting
if out_device != "WEBGPU" or len(dedup([x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const()])) < 7: srcs = _srcs # type: ignore
return create_lazybuffer(out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype)
# *** reduce ops ***
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if not all_int(self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
# *** movement ops ***
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children:
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
root = get_movementroot(self)
if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape):
return root.reshape(st.shape)
return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape(arg)
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all(b == 0 and e == 0 for b,e in arg): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg)
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
if SHUFFLE_MOVEMENT_OPS and not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple([self.op.arg[a] for a in arg])
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.permute(arg).r(cast(ReduceOps, rop), narg)
# move permutes before expands (always, this is safe)
if self.op.op == MovementOps.EXPAND:
return self.op.src[0].permute(arg).expand(tuple([self.op.arg[a] for a in arg]))
# move permutes before reshapes if we can
if PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer):
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if all(a == 1 for a in arg): return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Any]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1:
assert isinstance(bx.op.op, MovementOps)
mops.append((bx.op.op, bx.op.arg))
assert isinstance(bx.op.src[0], LazyBuffer)
bx = bx.op.src[0]
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
if mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and (all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())):
new_srcs.append(bx.op.replace_with_movement_ops(mops[::-1]))
else:
new_srcs.append(x)
return tuple(new_srcs)
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape,
MovementOps.EXPAND: LazyBuffer.expand,
MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}