mirror of https://github.com/commaai/tinygrad.git
fix dtypes helpers for integers (#2716)
* scalar * maybe do this instead * Revert "scalar" everything is a scalar * add tests in test_dtype * fuzz testing + fix unsigned ints * fuzz everything
This commit is contained in:
parent
bc3c4ce50b
commit
a43bc78804
|
@ -4,6 +4,7 @@ from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType,
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
from tinygrad.tensor import Tensor, dtypes
|
from tinygrad.tensor import Tensor, dtypes
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
from hypothesis import given, strategies as st
|
||||||
|
|
||||||
def is_dtype_supported(dtype: DType):
|
def is_dtype_supported(dtype: DType):
|
||||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||||
|
@ -49,7 +50,7 @@ class TestDType(unittest.TestCase):
|
||||||
DATA: Any = None
|
DATA: Any = None
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported")
|
||||||
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
|
cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist()
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||||
|
@ -189,5 +190,33 @@ class TestEqStrDType(unittest.TestCase):
|
||||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||||
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float")
|
||||||
|
|
||||||
|
class TestHelpers(unittest.TestCase):
|
||||||
|
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||||
|
uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||||
|
floats = (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||||
|
|
||||||
|
@given(st.sampled_from(signed_ints+uints), st.integers(min_value=1, max_value=8))
|
||||||
|
def test_is_int(self, dtype, amt):
|
||||||
|
assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
|
||||||
|
@given(st.sampled_from(uints), st.integers(min_value=1, max_value=8))
|
||||||
|
def test_is_unsigned_uints(self, dtype, amt):
|
||||||
|
assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
|
||||||
|
@given(st.sampled_from(signed_ints), st.integers(min_value=1, max_value=8))
|
||||||
|
def test_is_unsigned_signed_ints(self, dtype, amt):
|
||||||
|
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
|
||||||
|
@given(st.sampled_from(floats), st.integers(min_value=1, max_value=8))
|
||||||
|
def test_is_float(self, dtype, amt):
|
||||||
|
assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||||
|
|
||||||
|
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
|
||||||
|
def test_scalar(self, dtype, amt):
|
||||||
|
assert dtype.vec(amt).scalar() == dtype
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -112,7 +112,7 @@ class DType(NamedTuple):
|
||||||
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}" if self.sz == 1 else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
|
||||||
def vec(self, sz:int):
|
def vec(self, sz:int):
|
||||||
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
|
||||||
return DType(self.priority, self.itemsize*sz, self.name+str(sz), None, sz)
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self]}{str(sz)}", None, sz)
|
||||||
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
|
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.sz))]] if self.sz > 1 else self
|
||||||
|
|
||||||
# dependent typing?
|
# dependent typing?
|
||||||
|
@ -137,11 +137,11 @@ class PtrDType(DType):
|
||||||
|
|
||||||
class dtypes:
|
class dtypes:
|
||||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
def is_int(x: DType)-> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes.half.vec(4), dtypes.float.vec(2), dtypes.float.vec(4))
|
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -154,14 +154,21 @@ class dtypes:
|
||||||
float64: Final[DType] = DType(11, 8, "double", np.float64)
|
float64: Final[DType] = DType(11, 8, "double", np.float64)
|
||||||
double = float64
|
double = float64
|
||||||
int8: Final[DType] = DType(1, 1, "char", np.int8)
|
int8: Final[DType] = DType(1, 1, "char", np.int8)
|
||||||
|
char = int8
|
||||||
int16: Final[DType] = DType(3, 2, "short", np.int16)
|
int16: Final[DType] = DType(3, 2, "short", np.int16)
|
||||||
|
short = int16
|
||||||
int32: Final[DType] = DType(5, 4, "int", np.int32)
|
int32: Final[DType] = DType(5, 4, "int", np.int32)
|
||||||
int = int32
|
int = int32
|
||||||
int64: Final[DType] = DType(7, 8, "long", np.int64)
|
int64: Final[DType] = DType(7, 8, "long", np.int64)
|
||||||
|
long = int64
|
||||||
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
|
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
|
||||||
|
uchar = uint8
|
||||||
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
|
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
|
||||||
|
ushort = uint16
|
||||||
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
|
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
|
||||||
|
uint = uint32
|
||||||
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
|
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
|
||||||
|
ulong = uint64
|
||||||
|
|
||||||
# NOTE: bfloat16 isn't supported in numpy
|
# NOTE: bfloat16 isn't supported in numpy
|
||||||
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)
|
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)
|
||||||
|
|
Loading…
Reference in New Issue