fix mean underflow for half tensor (#4377)

* fix mean underflow for half tensor

divide only the reduce factor. added unit test and non-nan assertion in resnet training. also added a failed test cast for symbolic shape var

* skip for python backend
This commit is contained in:
chenyu 2024-05-01 13:38:57 -04:00 committed by GitHub
parent dce7ac0160
commit 826cccd54d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 2 deletions

View File

@ -176,6 +176,7 @@ def train_resnet():
i += 1
if i == BENCHMARK:
assert not math.isnan(loss)
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")

View File

@ -621,9 +621,17 @@ class TestAutoCastType(unittest.TestCase):
t.reshape(2, 1).expand(2, 10001).max().backward()
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
@unittest.skipIf(Device.DEFAULT=="PYTHON", "very slow")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):
N = 10000
x = 0.001
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
@unittest.skip("TODO: fix this")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision(self):
def test_mean_half_precision_overflow(self):
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()

View File

@ -32,6 +32,14 @@ class TestTensorVariable(unittest.TestCase):
ret = t.mean().item()
assert ret == 1
@unittest.skip("symbolic var isn't supported")
def test_symbolic_var(self):
vv = Variable("a", 1, 10)
vv.bind(2)
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
ret = t.var().item()
assert ret == 0
def test_symbolic_mean_2d(self):
vv = Variable("a", 1, 10)
vv.bind(2)

View File

@ -926,7 +926,7 @@ class Tensor:
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out.div(prod(self.shape)).mul(prod(out.shape)) if 0 not in out.shape else out
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
def var(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)