Match Torch speed for sum reduction (#1387)

Co-authored-by: Alexander Edwards <alex@alexedw.com>
This commit is contained in:
nimlgen 2023-08-06 08:27:33 +03:00 committed by GitHub
parent 09ede08b23
commit 1ba8ae62a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 5 deletions

View File

@ -63,7 +63,7 @@ class TestInferenceMinKernels(unittest.TestCase):
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
# TODO: this seems very high
with CLCache(115):
with CLCache(116):
model.forward(img).realize()
def test_resnet(self):

View File

@ -47,7 +47,7 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 924)
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.5, 967)
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
def test_llama(self):
@ -87,7 +87,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/32)*BS, 236)
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/32)*BS, 267)
# reset device
Tensor.training = old_training

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import operator
import operator, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
import sys, importlib, inspect, functools, pathlib
from weakref import ref
@ -210,11 +210,18 @@ class LazyBuffer:
return root.reshape(ret.st.shape)
return ret
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach.
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old))/(stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.125: return self._reduce_op(op, new_shape) # Choose largest divisor (>=16) to split on, penalize large strides.
def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, splitted_shape((1,))).reshape(splitted_shape(()))._reduce_op(op, new_shape)
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE: