rand_for_dtype helper (#4459)

This commit is contained in:
qazal 2024-05-07 05:03:42 +08:00 committed by GitHub
parent a3140c9767
commit 35dfbc6354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import sys
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import Runner
from tinygrad.dtype import DType
@ -38,3 +39,12 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
def rand_for_dtype(dt:DType, size:int):
if dtypes.is_unsigned(dt):
return np.random.randint(0, 100, size=size, dtype=dt.np)
elif dtypes.is_int(dt):
return np.random.randint(-100, 100, size=size, dtype=dt.np)
elif dt == dtypes.bool:
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(dt.np)

View File

@ -6,7 +6,7 @@ from tinygrad.helpers import getenv, DEBUG
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device, Tensor, dtypes
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported
from test.helpers import is_dtype_supported, rand_for_dtype
settings.register_profile("my_profile", max_examples=200, deadline=None)
settings.load_profile("my_profile")
@ -59,15 +59,7 @@ class TestDType(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
DATA_SIZE = 10
if dtypes.is_unsigned(cls.DTYPE):
cls.DATA = np.random.randint(0, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
elif dtypes.is_int(cls.DTYPE):
cls.DATA = np.random.randint(-100, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
elif cls.DTYPE == dtypes.bool:
cls.DATA = np.random.choice([True, False], size=DATA_SIZE)
else:
cls.DATA = np.random.uniform(-10, 10, size=DATA_SIZE).astype(cls.DTYPE.np)
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")