fix Tensor.randint ignoring kwargs (#3350)

* fix Tensor.randint ignoring kwargs

* randint kwargs fix
This commit is contained in:
andresgit 2024-02-09 18:12:16 +02:00 committed by GitHub
parent ce21fdfb67
commit 28ba1c5406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 1 deletions

View File

@ -65,6 +65,7 @@ class TestRandomness(unittest.TestCase):
def test_randint(self):
self.assertFalse(normal_test(Tensor.randint))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(Tensor.randint(1,device="CLANG").device=="CLANG")
def test_normal(self):
self.assertTrue(normal_test(Tensor.normal))

View File

@ -229,7 +229,7 @@ class Tensor:
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float)
@staticmethod
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32)
def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean