mirror of https://github.com/commaai/tinygrad.git
Match Torch speed for sum reduction (#1387)
Co-authored-by: Alexander Edwards <alex@alexedw.com>
This commit is contained in:
parent
09ede08b23
commit
1ba8ae62a1
|
@ -63,7 +63,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
|||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
img = Tensor.randn(1, 3, 224, 224)
|
||||
# TODO: this seems very high
|
||||
with CLCache(115):
|
||||
with CLCache(116):
|
||||
model.forward(img).realize()
|
||||
|
||||
def test_resnet(self):
|
||||
|
|
|
@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 924)
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 967)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
def test_llama(self):
|
||||
|
@ -87,7 +87,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/32)*BS, 236)
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/32)*BS, 267)
|
||||
|
||||
# reset device
|
||||
Tensor.training = old_training
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import operator
|
||||
import operator, math
|
||||
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
|
||||
import sys, importlib, inspect, functools, pathlib
|
||||
from weakref import ref
|
||||
|
@ -210,11 +210,18 @@ class LazyBuffer:
|
|||
return root.reshape(ret.st.shape)
|
||||
return ret
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
|
||||
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
|
||||
|
||||
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
if prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
|
||||
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
|
||||
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.RESHAPE:
|
||||
|
|
Loading…
Reference in New Issue