mirror of https://github.com/commaai/tinygrad.git
move sum acc_dtype into lazy so it applies to backward (#4149)
* move sum acc_dtype into lazy so it applies to backward * unit test
This commit is contained in:
parent
bbda20c0db
commit
380f27d629
|
@ -602,6 +602,13 @@ class TestAutoCastType(unittest.TestCase):
|
|||
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_backward_sum_acc_dtype(self):
|
||||
# test acc of sum in the backward is upcasted to float
|
||||
t = Tensor([5, -5], dtype=dtypes.half, requires_grad=True)
|
||||
t.reshape(2, 1).expand(2, 10001).max().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
result = []
|
||||
|
|
|
@ -112,15 +112,15 @@ class MultiLazyBuffer:
|
|||
def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]:
|
||||
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> MultiLazyBuffer:
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer:
|
||||
if self.axis is not None and self.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
reduced_parts = [x.r(op, axis) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)]
|
||||
reduced_parts = [x.r(op, axis, acc_dt) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)]
|
||||
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
||||
return MultiLazyBuffer(reduced_parts, None, self.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
||||
return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real)
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
|
|
|
@ -145,14 +145,14 @@ class Where(Function):
|
|||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
|
||||
self.input_shape = x.shape
|
||||
return x.r(ReduceOps.SUM, axis)
|
||||
return x.r(ReduceOps.SUM, axis, acc_dtype)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape)
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer:
|
||||
self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis
|
||||
return self.ret
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Union, Optional, Any, Tuple, List
|
||||
from tinygrad.dtype import dtypes, DType, ConstType
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, least_upper_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
@ -148,14 +148,14 @@ class LazyBuffer:
|
|||
|
||||
# *** reduce ops ***
|
||||
|
||||
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer:
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(x for x in axis if self.shape[x] != 1)
|
||||
if len(axis) == 0: return self
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, axis, (self,))
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,))
|
||||
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer:
|
||||
new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape))
|
||||
# TODO: this logic should move to the scheduler
|
||||
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
||||
|
@ -163,17 +163,26 @@ class LazyBuffer:
|
|||
if self.is_unrealized_unmasked_const():
|
||||
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
||||
|
||||
# upcast acc_dt here so if reduce is splitted, the intermediate dtype is upcasted
|
||||
if op is ReduceOps.SUM and acc_dt is None:
|
||||
acc_dt = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
if acc_dt is not None and acc_dt != self.dtype:
|
||||
# cast back to float16 or bfloat16 to match torch / jax behavior
|
||||
return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
|
||||
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, axis)
|
||||
return self._reduce_op(op, axis, acc_dt)
|
||||
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, s))/(st or math.inf), divisor, i) for i,(s,st) in \
|
||||
enumerate(zip(self.shape, self.st.real_strides())) if i in axis and (st is None or isinstance(st, int)))
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis)
|
||||
if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, axis, acc_dt)
|
||||
# choose largest divisor (>=16) to split on, penalize large strides
|
||||
def splitted_shape(dim_aft_div):
|
||||
return self.shape[:dim_to_split] + (self.shape[dim_to_split]//divisor,) + dim_aft_div + self.shape[dim_to_split+1:]
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,)).reshape(splitted_shape(()))._reduce_op(op, axis)
|
||||
return self.reshape(splitted_shape((divisor,)))._reduce_op(op, (dim_to_split+1,), acc_dt).reshape(splitted_shape(()))._reduce_op(op, axis, acc_dt)
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
|
|
|
@ -610,21 +610,14 @@ class Tensor:
|
|||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor:
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis))
|
||||
axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_)
|
||||
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
|
||||
ret = fxn.apply(self, axis=axis_)
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
|
||||
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
|
||||
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
|
||||
return self.cast(acc_dtype)._reduce(F.Sum, axis, keepdim).cast(output_dtype)
|
||||
ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype)
|
||||
return ret if keepdim else ret.reshape(shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype)
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
|
|
Loading…
Reference in New Issue