From 46e360fdc0a5ff715fee50d8484f09bc56dfaf79 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Mon, 23 Sep 2024 10:48:44 +0800 Subject: [PATCH] check bfloat16 range with threefry (#6660) --- test/test_randomness.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/test_randomness.py b/test/test_randomness.py index 159584a9..7fe5e156 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase): N = 128 x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16) assert x.dtype == dtypes.bfloat16 - # TODO: fix this property for bfloat16 random - # x = x.numpy() - # ones = np.take(x, np.where(x == 1)) - # zeros = np.take(x, np.where(x == 0)) - # self.assertTrue(ones.size == 0) - # self.assertTrue(zeros.size > 0) + if THREEFRY.value: + nx = x.numpy() + assert nx[nx == 1].size == 0 + assert nx[nx == 0].size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) def test_randn(self):