mirror of https://github.com/commaai/tinygrad.git
remove realize from threefry (#5969)
This commit is contained in:
parent
bf8ec23b00
commit
97d708252a
|
@ -266,6 +266,47 @@ class TestJit(unittest.TestCase):
|
|||
assert len(res3) == 5, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
rn = rn * a
|
||||
rn2 = Tensor.randn(*a.shape)
|
||||
rn2 = rn2 * b
|
||||
rn = rn + rn2
|
||||
rn2 = rn2 + Tensor.randn(*a.shape)
|
||||
return ((a+b)*rn).realize(), ((a+b)*rn2).realize()
|
||||
a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed
|
||||
b = Tensor.randn(10, 10).realize()
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf = TinyJit(f)
|
||||
res = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf(a, b)
|
||||
res.add(o1.numpy()[0][0])
|
||||
res.add(o2.numpy()[0][0])
|
||||
assert len(res) == 10, "All values should be different, rand works in jit."
|
||||
|
||||
Tensor.manual_seed(1234)
|
||||
jf2 = TinyJit(f)
|
||||
res2 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf2(a, b)
|
||||
res2.add(o1.numpy()[0][0])
|
||||
res2.add(o2.numpy()[0][0])
|
||||
assert len(res2) == 10, "All values should be different, rand works in jit."
|
||||
assert res == res2, "Jit rand is not reproducible with the same seed"
|
||||
|
||||
Tensor.manual_seed(3421)
|
||||
jf3 = TinyJit(f)
|
||||
res3 = set()
|
||||
for _ in range(5):
|
||||
o1, o2 = jf3(a, b)
|
||||
res3.add(o1.numpy()[0][0])
|
||||
res3.add(o2.numpy()[0][0])
|
||||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_realization_and_sampling(self):
|
||||
w = Tensor.eye(5)
|
||||
|
||||
|
|
|
@ -854,9 +854,9 @@ class TestLinearizer(unittest.TestCase):
|
|||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.rand((4,4))
|
||||
b = Tensor.rand((4,4))
|
||||
d = Tensor.rand((4,4))
|
||||
a = Tensor.empty((4,4))
|
||||
b = Tensor.empty((4,4))
|
||||
d = Tensor.empty((4,4))
|
||||
|
||||
c = (a*b)/b
|
||||
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
|
||||
|
|
|
@ -401,7 +401,7 @@ class Tensor:
|
|||
print(Tensor.rand(5).numpy())
|
||||
```
|
||||
"""
|
||||
Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._seed, Tensor._rng_counter = seed, None
|
||||
|
||||
@staticmethod
|
||||
def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs):
|
||||
|
@ -417,7 +417,8 @@ class Tensor:
|
|||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False)
|
||||
if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}")
|
||||
if not THREEFRY.value:
|
||||
# for bfloat16, numpy rand passes buffer in float
|
||||
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
|
||||
|
@ -426,9 +427,9 @@ class Tensor:
|
|||
|
||||
# threefry
|
||||
if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize()
|
||||
if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num)
|
||||
counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device))
|
||||
counts2 = counts1 + math.ceil(num / 2)
|
||||
Tensor._rng_counter.assign(Tensor._rng_counter + num).realize()
|
||||
|
||||
x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64)
|
||||
x = F.Threefry.apply(*x._broadcasted(Tensor._seed))
|
||||
|
|
Loading…
Reference in New Issue