diff --git a/test/helpers.py b/test/helpers.py index 15ee0c56..be2472b1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,5 @@ import sys +import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.device import Runner from tinygrad.dtype import DType @@ -38,3 +39,12 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if device == "PYTHON": return sys.version_info >= (3, 12) if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") return True + +def rand_for_dtype(dt:DType, size:int): + if dtypes.is_unsigned(dt): + return np.random.randint(0, 100, size=size, dtype=dt.np) + elif dtypes.is_int(dt): + return np.random.randint(-100, 100, size=size, dtype=dt.np) + elif dt == dtypes.bool: + return np.random.choice([True, False], size=size) + return np.random.uniform(-10, 10, size=size).astype(dt.np) diff --git a/test/test_dtype.py b/test/test_dtype.py index 216952cb..a9f0cbbf 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -6,7 +6,7 @@ from tinygrad.helpers import getenv, DEBUG from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype from tinygrad import Device, Tensor, dtypes from hypothesis import given, settings, strategies as strat -from test.helpers import is_dtype_supported +from test.helpers import is_dtype_supported, rand_for_dtype settings.register_profile("my_profile", max_examples=200, deadline=None) settings.load_profile("my_profile") @@ -59,15 +59,7 @@ class TestDType(unittest.TestCase): @classmethod def setUpClass(cls): if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") - DATA_SIZE = 10 - if dtypes.is_unsigned(cls.DTYPE): - cls.DATA = np.random.randint(0, 100, size=DATA_SIZE, dtype=cls.DTYPE.np) - elif dtypes.is_int(cls.DTYPE): - cls.DATA = np.random.randint(-100, 100, size=DATA_SIZE, dtype=cls.DTYPE.np) - elif cls.DTYPE == dtypes.bool: - cls.DATA = np.random.choice([True, False], size=DATA_SIZE) - else: - cls.DATA = np.random.uniform(-10, 10, size=DATA_SIZE).astype(cls.DTYPE.np) + cls.DATA = rand_for_dtype(cls.DTYPE, 10) def setUp(self): if self.DTYPE is None: raise unittest.SkipTest("base class")