From 6bbbeb93ac4cd1e669789bbfea11b0d35f373b00 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 10 Apr 2024 02:00:34 -0400 Subject: [PATCH] skip a few clang test that took > 30 seconds in CI (#4126) * skip slow CLANG test test_train_cifar * skip those too * and that * only CI * one more --- test/models/test_mnist.py | 4 +++- test/models/test_real_world.py | 6 ++---- test/test_fuzz_shape_ops.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/models/test_mnist.py b/test/models/test_mnist.py index 1c1179c3..e51555f5 100644 --- a/test/models/test_mnist.py +++ b/test/models/test_mnist.py @@ -1,8 +1,9 @@ #!/usr/bin/env python import unittest import numpy as np +from tinygrad import Tensor, Device +from tinygrad.helpers import CI from tinygrad.nn.state import get_parameters -from tinygrad.tensor import Tensor from tinygrad.nn import optim, BatchNorm2d from extra.training import train, evaluate from extra.datasets import fetch_mnist @@ -48,6 +49,7 @@ class TinyConvNet: x = x.reshape(shape=[x.shape[0], -1]) return x.dot(self.l1) +@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow") class TestMNIST(unittest.TestCase): def test_sgd_onestep(self): np.random.seed(1337) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index aee49587..2b6e42b5 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -64,7 +64,6 @@ class TestRealWorld(unittest.TestCase): return t.realize() helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43) - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") def test_llama(self): dtypes.default_float = dtypes.float16 @@ -89,6 +88,7 @@ class TestRealWorld(unittest.TestCase): def test(t, v): return model(t, v).realize() helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 164 if CI else 468, all_jitted=True) + @unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow") def test_train_mnist(self): from examples.beautiful_mnist import Model with Tensor.train(): @@ -106,8 +106,7 @@ class TestRealWorld(unittest.TestCase): helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127) - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") - @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") + @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU"}, "slow") def test_train_cifar(self): with Tensor.train(): model = SpeedyResNet(Tensor.ones((12,3,2,2))) @@ -125,7 +124,6 @@ class TestRealWorld(unittest.TestCase): helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 142 if CI else 154) # it's 154 on metal - @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16") def test_train_cifar_hyp(self): dtypes.default_float = dtypes.float16 diff --git a/test/test_fuzz_shape_ops.py b/test/test_fuzz_shape_ops.py index d86d6547..8f20efc5 100644 --- a/test/test_fuzz_shape_ops.py +++ b/test/test_fuzz_shape_ops.py @@ -6,7 +6,7 @@ from hypothesis.extra import numpy as stn import numpy as np import torch -import tinygrad +from tinygrad import Tensor, Device from tinygrad.helpers import CI @@ -26,9 +26,9 @@ def st_shape(draw) -> tuple[int, ...]: return s -def tensors_for_shape(s:tuple[int, ...]) -> tuple[torch.tensor, tinygrad.Tensor]: +def tensors_for_shape(s:tuple[int, ...]) -> tuple[torch.tensor, Tensor]: x = np.arange(prod(s)).reshape(s) - return torch.from_numpy(x), tinygrad.Tensor(x) + return torch.from_numpy(x), Tensor(x) def apply(tor, ten, tor_fn, ten_fn=None): ok = True @@ -38,7 +38,7 @@ def apply(tor, ten, tor_fn, ten_fn=None): except: ten, ok = None, not ok # noqa: E722 return tor, ten, ok - +@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow") class TestShapeOps(unittest.TestCase): @settings.get_profile(__file__) @given(st_shape(), st_int32, st.one_of(st_int32, st.lists(st_int32)))