mirror of https://github.com/commaai/tinygrad.git
Fix bug where Tensor.randn returns inf (#1192)
* fix randn inf bug * add test * more compact test * clarify test purpose
This commit is contained in:
parent
d9c1d81e99
commit
628ee46627
|
@ -145,6 +145,13 @@ class TestTinygrad(unittest.TestCase):
|
|||
b = random_fn(10,10).realize()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
|
||||
def test_randn_isnt_inf_on_zero(self):
|
||||
# simulate failure case of rand handing a zero to randn
|
||||
original_rand, Tensor.rand = Tensor.rand, Tensor.zeros
|
||||
try: self.assertNotIn(np.inf, Tensor.randn(16).numpy())
|
||||
except: raise
|
||||
finally: Tensor.rand = original_rand
|
||||
|
||||
def test_zeros_like_has_same_dtype(self):
|
||||
for datatype in [dtypes.float16, dtypes.float32, dtypes.int8, dtypes.int32, dtypes.int64, dtypes.uint8]:
|
||||
a = Tensor([1, 2, 3], dtype=datatype)
|
||||
|
|
|
@ -177,7 +177,7 @@ class Tensor:
|
|||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
return src[0].mul(2*pi).cos().mul(((1 - src[1])).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
|
|
Loading…
Reference in New Issue