diff --git a/test/test_dtype.py b/test/test_dtype.py index a1acfca7..c9fe72da 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,11 +1,12 @@ # ruff: noqa: E501 import unittest import numpy as np +import torch from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype from tinygrad import Device from tinygrad.tensor import Tensor, dtypes from typing import Any, List -from hypothesis import given, strategies as st +from hypothesis import given, settings, strategies as st def is_dtype_supported(dtype: DType): # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) @@ -267,6 +268,25 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16 assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16 +class TestAutoCastType(unittest.TestCase): + @given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)])) + @settings(deadline=None) + def test_int_to_float_unary_func(self, dtype): + for func in [ + lambda t: t.exp(), + # lambda t: t.exp2(), # requires MUL + lambda t: t.log(), + lambda t: t.log2(), + lambda t: t.sqrt(), + # lambda t: t.rsqrt(), # requires DIV + lambda t: t.sin(), + # lambda t: t.cos(), # requires SUB + # lambda t: t.tan(), # requires .cos() to work + lambda t: t.sigmoid(), + ]: + a = [2, 3, 4] + np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 25241680..0e0f820d 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -195,7 +195,8 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], def _get_recursive_parents(dtype:DType) -> Set[DType]: return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} @functools.lru_cache(None) -def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) +def least_upper_dtype(*ds:DType) -> DType: + return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] # HACK: staticmethods are not callable in 3.8 so we have to compare the class DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod} diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index ae937da9..db98e81e 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,6 +1,6 @@ import math from typing import Tuple, Optional, cast -from tinygrad.helpers import argsort, DType +from tinygrad.helpers import argsort, DType, dtypes, least_upper_dtype from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.tensor import Function from tinygrad.lazy import LazyBuffer @@ -35,7 +35,7 @@ class Neg(Function): class Sin(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x - return x.e(UnaryOps.SIN) + return x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SIN) def backward(self, grad:LazyBuffer) -> LazyBuffer: return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad) @@ -52,14 +52,14 @@ class Relu(Function): class Log(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x - return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) + return x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2))) def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) + self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -67,7 +67,7 @@ class Exp(Function): class Sqrt(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.e(UnaryOps.SQRT) + self.ret = x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SQRT) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: @@ -78,7 +78,8 @@ class Sqrt(Function): # TODO: have the backend automatically find this class Sigmoid(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: - self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2))) + self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).const(1).e( + BinaryOps.DIV, x.cast(ftype).const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.cast(ftype).const(-1/math.log(2))).e(UnaryOps.EXP2))) return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: