diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d803648a..6aebccd7 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,5 +1,5 @@ from __future__ import annotations -import sys, operator, math +import sys, math from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set from weakref import ref, WeakSet, WeakValueDictionary @@ -266,9 +266,11 @@ class LazyBuffer: return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype) def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: - if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. + # TODO: can we split symbolic shape if the reduce axis is not symbolic? + if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore - if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides. + if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) + # choose largest divisor (>=16) to split on, penalize large strides def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape) @@ -332,7 +334,7 @@ class LazyBuffer: def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if all(a == 1 for a in arg): return self - if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg))) + if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg))) return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg) def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: @@ -357,10 +359,6 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: return tuple(new_srcs) MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { - MovementOps.RESHAPE: LazyBuffer.reshape, - MovementOps.EXPAND: LazyBuffer.expand, - MovementOps.SHRINK: LazyBuffer.shrink, - MovementOps.PERMUTE: LazyBuffer.permute, - MovementOps.PAD: LazyBuffer.pad, - MovementOps.STRIDE: LazyBuffer.stride, + MovementOps.RESHAPE: LazyBuffer.reshape, MovementOps.EXPAND: LazyBuffer.expand, MovementOps.SHRINK: LazyBuffer.shrink, + MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride, }