From 86c2f267d48408e2931177224fe8679f6111de9b Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 16 Dec 2023 15:14:50 -0500 Subject: [PATCH] Tensor.randint is Tensor.uniform with dtypes.int32 (#2801) --- test/test_randomness.py | 4 ++++ tinygrad/tensor.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_randomness.py b/test/test_randomness.py index f7954745..f4b13d71 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -63,6 +63,10 @@ class TestRandomness(unittest.TestCase): self.assertTrue(normal_test(Tensor.randn)) self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x))) + def test_randint(self): + self.assertFalse(normal_test(Tensor.randint)) + self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) + def test_normal(self): self.assertTrue(normal_test(Tensor.normal)) self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1), lambda x: np.random.normal(loc=0, scale=1, size=x))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d3072e66..13f18864 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -199,8 +199,7 @@ class Tensor: return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) @staticmethod - def randint(*shape, low=0, high=10, **kwargs) -> Tensor: - return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32) + def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(shape, low=low, high=high, dtype=dtypes.int32) @staticmethod def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean