diff --git a/test/test_dtype.py b/test/test_dtype.py index e28a1972..3027e6a7 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -638,10 +638,11 @@ class TestAutoCastType(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_mean_half_precision_overflow(self): - t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True) + N = 256 + t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N) np.testing.assert_allclose(t.mean().numpy(), 60000) t.square().mean().backward() - np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3) + np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N) class TestImplicitFunctionTypeChange(unittest.TestCase): def test_functions(self):