Remove pytest markers (#2831)

* remove pytest marker

* fix some, skip some

* tweak

* fix

* skip slow

* skip more
This commit is contained in:
chenyu 2023-12-18 18:53:28 -05:00 committed by GitHub
parent 264fe9c93f
commit 73cadfbb3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 22 additions and 40 deletions

View File

@ -414,13 +414,13 @@ jobs:
DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
- name: Run pytest (not cuda)
if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton'
run: python -m pytest -n=auto test/ -m 'not exclude_${{matrix.backend}}' --durations=20
run: python -m pytest -n=auto test/ --durations=20
- name: Run ONNX (only LLVM)
if: matrix.backend == 'llvm'
run: python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Run pytest (cuda)
if: matrix.backend=='cuda'||matrix.backend=='ptx'||matrix.backend=='triton'
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' -m 'not exclude_cuda' --ignore=test/external --ignore=test/models --durations=20
run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' --ignore=test/external --ignore=test/models --durations=20
#testunicorn:
# name: ARM64 unicorn Test

View File

@ -3,7 +3,7 @@ import gc
from tinygrad.helpers import prod
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.runtime.ops_gpu import CLBuffer
from tinygrad.device import Buffer
from tinygrad.helpers import GlobalCounters
def print_objects():
@ -11,7 +11,7 @@ def print_objects():
tensors = [x for x in gc.get_objects() if isinstance(x, Tensor)]
tensor_ram_used = sum([prod(x.shape)*4 for x in tensors])
lazybuffers = [x for x in gc.get_objects() if isinstance(x, LazyBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, CLBuffer)]
gpubuffers = [x for x in gc.get_objects() if isinstance(x, Buffer)]
realized_buffers = [x.realized for x in lazybuffers if x.realized]
gpubuffers_orphaned = [x for x in gpubuffers if x not in realized_buffers]

View File

@ -1,6 +0,0 @@
[pytest]
markers =
exclude_cuda
exclude_gpu
exclude_clang
onnx_coverage

View File

@ -7,9 +7,6 @@ from tinygrad.nn.optim import Adam
from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
from extra.training import train, evaluate
from extra.datasets import fetch_mnist
import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu]
np.random.seed(1337)
Tensor.manual_seed(1337)

View File

@ -4,11 +4,12 @@ import math, unittest, random, copy
# import warnings
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad import Tensor, dtypes, Device
# from tinygrad import TinyJit
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.helpers import CI
random.seed(42)
@ -187,6 +188,7 @@ class TestIndexing(unittest.TestCase):
# def delitem(): del reference[0]
# self.assertRaises(TypeError, delitem)
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "GPU"], "slow")
def test_advancedindex(self):
# integer array indexing

View File

@ -6,9 +6,6 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import optim, BatchNorm2d
from extra.training import train, evaluate
from extra.datasets import fetch_mnist
import pytest
pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
# load the mnist dataset
X_train, Y_train, X_test, Y_test = fetch_mnist()

View File

@ -7,9 +7,6 @@ import onnx
from extra.onnx import get_run_onnx
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI, fetch, temp
import pytest
pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
def run_onnx_torch(onnx_model, inputs):
import torch
@ -99,6 +96,7 @@ class TestOnnxModel(unittest.TestCase):
print(tinygrad_out, torch_out)
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
@unittest.skip("slow")
def test_efficientnet(self):
input_name, input_new = "images:0", True
self._test_model(fetch("https://github.com/onnx/models/raw/main/archive/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) # noqa: E501

View File

@ -4,16 +4,13 @@ import numpy as np
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Device
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, CI
from extra.training import train
from extra.models.convnext import ConvNeXt
from extra.models.efficientnet import EfficientNet
from extra.models.transformer import Transformer
from extra.models.vit import ViT
from extra.models.resnet import ResNet18
import pytest
pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
BS = getenv("BS", 2)
@ -42,6 +39,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(CI, "slow")
def test_efficientnet(self):
model = EfficientNet(0)
X = np.zeros((BS,3,224,224), dtype=np.float32)
@ -49,6 +47,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(CI, "slow")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_vit(self):
model = ViT()
@ -66,6 +65,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(CI, "slow")
def test_resnet(self):
X = np.zeros((BS, 3, 224, 224), dtype=np.float32)
Y = np.zeros((BS), dtype=np.int32)

View File

@ -2,12 +2,10 @@
import time
import unittest
import torch
from tinygrad.tensor import Tensor
from tinygrad.helpers import Profiling
import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
from tinygrad import Tensor, Device
from tinygrad.helpers import Profiling, CI
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
class TestConvSpeed(unittest.TestCase):
def test_mnist(self):

View File

@ -6,10 +6,8 @@ from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
import torch
import pytest
pytestmark = [pytest.mark.exclude_cuda]
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
class TestNN(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU")
def test_sparse_cat_cross_entropy(self):

View File

@ -1,11 +1,9 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
from tinygrad import Tensor, Device
from tinygrad.nn.optim import Adam, SGD, AdamW
import pytest
pytestmark = pytest.mark.exclude_cuda
from tinygrad.helpers import CI
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
@ -35,6 +33,7 @@ def step(tensor, optim, steps=1, kwargs={}):
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
class TestOptim(unittest.TestCase):
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):

View File

@ -15,9 +15,6 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI
from tinygrad.jit import TinyJit
import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang]
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
@ -123,6 +120,7 @@ def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
@unittest.skipIf(getenv("BIG") == 0, "no big tests")
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
class TestBigSpeed(unittest.TestCase):
def test_add(self):
def f(a, b): return a+b
@ -143,6 +141,7 @@ class TestBigSpeed(unittest.TestCase):
def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
@unittest.skipIf(getenv("CUDACPU"), "no CUDACPU")
class TestSpeed(unittest.TestCase):
def test_sub(self):
def f(a, b): return a-b

View File

@ -22,7 +22,7 @@ class TestZeroCopy(unittest.TestCase):
t2 = time_tensor_numpy(out)
gbps = out.nbytes()*1e-9/max(t2-t1, 1e-10)
print(f"time(base): {t1*1e3:.2f} ms, time(copy): {t2*1e3:.2f} ms : copy speed {gbps:.2f} GB/s")
self.assertGreater(gbps, 1000) # more than 1000 GB/s = no copy
self.assertGreater(gbps, 600) # more than 600 GB/s = no copy
if __name__ == '__main__':
unittest.main(verbosity=2)