From ca7300c783faa54317c940f2462bcf78f533cb1a Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 7 May 2024 21:46:41 -0400 Subject: [PATCH] fix half mean and its backward (#4469) * fix half mean and its backward cast to sum_acc_type, sum, div, then cast back * mean dtype tests --- test/test_dtype.py | 16 +++++++++++++++- tinygrad/tensor.py | 5 +++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index a9f0cbbf..e28a1972 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -523,6 +523,21 @@ class TestAutoCastType(unittest.TestCase): assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 + def test_mean(self): + assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16 + #assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64 + def test_cumsum(self): assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32 assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32 @@ -621,7 +636,6 @@ class TestAutoCastType(unittest.TestCase): t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous() np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3) - @unittest.skip("TODO: fix this") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_mean_half_precision_overflow(self): t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 22b3e32b..cc229081 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -927,8 +927,9 @@ class Tensor: def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) def mean(self, axis=None, keepdim=False): - out = self.sum(axis=axis, keepdim=keepdim) - return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])) + output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32 + numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) + return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype) def var(self, axis=None, keepdim=False, correction=1): assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)