mirror of https://github.com/commaai/tinygrad.git
fix output of truncate_fp16 (#6381)
make sure the non-inf path returns the truncated value
This commit is contained in:
parent
c88329244b
commit
002303c145
|
@ -6,6 +6,7 @@ from tinygrad.helpers import getenv, DEBUG, CI
|
|||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.ops import truncate_fp16
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported, rand_for_dtype
|
||||
|
||||
|
@ -382,6 +383,12 @@ class TestHelpers(unittest.TestCase):
|
|||
np.testing.assert_equal(dtypes.min(dt), False)
|
||||
np.testing.assert_equal(dtypes.max(dt), True)
|
||||
|
||||
def test_truncate_fp16(self):
|
||||
self.assertEqual(truncate_fp16(1), 1)
|
||||
self.assertEqual(truncate_fp16(65504), 65504)
|
||||
self.assertEqual(truncate_fp16(65519.999), 65504)
|
||||
self.assertEqual(truncate_fp16(65520), math.inf)
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
|
|
|
@ -459,10 +459,7 @@ python_alu: Dict[Op, Callable] = {
|
|||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def truncate_fp16(x):
|
||||
try:
|
||||
x = float(x)
|
||||
struct.pack("@e", x)
|
||||
return x
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
|
|
Loading…
Reference in New Issue