fix output of truncate_fp16 (#6381)

make sure the non-inf path returns the truncated value
This commit is contained in:
chenyu 2024-09-05 22:55:43 -04:00 committed by GitHub
parent c88329244b
commit 002303c145
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View File

@ -6,6 +6,7 @@ from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.ops import truncate_fp16
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype
@ -382,6 +383,12 @@ class TestHelpers(unittest.TestCase):
np.testing.assert_equal(dtypes.min(dt), False)
np.testing.assert_equal(dtypes.max(dt), True)
def test_truncate_fp16(self):
self.assertEqual(truncate_fp16(1), 1)
self.assertEqual(truncate_fp16(65504), 65504)
self.assertEqual(truncate_fp16(65519.999), 65504)
self.assertEqual(truncate_fp16(65520), math.inf)
class TestTypeSpec(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float

View File

@ -459,10 +459,7 @@ python_alu: Dict[Op, Callable] = {
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
def truncate_fp16(x):
try:
x = float(x)
struct.pack("@e", x)
return x
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
truncate: Dict[DType, Callable] = {dtypes.bool: bool,