check bfloat16 range with threefry (#6660)

This commit is contained in:
wozeparrot 2024-09-23 10:48:44 +08:00 committed by GitHub
parent d24e4b1042
commit 46e360fdc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 6 deletions

View File

@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
assert x.dtype == dtypes.bfloat16
# TODO: fix this property for bfloat16 random
# x = x.numpy()
# ones = np.take(x, np.where(x == 1))
# zeros = np.take(x, np.where(x == 0))
# self.assertTrue(ones.size == 0)
# self.assertTrue(zeros.size > 0)
if THREEFRY.value:
nx = x.numpy()
assert nx[nx == 1].size == 0
assert nx[nx == 0].size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
def test_randn(self):