fix dtypes helpers for integers (#2716)

* scalar

* maybe do this instead

* Revert "scalar"

everything is a scalar

* add tests in test_dtype

* fuzz testing + fix unsigned ints

* fuzz everything
This commit is contained in:
qazal 2023-12-11 19:28:19 +02:00 committed by GitHub
parent bc3c4ce50b
commit a43bc78804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 5 deletions

View File

@ -4,6 +4,7 @@ from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType,
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, strategies as st
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
@ -49,7 +50,7 @@ class TestDType(unittest.TestCase):
DATA: Any = None
@classmethod
def setUpClass(cls):
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
@ -189,5 +190,33 @@ class TestEqStrDType(unittest.TestCase):
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
@given(st.sampled_from(signed_ints+uints), st.integers(min_value=1, max_value=8))
def test_is_int(self, dtype, amt):
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(uints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_uints(self, dtype, amt):
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(signed_ints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_signed_ints(self, dtype, amt):
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(floats), st.integers(min_value=1, max_value=8))
def test_is_float(self, dtype, amt):
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
if __name__ == '__main__':
unittest.main()

View File

@ -112,7 +112,7 @@ class DType(NamedTuple):
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz)
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
# dependent typing?
@ -137,11 +137,11 @@ class PtrDType(DType):
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
def is_int(x: DType)-> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4))
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
@ -154,14 +154,21 @@ class dtypes:
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
char = int8
int16: Final[DType] = DType(3, 2, "short", np.int16)
short = int16
int32: Final[DType] = DType(5, 4, "int", np.int32)
int = int32
int64: Final[DType] = DType(7, 8, "long", np.int64)
long = int64
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uchar = uint8
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
ushort = uint16
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint = uint32
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
ulong = uint64
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)