From ef6e942a23fdcf0fb7ff5a7b82d27e201e7ad928 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 11 Dec 2023 23:14:23 -0500 Subject: [PATCH] dtype promotion helpers (#2724) * dtype promotion helpers * better tests * space --- test/test_dtype.py | 34 +++++++++++++++++++++++++++++++++- tinygrad/helpers.py | 16 +++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index c4126a66..1735740f 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp +from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype from tinygrad import Device from tinygrad.tensor import Tensor, dtypes from typing import Any, List @@ -234,5 +234,37 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int +# TODO: better way to write a set of core dtypes? +core_types = [d for d in DTYPES_DICT.values() if d not in [dtypes._arg_int32]] +class TestTypePromotion(unittest.TestCase): + @given(st.sampled_from(core_types)) + def test_self_promo_to_self(self, dtype): + assert least_upper_dtype(dtype) == dtype + assert least_upper_dtype(dtype, dtype) == dtype + assert least_upper_dtype(dtype, dtype, dtype) == dtype + + @given(st.sampled_from(core_types), st.sampled_from(core_types)) + def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2): + result = least_upper_dtype(dtype1, dtype2) + assert result >= dtype1 and result >= dtype2 + + def test_dtype_promo(self): + assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8 + assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16 + assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16 + assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32 + assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32 + assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64 + assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64 + # special! + assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float_scalar + assert least_upper_dtype(dtypes.float_scalar, dtypes.float16) == dtypes.float16 + assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32 + assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64 + + assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32 + assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64 + + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 1d7ec968..bc1ff517 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -3,7 +3,7 @@ import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, import numpy as np from urllib import request from tqdm import tqdm -from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable +from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Set if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 from typing_extensions import TypeGuard @@ -148,12 +148,15 @@ class dtypes: def fields() -> Dict[str, DType]: return DTYPES_DICT bool: Final[DType] = DType(0, 1, "bool", np.bool_) float16: Final[DType] = DType(9, 2, "half", np.float16) + # TODO: make float_scalar and int_scalar real and link to default float and int dtype + float_scalar = float16 half = float16 float32: Final[DType] = DType(10, 4, "float", np.float32) float = float32 float64: Final[DType] = DType(11, 8, "double", np.float64) double = float64 int8: Final[DType] = DType(1, 1, "char", np.int8) + int_scalar = int8 char = int8 int16: Final[DType] = DType(3, 2, "short", np.int16) short = int16 @@ -182,6 +185,17 @@ class dtypes: @staticmethod def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp, dtypes.float32) +# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html +promo_lattice = { dtypes.bool: [dtypes.int_scalar], + dtypes.int_scalar: [dtypes.uint8, dtypes.int8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], dtypes.int64: [dtypes.float_scalar], + dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float_scalar], + dtypes.float_scalar: [dtypes.float16, dtypes.bfloat16], dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } + +@functools.lru_cache(None) +def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} +@functools.lru_cache(None) +def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) + # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod} INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}