mirror of https://github.com/commaai/tinygrad.git
minor lazy tweak before rewrite (#2573)
This commit is contained in:
parent
fa1d4dd14b
commit
875c34bfc4
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import sys, operator, math
|
||||
import sys, math
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
|
||||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
|
@ -266,9 +266,11 @@ class LazyBuffer:
|
|||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype)
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape)
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape)
|
||||
# choose largest divisor (>=16) to split on, penalize large strides
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
|
@ -332,7 +334,7 @@ class LazyBuffer:
|
|||
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if all(a == 1 for a in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(a1*a2 for a1,a2 in zip(arg, self.op.arg)))
|
||||
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
|
@ -357,10 +359,6 @@ def _push_movement_ops(srcs:Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
|
|||
return tuple(new_srcs)
|
||||
|
||||
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape,
|
||||
MovementOps.EXPAND: LazyBuffer.expand,
|
||||
MovementOps.SHRINK: LazyBuffer.shrink,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute,
|
||||
MovementOps.PAD: LazyBuffer.pad,
|
||||
MovementOps.STRIDE: LazyBuffer.stride,
|
||||
MovementOps.RESHAPE: LazyBuffer.reshape, MovementOps.EXPAND: LazyBuffer.expand, MovementOps.SHRINK: LazyBuffer.shrink,
|
||||
MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue