minor lazy tweak before rewrite (#2573)

This commit is contained in:
chenyu 2023-12-02 18:23:33 -05:00 committed by GitHub
parent fa1d4dd14b
commit 875c34bfc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 10 deletions

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import sys, operator, math
import sys, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
from weakref import ref, WeakSet, WeakValueDictionary
@ -266,9 +266,11 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape)
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape)
# choose largest divisor (>=16) to split on, penalize large strides
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
@ -332,7 +334,7 @@ class LazyBuffer:
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if all(a == 1 for a in arg): return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg)))
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
@ -357,10 +359,6 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
return tuple(new_srcs)
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape,
MovementOps.EXPAND: LazyBuffer.expand,
MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
MovementOps.RESHAPE: LazyBuffer.reshape, MovementOps.EXPAND: LazyBuffer.expand, MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride,
}