fix dtypes helpers for integers (#2716)

* scalar

* maybe do this instead

* Revert "scalar"

everything is a scalar

* add tests in test_dtype

* fuzz testing + fix unsigned ints

* fuzz everything
This commit is contained in:
qazal 2023-12-11 19:28:19 +02:00 committed by GitHub
parent bc3c4ce50b
commit a43bc78804
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 5 deletions

View File

@ -4,6 +4,7 @@ from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType,
from tinygrad import Device from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes from tinygrad.tensor import Tensor, dtypes
from typing import Any, List from typing import Any, List
from hypothesis import given, strategies as st
def is_dtype_supported(dtype: DType): def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
@ -49,7 +50,7 @@ class TestDType(unittest.TestCase):
DATA: Any = None DATA: Any = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist() cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
def setUp(self): def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class") if self.DTYPE is None: raise unittest.SkipTest("base class")
@ -189,5 +190,33 @@ class TestEqStrDType(unittest.TestCase):
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float") self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
@given(st.sampled_from(signed_ints+uints), st.integers(min_value=1, max_value=8))
def test_is_int(self, dtype, amt):
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(uints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_uints(self, dtype, amt):
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(signed_ints), st.integers(min_value=1, max_value=8))
def test_is_unsigned_signed_ints(self, dtype, amt):
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from(floats), st.integers(min_value=1, max_value=8))
def test_is_float(self, dtype, amt):
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -112,7 +112,7 @@ class DType(NamedTuple):
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}" def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
def vec(self, sz:int): def vec(self, sz:int):
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}" assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz) return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
# dependent typing? # dependent typing?
@ -137,11 +137,11 @@ class PtrDType(DType):
class dtypes: class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) def is_int(x: DType)-> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod @staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4)) def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
@staticmethod @staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod @staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod @staticmethod
@ -154,14 +154,21 @@ class dtypes:
float64: Final[DType] = DType(11, 8, "double", np.float64) float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64 double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8) int8: Final[DType] = DType(1, 1, "char", np.int8)
char = int8
int16: Final[DType] = DType(3, 2, "short", np.int16) int16: Final[DType] = DType(3, 2, "short", np.int16)
short = int16
int32: Final[DType] = DType(5, 4, "int", np.int32) int32: Final[DType] = DType(5, 4, "int", np.int32)
int = int32 int = int32
int64: Final[DType] = DType(7, 8, "long", np.int64) int64: Final[DType] = DType(7, 8, "long", np.int64)
long = int64
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uchar = uint8
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
ushort = uint16
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint = uint32
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
ulong = uint64
# NOTE: bfloat16 isn't supported in numpy # NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None) bfloat16: Final[DType] = DType(9, 2, "__bf16", None)