mirror of https://github.com/commaai/tinygrad.git
some unary functions cast int input into float (#2740)
* some unary functions cast int input into float * precision * image dtype
This commit is contained in:
parent
3e778fcc52
commit
2ef33abd20
|
@ -1,11 +1,12 @@
|
|||
# ruff: noqa: E501
|
||||
import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp, least_upper_dtype
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor, dtypes
|
||||
from typing import Any, List
|
||||
from hypothesis import given, strategies as st
|
||||
from hypothesis import given, settings, strategies as st
|
||||
|
||||
def is_dtype_supported(dtype: DType):
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
|
@ -267,6 +268,25 @@ class TestTypePromotion(unittest.TestCase):
|
|||
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
|
||||
|
||||
class TestAutoCastType(unittest.TestCase):
|
||||
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
@settings(deadline=None)
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
# lambda t: t.exp2(), # requires MUL
|
||||
lambda t: t.log(),
|
||||
lambda t: t.log2(),
|
||||
lambda t: t.sqrt(),
|
||||
# lambda t: t.rsqrt(), # requires DIV
|
||||
lambda t: t.sin(),
|
||||
# lambda t: t.cos(), # requires SUB
|
||||
# lambda t: t.tan(), # requires .cos() to work
|
||||
lambda t: t.sigmoid(),
|
||||
]:
|
||||
a = [2, 3, 4]
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -195,7 +195,8 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8],
|
|||
def _get_recursive_parents(dtype:DType) -> Set[DType]:
|
||||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.lru_cache(None)
|
||||
def least_upper_dtype(*ds:DType) -> DType: return min(set.intersection(*[_get_recursive_parents(d) for d in ds]))
|
||||
def least_upper_dtype(*ds:DType) -> DType:
|
||||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
|
||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and v.__class__ is not staticmethod}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from typing import Tuple, Optional, cast
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.helpers import argsort, DType, dtypes, least_upper_dtype
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -35,7 +35,7 @@ class Neg(Function):
|
|||
class Sin(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.SIN)
|
||||
return x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SIN)
|
||||
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer:
|
||||
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
|
||||
|
@ -52,14 +52,14 @@ class Relu(Function):
|
|||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
return x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
@ -67,7 +67,7 @@ class Exp(Function):
|
|||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(UnaryOps.SQRT)
|
||||
self.ret = x.cast(least_upper_dtype(x.dtype, dtypes.float)).e(UnaryOps.SQRT)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
@ -78,7 +78,8 @@ class Sqrt(Function):
|
|||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
self.ret = x.cast(ftype:=least_upper_dtype(x.dtype, dtypes.float)).const(1).e(
|
||||
BinaryOps.DIV, x.cast(ftype).const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.cast(ftype).const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
|
Loading…
Reference in New Issue