mirror of https://github.com/commaai/tinygrad.git
stronger test case for half mean overflow (#4470)
This commit is contained in:
parent
ca7300c783
commit
7eb035e7c5
|
@ -638,10 +638,11 @@ class TestAutoCastType(unittest.TestCase):
|
|||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_overflow(self):
|
||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
||||
N = 256
|
||||
t = Tensor([60000] * N*N, dtype=dtypes.half, requires_grad=True).reshape(N, N)
|
||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||
t.square().mean().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), [60000 * 2 / 3] * 3)
|
||||
np.testing.assert_allclose(t.grad.numpy().flatten(), [60000 * 2 / (N*N)] * N*N)
|
||||
|
||||
class TestImplicitFunctionTypeChange(unittest.TestCase):
|
||||
def test_functions(self):
|
||||
|
|
Loading…
Reference in New Issue