From 97d708252a13820820591ec6f492ebb95b14602e Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Wed, 7 Aug 2024 15:08:49 -0700 Subject: [PATCH] remove realize from threefry (#5969) --- test/test_jit.py | 41 +++++++++++++++++++++++++++++++++++++++++ test/test_linearizer.py | 6 +++--- tinygrad/tensor.py | 9 +++++---- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index c66079a9..6cb13409 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -266,6 +266,47 @@ class TestJit(unittest.TestCase): assert len(res3) == 5, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + def test_jit_multiple_random_regen(self): + def f(a, b): + rn = Tensor.randn(*a.shape) + rn = rn * a + rn2 = Tensor.randn(*a.shape) + rn2 = rn2 * b + rn = rn + rn2 + rn2 = rn2 + Tensor.randn(*a.shape) + return ((a+b)*rn).realize(), ((a+b)*rn2).realize() + a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed + b = Tensor.randn(10, 10).realize() + + Tensor.manual_seed(1234) + jf = TinyJit(f) + res = set() + for _ in range(5): + o1, o2 = jf(a, b) + res.add(o1.numpy()[0][0]) + res.add(o2.numpy()[0][0]) + assert len(res) == 10, "All values should be different, rand works in jit." + + Tensor.manual_seed(1234) + jf2 = TinyJit(f) + res2 = set() + for _ in range(5): + o1, o2 = jf2(a, b) + res2.add(o1.numpy()[0][0]) + res2.add(o2.numpy()[0][0]) + assert len(res2) == 10, "All values should be different, rand works in jit." + assert res == res2, "Jit rand is not reproducible with the same seed" + + Tensor.manual_seed(3421) + jf3 = TinyJit(f) + res3 = set() + for _ in range(5): + o1, o2 = jf3(a, b) + res3.add(o1.numpy()[0][0]) + res3.add(o2.numpy()[0][0]) + assert len(res3) == 10, "All values should be different, rand works in jit." + assert res3 != res2, "Jit rand is diff with diff seeds" + def test_jit_realization_and_sampling(self): w = Tensor.eye(5) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 6845a37e..cf3ed42c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -854,9 +854,9 @@ class TestLinearizer(unittest.TestCase): lin = Kernel(sched[0].ast) assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg - a = Tensor.rand((4,4)) - b = Tensor.rand((4,4)) - d = Tensor.rand((4,4)) + a = Tensor.empty((4,4)) + b = Tensor.empty((4,4)) + d = Tensor.empty((4,4)) c = (a*b)/b helper(c, "found UnaryOps.RECIP in (a*b)/b operation") diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7d7be09f..9ce3593b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -401,7 +401,7 @@ class Tensor: print(Tensor.rand(5).numpy()) ``` """ - Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False) + Tensor._seed, Tensor._rng_counter = seed, None @staticmethod def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, **kwargs): @@ -417,7 +417,8 @@ class Tensor: print(t.numpy()) ``` """ - if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False) + if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False) + if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}") if not THREEFRY.value: # for bfloat16, numpy rand passes buffer in float if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16: @@ -426,9 +427,9 @@ class Tensor: # threefry if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) - counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize() + if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num) + counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)) counts2 = counts1 + math.ceil(num / 2) - Tensor._rng_counter.assign(Tensor._rng_counter + num).realize() x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64) x = F.Threefry.apply(*x._broadcasted(Tensor._seed))