mirror of https://github.com/commaai/tinygrad.git
rand_for_dtype helper (#4459)
This commit is contained in:
parent
a3140c9767
commit
35dfbc6354
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import Runner
|
||||
from tinygrad.dtype import DType
|
||||
|
@ -38,3 +39,12 @@ def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
|||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
def rand_for_dtype(dt:DType, size:int):
|
||||
if dtypes.is_unsigned(dt):
|
||||
return np.random.randint(0, 100, size=size, dtype=dt.np)
|
||||
elif dtypes.is_int(dt):
|
||||
return np.random.randint(-100, 100, size=size, dtype=dt.np)
|
||||
elif dt == dtypes.bool:
|
||||
return np.random.choice([True, False], size=size)
|
||||
return np.random.uniform(-10, 10, size=size).astype(dt.np)
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.helpers import getenv, DEBUG
|
|||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported
|
||||
from test.helpers import is_dtype_supported, rand_for_dtype
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None)
|
||||
settings.load_profile("my_profile")
|
||||
|
@ -59,15 +59,7 @@ class TestDType(unittest.TestCase):
|
|||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||
DATA_SIZE = 10
|
||||
if dtypes.is_unsigned(cls.DTYPE):
|
||||
cls.DATA = np.random.randint(0, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
|
||||
elif dtypes.is_int(cls.DTYPE):
|
||||
cls.DATA = np.random.randint(-100, 100, size=DATA_SIZE, dtype=cls.DTYPE.np)
|
||||
elif cls.DTYPE == dtypes.bool:
|
||||
cls.DATA = np.random.choice([True, False], size=DATA_SIZE)
|
||||
else:
|
||||
cls.DATA = np.random.uniform(-10, 10, size=DATA_SIZE).astype(cls.DTYPE.np)
|
||||
cls.DATA = rand_for_dtype(cls.DTYPE, 10)
|
||||
def setUp(self):
|
||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
|
||||
|
|
Loading…
Reference in New Issue