diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 250cedf0..0f90ba74 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -228,6 +228,9 @@ jobs: - if: ${{ matrix.task == 'onnx' }} name: Test MLPerf optimizers run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py --durations=20 + - if: ${{ matrix.task == 'onnx' }} + name: Test THREEFRY + run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py #testwebgpu: # name: WebGPU Tests diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index 14d8c356..300d2bc6 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -6,6 +6,10 @@ from tinygrad.shape.symbolic import Variable from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d +def rand(*shape): + if CI: return Tensor(np.random.rand(*shape)) + return Tensor.rand(*shape) + class TestBeamSearch(unittest.TestCase): def setUp(self): self.old_beam = BEAM.value @@ -14,44 +18,44 @@ class TestBeamSearch(unittest.TestCase): BEAM.value = self.old_beam def test_variable_ast_beam(self): - a = Tensor.rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) + a = rand(3, 3).reshape((Variable("a", 1, 10).bind(3), 3)) a = (a+1).realize() - def test_big_prime_number_matmul(self): - a = Tensor.rand(367, 367) - b = Tensor.rand(367, 367) + def test_big_prime_number(self): + a = rand(367, 367) + b = rand(367, 367) c = (a@b).realize() np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) def test_big_prime_number_max(self): - a = -Tensor.rand(367, 367) - b = Tensor.rand(367, 367) + a = -rand(367, 367) + b = rand(367, 367) # if incorrectly padded 0, the max would be 0 instead of a negative number c = (a*b).max(1) np.testing.assert_allclose(c.numpy(), (a.numpy() * b.numpy()).max(1), atol=1e-4, rtol=1e-4) def test_big_prime_number_sum(self): - a = Tensor.rand(367, 367) - b = Tensor.rand(367, 367) + a = rand(367, 367) + b = rand(367, 367) # if incorrectly padded 0, the sum would be inf c = (a/b).sum(1).realize() np.testing.assert_allclose(c.numpy(), (a.numpy() / b.numpy()).sum(1), atol=1e-4, rtol=1e-4) def test_variable_big_prime_number(self): v = Variable("v", 1, 400).bind(367) - a = Tensor.rand(367, 367) - b = Tensor.rand(367, 367) + a = rand(367, 367) + b = rand(367, 367) c = (a.reshape(367, v) @ b.reshape(v, 367)).realize() np.testing.assert_allclose(c.numpy(), a.numpy() @ b.numpy(), atol=1e-4, rtol=1e-4) def test_variable_shrink_prime_number(self): v = Variable("v", 1, 400).bind(367) - a = Tensor.rand(400, 367) + a = rand(400, 367) b = (a.shrink(((0,v), None))+1).reshape(367,367).realize() np.testing.assert_allclose(b.numpy(), a.numpy()[:367]+1, atol=1e-4, rtol=1e-4) def test_no_mutate_rawbuffers(self): - a = Tensor.rand(3, 3).realize() + a = rand(3, 3).realize() desired = a.numpy() + 1 a.assign(a+1) actual = a.numpy() @@ -60,7 +64,7 @@ class TestBeamSearch(unittest.TestCase): @unittest.skipIf(CI, "flaky. CL_OUT_OF_RESOURCES") def test_conv_beam(self): c = Conv2d(3, 16, (3,3)) - x = Tensor.rand(1,3,32,32) + x = rand(1,3,32,32) with Timing(): c(x).realize() diff --git a/test/helpers.py b/test/helpers.py index ab198d5d..74392baf 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -18,6 +18,8 @@ def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache) == expected_len else: assert len(fxn.jit_cache) == 1 + # until we have a better way of typing the prg in JitItem + assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): @@ -33,4 +35,4 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if device in ["GPU", "LLVM", "CUDA"]: return not CI if device == "PYTHON": return sys.version_info >= (3, 12) if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") - return True \ No newline at end of file + return True diff --git a/test/test_gc.py b/test/test_gc.py index 49773dd9..357c41dc 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -10,28 +10,28 @@ def tensors_allocated(): class TestGC(unittest.TestCase): def test_gc(self): - a = Tensor.zeros(4, 4, requires_grad=True) + a = Tensor.rand(4, 4, requires_grad=True) b = Tensor.zeros(4, 4, requires_grad=True) (a*b).mean().backward() assert(tensors_allocated() > 0) del a,b - assert(tensors_allocated() == 0) + assert(tensors_allocated() == 1) # one for Tensor._rng_counter def test_gc_complex(self): a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) - b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) - assert(tensors_allocated() == 2) + b = Tensor.rand(4, 4, requires_grad=True) + assert(tensors_allocated() == 3) (a*b).mean().backward() - assert(tensors_allocated() == 4) + assert(tensors_allocated() == 5) del b - assert(tensors_allocated() == 2) + assert(tensors_allocated() == 3) b = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) print(tensors_allocated()) (a*b).mean().backward() print(tensors_allocated()) - assert(tensors_allocated() == 4) + assert(tensors_allocated() == 5) del b - assert(tensors_allocated() == 2) + assert(tensors_allocated() == 3) if __name__ == '__main__': unittest.main() diff --git a/test/test_jit.py b/test/test_jit.py index b900e1eb..a28433c4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -182,7 +182,7 @@ class TestJit(unittest.TestCase): a = Tensor.randn(10, 10).realize() # realize these before resetting the random seed b = Tensor.randn(10, 10).realize() - Tensor._seed = 1234 + Tensor.manual_seed(1234) jf = TinyJit(f) res = set() for _ in range(5): @@ -190,7 +190,7 @@ class TestJit(unittest.TestCase): res.add(o1.numpy()[0][0]) assert len(res) == 5, "All values should be different, rand works in jit." - Tensor._seed = 1234 + Tensor.manual_seed(1234) jf2 = TinyJit(f) res2 = set() for _ in range(5): @@ -199,7 +199,7 @@ class TestJit(unittest.TestCase): assert len(res2) == 5, "All values should be different, rand works in jit." assert res == res2, "Jit rand is not reproducible with the same seed" - Tensor._seed = 3421 + Tensor.manual_seed(3421) jf3 = TinyJit(f) res3 = set() for _ in range(5): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7ab58c12..e9e6e88b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -199,7 +199,7 @@ class TestLinearizer(unittest.TestCase): np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol) def test_limit_dims_to_max_5d_global(self): - t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 + t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] assert len(sched) == 1 lin = Linearizer(*sched[0].ast) @@ -740,7 +740,7 @@ class TestLinearizerOpts(unittest.TestCase): def test_padto_where(self): N = 17 * 17 - a = (Tensor.rand(N, N).max(axis=0, keepdim=True) > 1).where(1, 0) + a = (Tensor.empty(N, N).max(axis=0, keepdim=True) > 1).where(1, 0) helper_linearizer_opt(a.max(0), [ [Opt(OptOps.PADTO, 0, 32)], [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),], diff --git a/test/test_randomness.py b/test/test_randomness.py index 2d04b446..d0864420 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -5,6 +5,7 @@ from functools import partial import numpy as np import torch from tinygrad import nn, dtypes, Tensor +from tinygrad.helpers import THREEFRY from test.helpers import is_dtype_supported # https://gist.github.com/devries/11405101 @@ -41,7 +42,7 @@ def kstest(l1, l2): prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d) return prob -def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.05): +def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.04): Tensor.manual_seed(1337) torch.manual_seed(1337) np.random.seed(1337) @@ -60,6 +61,7 @@ class TestRandomness(unittest.TestCase): self.assertFalse(normal_test(Tensor.rand)) self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x))) + @unittest.skipIf(THREEFRY.value, "broken with threefry") def test_rand_half(self): N = 128 x = Tensor.rand((2, N, N), dtype=dtypes.half) @@ -71,6 +73,16 @@ class TestRandomness(unittest.TestCase): self.assertTrue(zeros.size > 0) equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + @unittest.skipIf(not THREEFRY.value, "not using threefry") + def test_threefly_against_reference(self): + Tensor.manual_seed(1337) + # generated using + # (jax.extend.random.threefry_2x32((np.uint32(1337), np.uint32(0x0)), np.arange(20, dtype=np.uint32)) >> 8).astype(float) / np.float32(2**24) + jr = np.array([0.30984968, 0.42723763, 0.92448753, 0.27268296, 0.48820806, 0.29587173, 0.3213513, 0.05805135, 0.4954177, 0.23303074, + 0.62478125, 0.51861334, 0.24712527, 0.12718695, 0.5236074, 0.50704265, 0.9166272, 0.6918763, 0.6530086, 0.34640658]) + r = Tensor.rand(20).numpy() + np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5) + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support") def test_rand_bfloat16(self): N = 128 @@ -115,16 +127,10 @@ class TestRandomness(unittest.TestCase): lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:]))))) def test_kaiming_uniform(self): - Tensor.manual_seed(1337) - torch.manual_seed(1337) - np.random.seed(1337) for shape in [(128, 64, 3, 3), (20, 24)]: self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) def test_kaiming_normal(self): - Tensor.manual_seed(1337) - torch.manual_seed(1337) - np.random.seed(1337) for shape in [(128, 64, 3, 3), (20, 24)]: self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape)) diff --git a/test/test_schedule.py b/test/test_schedule.py index a3c65904..609a6db8 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -181,7 +181,7 @@ class TestSchedule(unittest.TestCase): # run img = Tensor.rand(2,3,64,64) out = c1(img).elu() - check_schedule(out, 1, [c1.weight, c1.bias]) + check_schedule(out, 1, [c1.weight, c1.bias, img]) def test_two_sum(self): img = Tensor.empty(64,64) @@ -336,7 +336,7 @@ class TestSchedule(unittest.TestCase): out = bn1(conv1(x)).relu() out = bn2(conv2(out)) out = (out + x).relu() - check_schedule(out, 4) + check_schedule(out, 2, [conv1.weight, conv2.weight]) def test_contiguous_while_contiguous(self): x = Tensor.empty(1, 64, 32, 32) diff --git a/test/test_tensor.py b/test/test_tensor.py index 05097ae2..6da4d176 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -378,17 +378,17 @@ class TestMoveTensor(unittest.TestCase): class TestZeroShapeTensor(unittest.TestCase): def test_shape_stride(self): - t = Tensor.rand(3, 2, 0) + t = Tensor.empty(3, 2, 0) assert t.shape == (3, 2, 0) # numpy has stride 0, 0, 0; torch has stride 2, 1, 1 assert t.lazydata.st.real_strides() == (0, 0, 1) - t = Tensor.rand(3, 0, 2) + t = Tensor.empty(3, 0, 2) assert t.shape == (3, 0, 2) # numpy has stride 0, 0, 0; torch has stride 2, 2, 1 assert t.lazydata.st.real_strides() == (0, 2, 1) - t = Tensor.rand(0, 0, 0) + t = Tensor.empty(0, 0, 0) assert t.shape == (0, 0, 0) # numpy has stride 0, 0, 0; torch has stride 1, 1, 1 assert t.lazydata.st.real_strides() == (0, 0, 1) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 50e08c28..0e95d98c 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -120,7 +120,7 @@ class TestSafetensors(unittest.TestCase): for dtype in dtypes.fields().values(): if dtype in [dtypes.bfloat16]: continue # not supported in numpy path = temp(f"ones.{dtype}.safetensors") - ones = Tensor.rand((10,10), dtype=dtype) + ones = Tensor(np.random.rand(10,10), dtype=dtype) safe_save(get_state_dict(ones), path) np.testing.assert_equal(ones.numpy(), list(safe_load(path).values())[0].numpy()) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index e029253a..b4d9a854 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -41,7 +41,7 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz tms = [] for _ in range(cnt): if clear_l2: - with Context(DEBUG=0): Tensor.rand(1024,1024).realize() + with Context(DEBUG=0, BEAM=0): Tensor.ones(1024,1024).contiguous().realize() tms.append(cast(float, car(rawbufs, var_vals, wait=True, do_update_stats=False))*factor) if early_stop is not None and early_stop < tms[-1]: break return tms diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 47bdff2d..9465deb4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -94,7 +94,8 @@ class ContextVar: def __gt__(self, x): return self.value > x def __lt__(self, x): return self.value < x -DEBUG, IMAGE, WINO, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("WINO", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) +DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) +WINO, THREEFRY = ContextVar("WINO", 0), ContextVar("THREEFRY", 0) GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net") class Timing(contextlib.ContextDecorator): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 034368d8..b7d2b449 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,11 +7,12 @@ from collections import defaultdict import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar -from tinygrad.helpers import argfix, make_pair, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv +from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv +from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps -from tinygrad.device import Device, Buffer +from tinygrad.device import Buffer, Device from tinygrad.shape.symbolic import sint from tinygrad.realize import run_schedule, create_schedule @@ -212,7 +213,7 @@ class Tensor: # ***** creation llop entrypoint ***** @staticmethod - def _loadop(op, shape, device:Optional[Union[Tuple[str], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs): + def _loadop(op, shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs): if isinstance(device, tuple): return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \ for d in device], None), device, dtype, **kwargs) @@ -222,16 +223,34 @@ class Tensor: def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs) _seed: int = int(time.time()) + _rng_counter: Optional[Tensor] = None @staticmethod - def manual_seed(seed=0): Tensor._seed = seed + def manual_seed(seed=0): Tensor._seed, Tensor._rng_counter = seed, Tensor([0], dtype=dtypes.uint32, requires_grad=False) @staticmethod - def rand(*shape, **kwargs): - if kwargs.get("dtype") == dtypes.bfloat16: - # TODO: remove this once we use threefry for rand. - kwargs.pop("dtype") - return Tensor.rand(*shape, **kwargs, dtype=dtypes.float).cast(dtypes.bfloat16) - return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs) + def rand(*shape, device:Optional[Union[Tuple[str, ...], str]]=None, dtype:Optional[DType]=None, **kwargs): + if Tensor._rng_counter is None: Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False) + if not THREEFRY.value: + if dtype == dtypes.bfloat16: + return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16) + return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs) + + # threefry + if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) + counts = (Tensor.arange(num, device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)).realize().pad(((0,num%2),)) + Tensor._rng_counter.assign(Tensor._rng_counter + num).realize() + + rotations = [[13, 15, 26, 6], [17, 29, 16, 24]] + ks = [0x0, Tensor._seed ^ 0x0 ^ 0x1BD11BDA, Tensor._seed] + + x = [(c := counts.chunk(2))[0] + ks[-1], c[1] + ks[0]] + for i in range(5): + for r in rotations[i % 2]: x[0], x[1] = (x0 := x[0] + x[1]), x0 ^ ((x[1] * (2 ** r)) + (x[1].div(2 ** (32 - r), upcast=False))) + x = [(x[0] + ks[i % 3]), (x[1] + ks[(i + 1) % 3] + i + 1)] + out = x[0].cat(x[1])[:num].div(2 ** 8, upcast=False).cast(dtypes.float32).div(2 ** 24) + out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype) + out.requires_grad = kwargs.get("requires_grad") + return out.contiguous() # ***** creation helper functions ***** @@ -248,6 +267,7 @@ class Tensor: @staticmethod def arange(start, stop=None, step=1, **kwargs): if stop is None: stop, start = start, 0 + assert all(isinstance(s, (int, float)) for s in (start, stop, step)), "symbolic arange not supported" dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) @@ -866,10 +886,10 @@ class Tensor: if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self) if not isinstance(x, Tensor) and x == -1.0: return -self return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self - def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + def div(self, x:Union[Tensor, Scalar], reverse=False, upcast=True) -> Tensor: x = self._to_const_val(x) - if not isinstance(x, Tensor) and not reverse and x != 0: return self.mul(1/x) - if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse)) + if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x) + if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return mlops.Div.apply(*self._broadcasted(x, reverse)) return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse)) def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse)) @@ -1022,7 +1042,7 @@ if IMAGE: setattr(Tensor, "conv2d", image_conv2d) setattr(Tensor, "dot", image_dot) -# TODO: remove the custom op and replace with threefry +# TODO: eventually remove this def custom_random(out:Buffer): Tensor._seed += 1 if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}")