Revert "Revert "Tensor.randint is Tensor.uniform with dtypes.int32 (#2801)" (#2802)" (#2814)

This reverts commit fa84998244.
This commit is contained in:
chenyu 2023-12-17 12:14:17 -05:00 committed by GitHub
parent b4fa189c8c
commit 9c32474a1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -63,6 +63,10 @@ class TestRandomness(unittest.TestCase):
self.assertTrue(normal_test(Tensor.randn))
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
def test_randint(self):
self.assertFalse(normal_test(Tensor.randint))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
def test_normal(self):
self.assertTrue(normal_test(Tensor.normal))
self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1), lambda x: np.random.normal(loc=0, scale=1, size=x)))

View File

@ -204,8 +204,7 @@ class Tensor:
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor:
return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32)
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(shape, low=low, high=high, dtype=dtypes.int32)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean