fix half mean and its backward (#4469)

* fix half mean and its backward

cast to sum_acc_type, sum, div, then cast back

* mean dtype tests
This commit is contained in:
chenyu 2024-05-07 21:46:41 -04:00 committed by GitHub
parent 7da1b41f38
commit ca7300c783
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 3 deletions

View File

@ -523,6 +523,21 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
def test_cumsum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
@ -621,7 +636,6 @@ class TestAutoCastType(unittest.TestCase):
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("TODO: fix this")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)

View File

@ -927,8 +927,9 @@ class Tensor:
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)