mirror of https://github.com/commaai/tinygrad.git
dtype promotion helpers (#2724)
* dtype promotion helpers * better tests * space
This commit is contained in:
parent
0232db294d
commit
ef6e942a23
|
@ -1,6 +1,6 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
|
@ -234,5 +234,37 @@ class TestTypeSpec(unittest.TestCase):
|
|||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
# TODO: better way to write a set of core dtypes?
|
||||
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
|
||||
class TestTypePromotion(unittest.TestCase):
|
||||
@given(st.sampled_from(core_types))
|
||||
def test_self_promo_to_self(self, dtype):
|
||||
assert least_upper_dtype(dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype) == dtype
|
||||
assert least_upper_dtype(dtype, dtype, dtype) == dtype
|
||||
|
||||
@given(st.sampled_from(core_types), st.sampled_from(core_types))
|
||||
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
|
||||
result = least_upper_dtype(dtype1, dtype2)
|
||||
assert result >= dtype1 and result >= dtype2
|
||||
|
||||
def test_dtype_promo(self):
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
|
||||
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
|
||||
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
|
||||
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
|
||||
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
|
||||
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
|
||||
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
|
||||
# special!
|
||||
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float_scalar
|
||||
assert least_upper_dtype(dtypes.float_scalar, dtypes.float16) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
|
||||
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
|
||||
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
|
||||
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle,
|
|||
import numpy as np
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Set
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
|
@ -148,12 +148,15 @@ class dtypes:
|
|||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
||||
float16: Final[DType] = DType(9, 2, "half", np.float16)
|
||||
# TODO: make float_scalar and int_scalar real and link to default float and int dtype
|
||||
float_scalar = float16
|
||||
half = float16
|
||||
float32: Final[DType] = DType(10, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(11, 8, "double", np.float64)
|
||||
double = float64
|
||||
int8: Final[DType] = DType(1, 1, "char", np.int8)
|
||||
int_scalar = int8
|
||||
char = int8
|
||||
int16: Final[DType] = DType(3, 2, "short", np.int16)
|
||||
short = int16
|
||||
|
@ -182,6 +185,17 @@ class dtypes:
|
|||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
|
||||
|
||||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
|
||||
promo_lattice = { dtypes.bool: [dtypes.int_scalar],
|
||||
dtypes.int_scalar: [dtypes.uint8, dtypes.int8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float_scalar],
|
||||
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float_scalar],
|
||||
dtypes.float_scalar: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.lru_cache(None)
|
||||
def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds]))
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
|
||||
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
||||
|
|
Loading…
Reference in New Issue