From 826cccd54d91ac5bf8978ceedcbc8e8290ff20fc Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 1 May 2024 13:38:57 -0400 Subject: [PATCH] fix mean underflow for half tensor (#4377) * fix mean underflow for half tensor divide only the reduce factor. added unit test and non-nan assertion in resnet training. also added a failed test cast for symbolic shape var * skip for python backend --- examples/mlperf/model_train.py | 1 + test/test_dtype.py | 10 +++++++++- test/test_tensor_variable.py | 8 ++++++++ tinygrad/tensor.py | 2 +- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index a7c87486..1dee6a9b 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -176,6 +176,7 @@ def train_resnet(): i += 1 if i == BENCHMARK: + assert not math.isnan(loss) median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60) print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m") diff --git a/test/test_dtype.py b/test/test_dtype.py index c76cb8fb..216952cb 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -621,9 +621,17 @@ class TestAutoCastType(unittest.TestCase): t.reshape(2, 1).expand(2, 10001).max().backward() np.testing.assert_allclose(t.grad.numpy(), [1, 0]) + @unittest.skipIf(Device.DEFAULT=="PYTHON", "very slow") + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") + def test_mean_half_precision_underflow(self): + N = 10000 + x = 0.001 + t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous() + np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3) + @unittest.skip("TODO: fix this") @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") - def test_mean_half_precision(self): + def test_mean_half_precision_overflow(self): t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True) np.testing.assert_allclose(t.mean().numpy(), 60000) t.square().mean().backward() diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index b1613b9c..45b19928 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -32,6 +32,14 @@ class TestTensorVariable(unittest.TestCase): ret = t.mean().item() assert ret == 1 + @unittest.skip("symbolic var isn't supported") + def test_symbolic_var(self): + vv = Variable("a", 1, 10) + vv.bind(2) + t = Tensor.ones(2, 2).contiguous().reshape(2, vv) + ret = t.var().item() + assert ret == 0 + def test_symbolic_mean_2d(self): vv = Variable("a", 1, 10) vv.bind(2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index be13d100..844797ca 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -926,7 +926,7 @@ class Tensor: def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) - return out.div(prod(self.shape)).mul(prod(out.shape)) if 0 not in out.shape else out + return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])) def var(self, axis=None, keepdim=False, correction=1): assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)