diff --git a/test/models/test_cifar.py b/test/models/test_cifar.py deleted file mode 100644 index 3da04365..00000000 --- a/test/models/test_cifar.py +++ /dev/null @@ -1,43 +0,0 @@ -import unittest -from tinygrad.tensor import Tensor -from tinygrad.nn import optim -from tinygrad.state import get_parameters -from examples.hlb_cifar10 import SpeedyResNet -from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE -from tinygrad.ops import GlobalCounters -from tinygrad.lazy import Device - -class TestCifar(unittest.TestCase): - @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") - def test_train_step(self): - # TODO: with default device - #old_default = Device.DEFAULT - #Device.DEFAULT = "FAKE" - #Device['fake'].codegen = Device[old_default].codegen - - # TODO: with train - old_training = Tensor.training - Tensor.training = True - - model = SpeedyResNet() - optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) - - @TinyJit - def train(X): - out = model(X) - loss = out.mean() - optimizer.zero_grad() - loss.backward() - optimizer.step() - for _ in range(3): train(Tensor.randn(32, 3, 32, 32)) - - print(f"used {GlobalCounters.mem_used/1e9} GB and {len(train.jit_cache)} kernels") - assert GlobalCounters.mem_used/1e9 < 0.55, "CIFAR training used more than 0.55 GB" - assert len(train.jit_cache) <= 236, "CIFAR training used more than 236 kernels" - - # reset device - Tensor.training = old_training - #Device.DEFAULT = old_default - -if __name__ == '__main__': - unittest.main() diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py new file mode 100644 index 00000000..1ec83df0 --- /dev/null +++ b/test/models/test_real_world.py @@ -0,0 +1,97 @@ +import unittest, time +from tinygrad.tensor import Tensor +from tinygrad.nn import optim +from tinygrad.state import get_parameters +from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE +from tinygrad.ops import GlobalCounters, LazyOp, LoadOps +from tinygrad.lazy import Device +from tinygrad.helpers import CI, dtypes + +from examples.hlb_cifar10 import SpeedyResNet +from examples.llama import Transformer, args_7B +from examples.stable_diffusion import UNetModel + +def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed): + tms = [] + for _ in range(4): + GlobalCounters.reset() + Device[Device.DEFAULT].synchronize() + st = time.perf_counter_ns() + train(*gen()) + Device[Device.DEFAULT].synchronize() + tms.append(time.perf_counter_ns() - st) + + kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None + print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms") + assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB" + assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels" + +# for speed +def derandomize(x): + if isinstance(x, LazyOp): + if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY + x.src = [derandomize(s) for s in x.src] + else: + x.op = derandomize(x.op) + return x + +def derandomize_model(model): + for p in get_parameters(model): + p.lazydata = derandomize(p.lazydata) + p.realize() + +class TestRealWorld(unittest.TestCase): + @unittest.skipUnless(not CI, "too big for CI") + def test_stable_diffusion(self): + model = UNetModel() + derandomize_model(model) + @TinyJit + def test(t, t2): return model(t, 801, t2).realize() + helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.04, 912) + + @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") + def test_llama(self): + old_type = Tensor.default_type + Tensor.default_type = dtypes.float16 + + args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000} + model = Transformer(**(args_tiny if CI else args_7B)) + derandomize_model(model) + @TinyJit + def test(t): return model(t, 0).realize() + helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 126 if CI else 486) + + Tensor.default_type = old_type + + @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT") + def test_train_cifar(self): + # TODO: with default device + #old_default = Device.DEFAULT + #Device.DEFAULT = "FAKE" + #Device['fake'].codegen = Device[old_default].codegen + + # TODO: with train + old_training = Tensor.training + Tensor.training = True + + model = SpeedyResNet() + optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15) + + BS = 32 if CI else 512 + + @TinyJit + def train(X): + out = model(X) + loss = out.mean() + optimizer.zero_grad() + loss.backward() + optimizer.step() + + helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (0.55/32)*BS, 236) + + # reset device + Tensor.training = old_training + #Device.DEFAULT = old_default + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index c2f604b7..06a63f3a 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -125,6 +125,7 @@ class TestBigSpeed(unittest.TestCase): helper_test_generic_square('gemm', 4096, f, f) def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128) def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130) + def test_large_conv_5x5(self): helper_test_conv(bs=16, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130) @unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests") class TestSpeed(unittest.TestCase): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 2f89a55a..a0060b21 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -9,6 +9,7 @@ from math import prod # noqa: F401 # pylint:disable=unused-import ShapeType = Tuple[int, ...] # NOTE: helpers is not allowed to import from anything else in tinygrad OSX = platform.system() == "Darwin" +CI = os.getenv("CI", "") != "" def dedup(x): return list(dict.fromkeys(x)) # retains list order def argfix(*x): diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index cf68b506..3c8fcda9 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -362,7 +362,7 @@ def _realize_empty(buffer: LazyBuffer) -> None: def _realize_rand(buffer: LazyBuffer) -> None: rng = np.random.default_rng(buffer.op.arg) - buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore + buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # type: ignore def _realize_const(buffer: LazyBuffer) -> None: if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'): diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 37941ec2..2100dc4f 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -31,7 +31,7 @@ def einsum_mulacc(einsum, get_strides, expand): return mulacc numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{ - UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin, + UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin, BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)), BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)), BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,