Fix bug where Tensor.randn returns inf (#1192)

* fix randn inf bug

* add test

* more compact test

* clarify test purpose
This commit is contained in:
fluffy χατγιρλ 2023-07-08 21:03:46 +02:00 committed by GitHub
parent d9c1d81e99
commit 628ee46627
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -145,6 +145,13 @@ class TestTinygrad(unittest.TestCase):
b = random_fn(10,10).realize()
np.testing.assert_allclose(a.numpy(), b.numpy())
def test_randn_isnt_inf_on_zero(self):
# simulate failure case of rand handing a zero to randn
original_rand, Tensor.rand = Tensor.rand, Tensor.zeros
try: self.assertNotIn(np.inf, Tensor.randn(16).numpy())
except: raise
finally: Tensor.rand = original_rand
def test_zeros_like_has_same_dtype(self):
for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
a = Tensor([1, 2, 3], dtype=datatype)

View File

@ -177,7 +177,7 @@ class Tensor:
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
return src[0].mul(2*pi).cos().mul(((1 - src[1])).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low