seed in tensor (#6869)

This commit is contained in:
wozeparrot 2024-10-06 14:46:58 -04:00 committed by GitHub
parent f9e32f2bb2
commit 9eb6eef441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 52 additions and 16 deletions

View File

@ -21,24 +21,26 @@ class TestGC(unittest.TestCase):
(a*b).mean().backward()
assert (tensors_allocated() > 0)
del a,b
assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters
assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds
Tensor.manual_seed(0)
def test_gc_complex(self):
Tensor.manual_seed(0)
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor.rand(4, 4, requires_grad=True)
assert (tensors_allocated() == 4)
(a*b).mean().backward()
assert (tensors_allocated() == 5)
(a*b).mean().backward()
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
print(tensors_allocated())
(a*b).mean().backward()
print(tensors_allocated())
assert (tensors_allocated() == 5)
assert (tensors_allocated() == 6)
del b
assert (tensors_allocated() == 3)
assert (tensors_allocated() == 4)
Tensor.manual_seed(0)
def test_schedule_gc(self):
init = bufs_allocated()

View File

@ -307,6 +307,14 @@ class TestJit(unittest.TestCase):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()
Tensor.manual_seed(1234)
Tensor.rand()
res = [f().numpy() for _ in range(3)]
assert res[1] != res[2]
def test_jit_realization_and_sampling(self):
w = Tensor.eye(5)

View File

@ -76,7 +76,7 @@ class TestRandomness(unittest.TestCase):
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
def test_threefly_against_reference(self):
def test_threefry_against_reference(self):
Tensor.manual_seed(1337)
# reference generated using
@ -92,11 +92,11 @@ class TestRandomness(unittest.TestCase):
counts = Tensor.arange(20, dtype=dtypes.uint32)
counts0, counts1 = counts.chunk(2)
r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy()
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
np.testing.assert_allclose(jr, r)
def test_threefly_against_reference_full(self):
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)
# reference generated using
@ -118,7 +118,7 @@ class TestRandomness(unittest.TestCase):
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefly_tensors_cnt(self):
def test_threefry_tensors_cnt(self):
Tensor.manual_seed(1337)
Tensor.rand(20).realize()
@ -136,6 +136,31 @@ class TestRandomness(unittest.TestCase):
assert len(Tensor._device_rng_counters) == 0
assert len(Tensor._device_seeds) == 0
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
def test_threefry_same_kernels(self):
Tensor.manual_seed(0)
Tensor.rand(1).realize()
s = Tensor.rand(20).schedule()
s2 = Tensor.rand(20).schedule()
assert len(s) == len(s2), f"{len(s)} != {len(s2)}"
for x,y in zip(s, s2):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")
Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize()
s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}"
assert len(s2) == len(s4), f"{len(s)} != {len(s3)}"
for x,y in zip(s3, s4):
if not (x.ast == y.ast):
print(f"{x.ast} != {y.ast}")
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
def test_rand_bfloat16(self):
N = 128

View File

@ -426,7 +426,7 @@ class Tensor:
return r
_seed: int = int(time.time())
_device_seeds: Dict[str, int] = {}
_device_seeds: Dict[str, Tensor] = {}
_device_rng_counters: Dict[str, Tensor] = {}
@staticmethod
def manual_seed(seed=0):
@ -447,9 +447,8 @@ class Tensor:
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key0, key1, counts0, counts1):
def _threefry_random_bits(key, counts0, counts1):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1
x = F.Threefry.apply(*x._broadcasted(key))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@ -478,7 +477,9 @@ class Tensor:
# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
device=device, dtype=dtypes.uint64, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True
@ -487,12 +488,12 @@ class Tensor:
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
# increment rng counter for devices
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
# threefry random bits
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
counts1 = counts0 + ceildiv(num, 2)
bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num]
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
# bitcast to uint with same number of bits
_, nmant = dtypes.finfo(dtype)