diff --git a/test/test_dtype.py b/test/test_dtype.py index 2bdfffb0..798226c1 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -621,6 +621,11 @@ class TestAutoCastType(unittest.TestCase): t.reshape(2, 1).expand(2, 10001).max().backward() np.testing.assert_allclose(t.grad.numpy(), [1, 0]) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_mean_half_precision(self): + t = Tensor([60000, 60000, 60000], dtype=dtypes.half) + np.testing.assert_allclose(t.mean().numpy(), 60000) + class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self): result = [] diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 7fc4b6ad..9cc512c2 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -113,15 +113,15 @@ class MultiLazyBuffer: def _shape_to_single_shard(self, shape:Tuple[sint, ...], lb:LazyBuffer) -> Tuple[sint, ...]: return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) - def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> MultiLazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None, downcast_half:bool=True) -> MultiLazyBuffer: if self.axis is not None and self.axis in axis: # all-reduce on sharded axes new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - reduced_parts = [x.r(op, axis, acc_dt) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] + reduced_parts = [x.r(op, axis, acc_dt, downcast_half) if r else x.const(0, shape=new_shape) for x,r in zip(self.lbs, self.real)] if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) return MultiLazyBuffer(reduced_parts, None, self.real) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis, acc_dt) for x in self.lbs], self.axis, self.real) + return MultiLazyBuffer([x.r(op, axis, acc_dt, downcast_half) for x in self.lbs], self.axis, self.real) # *** movement ops *** diff --git a/tinygrad/function.py b/tinygrad/function.py index 15bbc280..750a96d3 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -146,14 +146,14 @@ class Where(Function): # ************* reduce ops ************* class Sum(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, axis, acc_dtype) + return x.r(ReduceOps.SUM, axis, acc_dtype, downcast_half) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): - def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...], acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer: self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 45a80853..eb5806d2 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -160,7 +160,7 @@ class LazyBuffer: new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, (axis, acc_dt), (self,)) - def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None) -> LazyBuffer: + def r(self, op:ReduceOps, axis:Tuple[int, ...], acc_dt:Optional[DType]=None, downcast_half:bool=True) -> LazyBuffer: new_shape = tuple(1 if i in axis else s for i,s in enumerate(self.shape)) # TODO: this logic should move to the scheduler if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) @@ -175,7 +175,7 @@ class LazyBuffer: least_upper_dtype(self.dtype, dtypes.float) if acc_dt is not None and acc_dt != self.dtype: # cast back to float16 or bfloat16 to match torch / jax behavior - return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt) + return self.cast(acc_dt).r(op, axis, acc_dt).cast(self.dtype if downcast_half and self.dtype in [dtypes.float16, dtypes.bfloat16] else acc_dt) # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 94aa0ce1..548ff967 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -907,21 +907,23 @@ class Tensor: # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False, acc_dtype:Optional[DType]=None) -> Tensor: + def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, + keepdim=False, acc_dtype:Optional[DType]=None, downcast_half:bool=True) -> Tensor: axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_) shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype) + ret = fxn.apply(self, axis=axis_, acc_dtype=acc_dtype, downcast_half=downcast_half) return ret if keepdim else ret.reshape(shape) - def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): return self._reduce(F.Sum, axis, keepdim, acc_dtype) + def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None, downcast_half:bool=True): + return self._reduce(F.Sum, axis, keepdim, acc_dtype, downcast_half) def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) def mean(self, axis=None, keepdim=False): assert all_int(self.shape), "does not support symbolic shape" - out = self.sum(axis=axis, keepdim=keepdim) - return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out + out = self.sum(axis=axis, keepdim=keepdim, downcast_half=False) + return out.div(prod(self.shape) / prod(out.shape)).cast(self.dtype) if 0 not in out.shape else out.cast(self.dtype) def var(self, axis=None, keepdim=False, correction=1): assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)