diff --git a/test/test_gc.py b/test/test_gc.py index 37e632ac..8b7979e9 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -21,24 +21,26 @@ class TestGC(unittest.TestCase): (a*b).mean().backward() assert (tensors_allocated() > 0) del a,b - assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters + assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds + Tensor.manual_seed(0) def test_gc_complex(self): Tensor.manual_seed(0) a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated() == 4) - (a*b).mean().backward() assert (tensors_allocated() == 5) + (a*b).mean().backward() + assert (tensors_allocated() == 6) del b - assert (tensors_allocated() == 3) + assert (tensors_allocated() == 4) b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) - assert (tensors_allocated() == 5) + assert (tensors_allocated() == 6) del b - assert (tensors_allocated() == 3) + assert (tensors_allocated() == 4) + Tensor.manual_seed(0) def test_schedule_gc(self): init = bufs_allocated() diff --git a/test/test_jit.py b/test/test_jit.py index 6cb13409..6988a5ee 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -307,6 +307,14 @@ class TestJit(unittest.TestCase): assert len(res3) == 10, "All values should be different, rand works in jit." assert res3 != res2, "Jit rand is diff with diff seeds" + def test_jit_random_after_unrealized_random(self): + @TinyJit + def f(): return Tensor.rand() + Tensor.manual_seed(1234) + Tensor.rand() + res = [f().numpy() for _ in range(3)] + assert res[1] != res[2] + def test_jit_realization_and_sampling(self): w = Tensor.eye(5) diff --git a/test/test_randomness.py b/test/test_randomness.py index 43c6ea59..daa5effe 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -76,7 +76,7 @@ class TestRandomness(unittest.TestCase): equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) @unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry") - def test_threefly_against_reference(self): + def test_threefry_against_reference(self): Tensor.manual_seed(1337) # reference generated using @@ -92,11 +92,11 @@ class TestRandomness(unittest.TestCase): counts = Tensor.arange(20, dtype=dtypes.uint32) counts0, counts1 = counts.chunk(2) - r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy() + r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy() np.testing.assert_allclose(jr, r) - def test_threefly_against_reference_full(self): + def test_threefry_against_reference_full(self): Tensor.manual_seed(1337) # reference generated using @@ -118,7 +118,7 @@ class TestRandomness(unittest.TestCase): np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5) @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI") - def test_threefly_tensors_cnt(self): + def test_threefry_tensors_cnt(self): Tensor.manual_seed(1337) Tensor.rand(20).realize() @@ -136,6 +136,31 @@ class TestRandomness(unittest.TestCase): assert len(Tensor._device_rng_counters) == 0 assert len(Tensor._device_seeds) == 0 + @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI") + def test_threefry_same_kernels(self): + Tensor.manual_seed(0) + + Tensor.rand(1).realize() + + s = Tensor.rand(20).schedule() + s2 = Tensor.rand(20).schedule() + + assert len(s) == len(s2), f"{len(s)} != {len(s2)}" + for x,y in zip(s, s2): + if not (x.ast == y.ast): + print(f"{x.ast} != {y.ast}") + + Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize() + + s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule() + s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule() + + assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}" + assert len(s2) == len(s4), f"{len(s)} != {len(s3)}" + for x,y in zip(s3, s4): + if not (x.ast == y.ast): + print(f"{x.ast} != {y.ast}") + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support") def test_rand_bfloat16(self): N = 128 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 507fcc25..0adbb557 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -426,7 +426,7 @@ class Tensor: return r _seed: int = int(time.time()) - _device_seeds: Dict[str, int] = {} + _device_seeds: Dict[str, Tensor] = {} _device_rng_counters: Dict[str, Tensor] = {} @staticmethod def manual_seed(seed=0): @@ -447,9 +447,8 @@ class Tensor: Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {} @staticmethod - def _threefry_random_bits(key0, key1, counts0, counts1): + def _threefry_random_bits(key, counts0, counts1): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1 x = F.Threefry.apply(*x._broadcasted(key)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -478,7 +477,9 @@ class Tensor: # generate per device seeds and rng counter if we haven't seen this device yet if device not in Tensor._device_seeds: - Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff + Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \ + | int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff], + device=device, dtype=dtypes.uint64, requires_grad=False) Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False) had_counter = False else: had_counter = True @@ -487,12 +488,12 @@ class Tensor: if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) # increment rng counter for devices - if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num) + if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous() # threefry random bits counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device]) counts1 = counts0 + ceildiv(num, 2) - bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num] + bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num] # bitcast to uint with same number of bits _, nmant = dtypes.finfo(dtype)