some unary functions cast int input into float (#2740)

* some unary functions cast int input into float

* precision

* image dtype
This commit is contained in:
chenyu 2023-12-13 00:10:29 -05:00 committed by GitHub
parent 3e778fcc52
commit 2ef33abd20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 8 deletions

View File

@ -1,11 +1,12 @@
# ruff: noqa: E501
import unittest
import numpy as np
import torch
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
from tinygrad import Device
from tinygrad.tensor import Tensor, dtypes
from typing import Any, List
from hypothesis import given, strategies as st
from hypothesis import given, settings, strategies as st
def is_dtype_supported(dtype: DType):
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
@ -267,6 +268,25 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
class TestAutoCastType(unittest.TestCase):
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
@settings(deadline=None)
def test_int_to_float_unary_func(self, dtype):
for func in [
lambda t: t.exp(),
# lambda t: t.exp2(), # requires MUL
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
# lambda t: t.rsqrt(), # requires DIV
lambda t: t.sin(),
# lambda t: t.cos(), # requires SUB
# lambda t: t.tan(), # requires .cos() to work
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
if __name__ == '__main__':
unittest.main()

View File

@ -195,7 +195,8 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
def _get_recursive_parents(dtype:DType) -> Set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds]))
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}

View File

@ -1,6 +1,6 @@
import math
from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType
from tinygrad.helpers import argsort, DType, dtypes, least_upper_dtype
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@ -35,7 +35,7 @@ class Neg(Function):
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.SIN)
return x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
@ -52,14 +52,14 @@ class Relu(Function):
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
return x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(BinaryOps.DIV, self.x)
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -67,7 +67,7 @@ class Exp(Function):
class Sqrt(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(UnaryOps.SQRT)
self.ret = x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SQRT)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -78,7 +78,8 @@ class Sqrt(Function):
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).const(1).e(
BinaryOps.DIV, x.cast(ftype).const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.cast(ftype).const(-1/math.log(2))).e(UnaryOps.EXP2)))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: