stronger test case for half mean overflow (#4470)

This commit is contained in:
chenyu 2024-05-07 22:40:09 -04:00 committed by GitHub
parent ca7300c783
commit 7eb035e7c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 2 deletions

View File

@ -638,10 +638,11 @@ class TestAutoCastType(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_overflow(self):
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
N = 256
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
np.testing.assert_allclose(t.mean().numpy(), 60000)
t.square().mean().backward()
np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3)
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
class TestImplicitFunctionTypeChange(unittest.TestCase):
def test_functions(self):