mirror of https://github.com/commaai/tinygrad.git
fix mean underflow for half tensor (#4377)
* fix mean underflow for half tensor divide only the reduce factor. added unit test and non-nan assertion in resnet training. also added a failed test cast for symbolic shape var * skip for python backend
This commit is contained in:
parent
dce7ac0160
commit
826cccd54d
|
@ -176,6 +176,7 @@ def train_resnet():
|
|||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
assert not math.isnan(loss)
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
|
|
|
@ -621,9 +621,17 @@ class TestAutoCastType(unittest.TestCase):
|
|||
t.reshape(2, 1).expand(2, 10001).max().backward()
|
||||
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="PYTHON", "very slow")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
N = 10000
|
||||
x = 0.001
|
||||
t = Tensor([[x]], dtype=dtypes.half, requires_grad=True).expand(N, N).contiguous()
|
||||
np.testing.assert_allclose(t.mean(axis=1).numpy(), np.array([x] * N, dtype=np.float16), rtol=1e-3)
|
||||
|
||||
@unittest.skip("TODO: fix this")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision(self):
|
||||
def test_mean_half_precision_overflow(self):
|
||||
t = Tensor([60000, 60000, 60000], dtype=dtypes.half, requires_grad=True)
|
||||
np.testing.assert_allclose(t.mean().numpy(), 60000)
|
||||
t.square().mean().backward()
|
||||
|
|
|
@ -32,6 +32,14 @@ class TestTensorVariable(unittest.TestCase):
|
|||
ret = t.mean().item()
|
||||
assert ret == 1
|
||||
|
||||
@unittest.skip("symbolic var isn't supported")
|
||||
def test_symbolic_var(self):
|
||||
vv = Variable("a", 1, 10)
|
||||
vv.bind(2)
|
||||
t = Tensor.ones(2, 2).contiguous().reshape(2, vv)
|
||||
ret = t.var().item()
|
||||
assert ret == 0
|
||||
|
||||
def test_symbolic_mean_2d(self):
|
||||
vv = Variable("a", 1, 10)
|
||||
vv.bind(2)
|
||||
|
|
|
@ -926,7 +926,7 @@ class Tensor:
|
|||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out.div(prod(self.shape)).mul(prod(out.shape)) if 0 not in out.shape else out
|
||||
return out.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so]))
|
||||
def var(self, axis=None, keepdim=False, correction=1):
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
|
|
Loading…
Reference in New Issue