remove CUDACPU flag in tests [run_process_replay] (#5902)

no longer used
This commit is contained in:
chenyu 2024-08-04 16:06:38 -04:00 committed by GitHub
parent 996ff0c135
commit 4a65010de8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 20 additions and 21 deletions

View File

@ -35,7 +35,7 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CUDACPU architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device == "GPU": return not CI and not OSX

View File

@ -202,7 +202,7 @@ class TestFloatDType(TestDType):
class TestDoubleDType(TestDType):
DTYPE = dtypes.double
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CUDACPU and PTX") # TODO: why not?
@unittest.skipIf((CI and Device.DEFAULT in {"CUDA", "NV"}) or getenv("PTX"), "conversion not supported on CI CUDA and PTX") # TODO: why not?
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),

View File

@ -39,8 +39,8 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (T
# TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated
#binary_operations += [(Tensor.maximum, np.maximum)]
# TODO: CUDACPU segfaults on sin
if getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"): unary_operations.remove((Tensor.sin, np.sin))
# TODO: CI CUDA segfaults on sin
if getenv("MOCKGPU") and Device.DEFAULT == "NV": unary_operations.remove((Tensor.sin, np.sin))
class ht:
float64 = strat.floats(width=64, allow_subnormal=False)
@ -144,8 +144,8 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, ht.float32, strat.sampled_from(integer_binary_operations), strat.sampled_from(binary_operations))
def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.int32, dtypes.float32)
# Metal and CUDACPU and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and (Device.DEFAULT in {"AMD", "NV"} or getenv("CUDACPU"))
# Metal and CUDA and HIP behave differently than numpy in CI for overflows
skip_overflow = CI and Device.DEFAULT in {"AMD", "NV"}
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
@ -165,8 +165,8 @@ class TestFromFuzzer(unittest.TestCase):
def test_sin(self, dtype):
if not is_dtype_supported(dtype): return
if dtype == dtypes.float64:
# crashes in CUDACPU
if (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")): return
# crashes in CI CUDA
if getenv("MOCKGPU") and Device.DEFAULT == "NV": return
def _test_value(n: float, unit: float=1.0):
next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
ulp = next_float - 1.0
@ -185,8 +185,8 @@ class TestFromFuzzer(unittest.TestCase):
def test_log2(self, dtype):
if not is_dtype_supported(dtype): return
if dtype == dtypes.float64:
# crashes in CUDACPU
if (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")): return
# crashes in CI CUDA
if getenv("MOCKGPU") and Device.DEFAULT == "NV": return
def _test_value(n: float, unit: float=1.0):
next_float = np.nextafter(1.0, 2.0, dtype=_to_np_dtype(dtype))
ulp = next_float - 1.0

View File

@ -77,7 +77,7 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
class TestOps(unittest.TestCase):
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, low=-1.5, high=1.5):
if getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"): self.skipTest('helper_test_exception fails in CUDACPU')
if getenv("MOCKGPU") and Device.DEFAULT == "NV": self.skipTest('helper_test_exception fails in CI CUDA')
ts, tst = prepare_test_op(low, high, shps, vals)
with self.assertRaises(expected) as torch_cm:
torch_fxn(*ts)
@ -559,15 +559,15 @@ class TestOps(unittest.TestCase):
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())
# works on real CUDA but not CUDACPU
if not (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")):
# works on real CUDA but not CI
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf]])
helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
if not (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")):
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
def test_tan(self):
@ -575,7 +575,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)
helper_test_op([(45,65)], lambda x: x.tan(), low=-5, high=5, forward_only=True)
helper_test_op([()], lambda x: x.tan())
if not (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")):
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)

View File

@ -108,7 +108,7 @@ def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
@unittest.skipIf(getenv("BIG") == 0, "no big tests")
@unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
@unittest.skipIf(getenv("MOCKGPU"), "no MOCKGPUs")
class TestBigSpeed(unittest.TestCase):
def test_add(self):
def f(a, b): return a+b
@ -129,7 +129,7 @@ class TestBigSpeed(unittest.TestCase):
def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096)
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
@unittest.skipIf(getenv("CUDACPU") or getenv("MOCKGPU"), "no CUDACPU or MOCKGPUs")
@unittest.skipIf(getenv("MOCKGPU"), "no MOCKGPUs")
class TestSpeed(unittest.TestCase):
def test_sub(self):
def f(a, b): return a-b

View File

@ -1,6 +1,5 @@
import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import getenv
from tinygrad.device import Buffer
from tinygrad.lazy import view_supported_devices
@ -41,7 +40,7 @@ class TestSubBuffer(unittest.TestCase):
out = (vt + 100).tolist()
assert out == [102, 103]
@unittest.skipIf(Device.DEFAULT not in {"CUDA", "NV", "AMD"} or getenv("CUDACPU"), "only NV, AMD, CUDA but not CUDACPU")
@unittest.skipIf(Device.DEFAULT not in {"CUDA", "NV", "AMD"}, "only NV, AMD, CUDA")
def test_subbuffer_transfer(self):
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
vt = t[2:5].contiguous().realize()

View File

@ -13,7 +13,7 @@ settings.load_profile("my_profile")
class TestTranscendentalMath(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.float64, Device.DEFAULT), f"no float64 on {Device.DEFAULT}")
@unittest.skipIf(getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"), "crashed")
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashed")
@given(ht.float64, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
def test_float64(self, x, op):
if op[0] == Tensor.sin:
@ -25,7 +25,7 @@ class TestTranscendentalMath(unittest.TestCase):
op[1](np.array([x], dtype=_to_np_dtype(dtypes.float64))),
atol=3e-2, rtol=1e-5) # sin can have bigger atol for very big x
@unittest.skipIf(getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV"), "crashed")
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashed")
@given(ht.float32, strat.sampled_from([(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin)]))
def test_float32(self, x, op):
with Context(TRANSCENDENTAL=2):