mirror of https://github.com/commaai/tinygrad.git
seed in tensor (#6869)
This commit is contained in:
parent
f9e32f2bb2
commit
9eb6eef441
|
@ -21,24 +21,26 @@ class TestGC(unittest.TestCase):
|
|||
(a*b).mean().backward()
|
||||
assert (tensors_allocated() > 0)
|
||||
del a,b
|
||||
assert (tensors_allocated() == 1) # one for Tensor._device_rng_counters
|
||||
assert (tensors_allocated() == 2) # one for Tensor._device_rng_counters, and one for Tensor._device_seeds
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
def test_gc_complex(self):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert (tensors_allocated() == 4)
|
||||
(a*b).mean().backward()
|
||||
assert (tensors_allocated() == 5)
|
||||
(a*b).mean().backward()
|
||||
assert (tensors_allocated() == 6)
|
||||
del b
|
||||
assert (tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 4)
|
||||
b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
print(tensors_allocated())
|
||||
(a*b).mean().backward()
|
||||
print(tensors_allocated())
|
||||
assert (tensors_allocated() == 5)
|
||||
assert (tensors_allocated() == 6)
|
||||
del b
|
||||
assert (tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 4)
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
def test_schedule_gc(self):
|
||||
init = bufs_allocated()
|
||||
|
|
|
@ -307,6 +307,14 @@ class TestJit(unittest.TestCase):
|
|||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
Tensor.manual_seed(1234)
|
||||
Tensor.rand()
|
||||
res = [f().numpy() for _ in range(3)]
|
||||
assert res[1] != res[2]
|
||||
|
||||
def test_jit_realization_and_sampling(self):
|
||||
w = Tensor.eye(5)
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ class TestRandomness(unittest.TestCase):
|
|||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
|
||||
def test_threefly_against_reference(self):
|
||||
def test_threefry_against_reference(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
# reference generated using
|
||||
|
@ -92,11 +92,11 @@ class TestRandomness(unittest.TestCase):
|
|||
|
||||
counts = Tensor.arange(20, dtype=dtypes.uint32)
|
||||
counts0, counts1 = counts.chunk(2)
|
||||
r = Tensor._threefry_random_bits(1337, 0, counts0, counts1).numpy()
|
||||
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
|
||||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
def test_threefly_against_reference_full(self):
|
||||
def test_threefry_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
# reference generated using
|
||||
|
@ -118,7 +118,7 @@ class TestRandomness(unittest.TestCase):
|
|||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
|
||||
def test_threefly_tensors_cnt(self):
|
||||
def test_threefry_tensors_cnt(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
Tensor.rand(20).realize()
|
||||
|
@ -136,6 +136,31 @@ class TestRandomness(unittest.TestCase):
|
|||
assert len(Tensor._device_rng_counters) == 0
|
||||
assert len(Tensor._device_seeds) == 0
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
|
||||
def test_threefry_same_kernels(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
Tensor.rand(1).realize()
|
||||
|
||||
s = Tensor.rand(20).schedule()
|
||||
s2 = Tensor.rand(20).schedule()
|
||||
|
||||
assert len(s) == len(s2), f"{len(s)} != {len(s2)}"
|
||||
for x,y in zip(s, s2):
|
||||
if not (x.ast == y.ast):
|
||||
print(f"{x.ast} != {y.ast}")
|
||||
|
||||
Tensor.rand(1, device=f"{Device.DEFAULT}:1").realize()
|
||||
|
||||
s3 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
|
||||
s4 = Tensor.rand(20, device=f"{Device.DEFAULT}:1").schedule()
|
||||
|
||||
assert len(s3) == len(s4), f"{len(s3)} != {len(s4)}"
|
||||
assert len(s2) == len(s4), f"{len(s)} != {len(s3)}"
|
||||
for x,y in zip(s3, s4):
|
||||
if not (x.ast == y.ast):
|
||||
print(f"{x.ast} != {y.ast}")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
|
||||
def test_rand_bfloat16(self):
|
||||
N = 128
|
||||
|
|
|
@ -426,7 +426,7 @@ class Tensor:
|
|||
return r
|
||||
|
||||
_seed: int = int(time.time())
|
||||
_device_seeds: Dict[str, int] = {}
|
||||
_device_seeds: Dict[str, Tensor] = {}
|
||||
_device_rng_counters: Dict[str, Tensor] = {}
|
||||
@staticmethod
|
||||
def manual_seed(seed=0):
|
||||
|
@ -447,9 +447,8 @@ class Tensor:
|
|||
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
||||
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key0, key1, counts0, counts1):
|
||||
def _threefry_random_bits(key, counts0, counts1):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
key = (Tensor([key0], device=x.device, dtype=dtypes.uint64, requires_grad=False) << 32) | key1
|
||||
x = F.Threefry.apply(*x._broadcasted(key))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
return counts0.cat(counts1)
|
||||
|
@ -478,7 +477,9 @@ class Tensor:
|
|||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff
|
||||
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
|
||||
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
|
||||
device=device, dtype=dtypes.uint64, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
had_counter = False
|
||||
else: had_counter = True
|
||||
|
@ -487,12 +488,12 @@ class Tensor:
|
|||
if (num := ceildiv(((num_ := prod(shape)) * dtype.itemsize), 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
|
||||
# increment rng counter for devices
|
||||
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
||||
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num).contiguous()
|
||||
|
||||
# threefry random bits
|
||||
counts0 = (Tensor.arange(ceildiv(num, 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._device_rng_counters[device])
|
||||
counts1 = counts0 + ceildiv(num, 2)
|
||||
bits = Tensor._threefry_random_bits(Tensor._seed, Tensor._device_seeds[device], counts0, counts1)[:num]
|
||||
bits = Tensor._threefry_random_bits(Tensor._device_seeds[device], counts0, counts1)[:num]
|
||||
|
||||
# bitcast to uint with same number of bits
|
||||
_, nmant = dtypes.finfo(dtype)
|
||||
|
|
Loading…
Reference in New Issue