From 380f27d6292257f928b9816a7d5b5b4081a9be05 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 11 Apr 2024 14:43:56 -0400 Subject: [PATCH] move sum acc_dtype into lazy so it applies to backward (#4149) * move sum acc_dtype into lazy so it applies to backward * unit test --- test/test_dtype.py | 7 +++++++ tinygrad/features/multi.py | 6 +++--- tinygrad/function.py | 6 +++--- tinygrad/lazy.py | 23 ++++++++++++++++------- tinygrad/tensor.py | 15 ++++----------- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 06710095..13438d52 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -602,6 +602,13 @@ class TestAutoCastType(unittest.TestCase): dtypes.default_float = old_default_float + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_backward_sum_acc_dtype(self): + # test acc of sum in the backward is upcasted to float + t = Tensor([5, -5], dtype=dtypes.half, requires_grad=True) + t.reshape(2, 1).expand(2, 10001).max().backward() + np.testing.assert_allclose(t.grad.numpy(), [1, 0]) + class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self): result = [] diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 94d00a69..cac2915c 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -112,15 +112,15 @@ class MultiLazyBuffer: def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) - def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer: if self.axis is not None and self.axis in axis: # all-reduce on sharded axes new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - reduced_parts = [x.r(op, axis) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] + reduced_parts = [x.r(op, axis, acc_dt) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) + return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real) # *** movement ops *** diff --git a/tinygrad/function.py b/tinygrad/function.py index 9dc6fd53..9c854fc4 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -145,14 +145,14 @@ class Where(Function): # ************* reduce ops ************* class Sum(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, axis) + return x.r(ReduceOps.SUM, axis, acc_dtype) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 7e017ad4..8aaa23fd 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,7 +1,7 @@ from __future__ import annotations import math from typing import Union, Optional, Any, Tuple, List -from tinygrad.dtype import dtypes, DType, ConstType +from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype from tinygrad.helpers import prod, getenv, all_int, all_same from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu from tinygrad.shape.symbolic import sint @@ -148,14 +148,14 @@ class LazyBuffer: # *** reduce ops *** - def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: + def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}" axis = tuple(x for x in axis if self.shape[x] != 1) if len(axis) == 0: return self new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,)) + return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,)) - def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) # TODO: this logic should move to the scheduler if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) @@ -163,17 +163,26 @@ class LazyBuffer: if self.is_unrealized_unmasked_const(): return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape) + # upcast acc_dt here so if reduce is splitted, the intermediate dtype is upcasted + if op is ReduceOps.SUM and acc_dt is None: + acc_dt = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ + least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ + least_upper_dtype(self.dtype, dtypes.float) + if acc_dt is not None and acc_dt != self.dtype: + # cast back to float16 or bfloat16 to match torch / jax behavior + return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt) + # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): - return self._reduce_op(op, axis) + return self._reduce_op(op, axis, acc_dt) heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, s))/(st or math.inf), divisor, i) for i,(s,st) in \ enumerate(zip(self.shape, self.st.real_strides())) if i in axis and (st is None or isinstance(st, int))) - if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis) + if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis, acc_dt) # choose largest divisor (>=16) to split on, penalize large strides def splitted_shape(dim_aft_div): return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:] - return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,)).reshape(splitted_shape(()))._reduce_op(op, axis) + return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,), acc_dt).reshape(splitted_shape(()))._reduce_op(op, axis, acc_dt) # *** movement ops *** diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 24f6844c..7a16a93a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -610,21 +610,14 @@ class Tensor: # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor: + def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor: axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_) shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - ret = fxn.apply(self, axis=axis_) - return ret if keepdim else ret.reshape(shape=shape) - - def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): - if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ - least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ - least_upper_dtype(self.dtype, dtypes.float) - # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc - output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype - return self.cast(acc_dtype)._reduce(F.Sum, axis, keepdim).cast(output_dtype) + ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype) + return ret if keepdim else ret.reshape(shape) + def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype) def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))