mirror of https://github.com/commaai/tinygrad.git
check bfloat16 range with threefry (#6660)
This commit is contained in:
parent
d24e4b1042
commit
46e360fdc0
|
@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase):
|
|||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
|
||||
assert x.dtype == dtypes.bfloat16
|
||||
# TODO: fix this property for bfloat16 random
|
||||
# x = x.numpy()
|
||||
# ones = np.take(x, np.where(x == 1))
|
||||
# zeros = np.take(x, np.where(x == 0))
|
||||
# self.assertTrue(ones.size == 0)
|
||||
# self.assertTrue(zeros.size > 0)
|
||||
if THREEFRY.value:
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
def test_randn(self):
|
||||
|
|
Loading…
Reference in New Issue