dtype promotion helpers (#2724)

* dtype promotion helpers

* better tests

* space
This commit is contained in:
chenyu 2023-12-11 23:14:23 -05:00 committed by GitHub
parent 0232db294d
commit ef6e942a23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 2 deletions

View File

@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
@ -234,5 +234,37 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
# TODO: better way to write a set of core dtypes?
core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]]
class TestTypePromotion(unittest.TestCase):
@given(st.sampled_from(core_types))
def test_self_promo_to_self(self, dtype):
assert least_upper_dtype(dtype) == dtype
assert least_upper_dtype(dtype, dtype) == dtype
assert least_upper_dtype(dtype, dtype, dtype) == dtype
@given(st.sampled_from(core_types), st.sampled_from(core_types))
def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2):
result = least_upper_dtype(dtype1, dtype2)
assert result >= dtype1 and result >= dtype2
def test_dtype_promo(self):
assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8
assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16
assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16
assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32
assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32
assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64
assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64
# special!
assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float_scalar
assert least_upper_dtype(dtypes.float_scalar, dtypes.float16) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64
assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32
assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64
if __name__ == '__main__':
unittest.main()

View File

@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle,
import numpy as np
from urllib import request
from tqdm import tqdm
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Set
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
@ -148,12 +148,15 @@ class dtypes:
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(9, 2, "half", np.float16)
# TODO: make float_scalar and int_scalar real and link to default float and int dtype
float_scalar = float16
half = float16
float32: Final[DType] = DType(10, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
int_scalar = int8
char = int8
int16: Final[DType] = DType(3, 2, "short", np.int16)
short = int16
@ -182,6 +185,17 @@ class dtypes:
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32)
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
promo_lattice = { dtypes.bool: [dtypes.int_scalar],
dtypes.int_scalar: [dtypes.uint8, dtypes.int8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float_scalar],
dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float_scalar],
dtypes.float_scalar: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@functools.lru_cache(None)
def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds]))
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}