remove realize from threefry (#5969)

This commit is contained in:
wozeparrot 2024-08-07 15:08:49 -07:00 committed by GitHub
parent bf8ec23b00
commit 97d708252a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 7 deletions

View File

@ -266,6 +266,47 @@ class TestJit(unittest.TestCase):
assert len(res3) == 5, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
def test_jit_multiple_random_regen(self):
def f(a, b):
rn = Tensor.randn(*a.shape)
rn = rn * a
rn2 = Tensor.randn(*a.shape)
rn2 = rn2 * b
rn = rn + rn2
rn2 = rn2 + Tensor.randn(*a.shape)
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
b = Tensor.randn(10, 10).realize()
Tensor.manual_seed(1234)
jf = TinyJit(f)
res = set()
for _ in range(5):
o1, o2 = jf(a, b)
res.add(o1.numpy()[0][0])
res.add(o2.numpy()[0][0])
assert len(res) == 10, "All values should be different, rand works in jit."
Tensor.manual_seed(1234)
jf2 = TinyJit(f)
res2 = set()
for _ in range(5):
o1, o2 = jf2(a, b)
res2.add(o1.numpy()[0][0])
res2.add(o2.numpy()[0][0])
assert len(res2) == 10, "All values should be different, rand works in jit."
assert res == res2, "Jit rand is not reproducible with the same seed"
Tensor.manual_seed(3421)
jf3 = TinyJit(f)
res3 = set()
for _ in range(5):
o1, o2 = jf3(a, b)
res3.add(o1.numpy()[0][0])
res3.add(o2.numpy()[0][0])
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)

View File

@ -854,9 +854,9 @@ class TestLinearizer(unittest.TestCase):
lin = Kernel(sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.rand((4,4))
b = Tensor.rand((4,4))
d = Tensor.rand((4,4))
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))
d = Tensor.empty((4,4))
c = (a*b)/b
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")

View File

@ -401,7 +401,7 @@ class Tensor:
print(Tensor.rand(5).numpy())
```
"""
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
Tensor._seed, Tensor._rng_counter = seed, None
@staticmethod
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
@ -417,7 +417,8 @@ class Tensor:
print(t.numpy())
```
"""
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}")
if not THREEFRY.value:
# for bfloat16, numpy rand passes buffer in float
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
@ -426,9 +427,9 @@ class Tensor:
# threefry
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num)
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device))
counts2 = counts1 + math.ceil(num / 2)
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64)
x = F.Threefry.apply(*x._broadcasted(Tensor._seed))