mirror of https://github.com/commaai/tinygrad.git
fix half mean and its backward (#4469)
* fix half mean and its backward cast to sum_acc_type, sum, div, then cast back * mean dtype tests
This commit is contained in:
parent
7da1b41f38
commit
ca7300c783
|
@ -523,6 +523,21 @@ class TestAutoCastType(unittest.TestCase):
|
|||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
def test_mean(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint8)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint16)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
def test_cumsum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32
|
||||
|
@ -621,7 +636,6 @@ class TestAutoCastType(unittest.TestCase):
|
|||
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
|
||||
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||
|
||||
@unittest.skip("TODO: fix this")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_overflow(self):
|
||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
||||
|
|
|
@ -927,8 +927,9 @@ class Tensor:
|
|||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
|
||||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
||||
def var(self, axis=None, keepdim=False, correction=1):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
|
|
Loading…
Reference in New Issue