diff --git a/test/test_dtype.py b/test/test_dtype.py index 42577acd..54f57a77 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -6,6 +6,7 @@ from tinygrad.helpers import getenv, DEBUG, CI from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype from tinygrad import Device, Tensor, dtypes from tinygrad.tensor import _to_np_dtype +from tinygrad.ops import truncate_fp16 from hypothesis import given, settings, strategies as strat from test.helpers import is_dtype_supported, rand_for_dtype @@ -382,6 +383,12 @@ class TestHelpers(unittest.TestCase): np.testing.assert_equal(dtypes.min(dt), False) np.testing.assert_equal(dtypes.max(dt), True) + def test_truncate_fp16(self): + self.assertEqual(truncate_fp16(1), 1) + self.assertEqual(truncate_fp16(65504), 65504) + self.assertEqual(truncate_fp16(65519.999), 65504) + self.assertEqual(truncate_fp16(65520), math.inf) + class TestTypeSpec(unittest.TestCase): def setUp(self): self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ff3160d7..5fe880a3 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -459,10 +459,7 @@ python_alu: Dict[Op, Callable] = { TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} def truncate_fp16(x): - try: - x = float(x) - struct.pack("@e", x) - return x + try: return struct.unpack("@e", struct.pack("@e", float(x)))[0] except OverflowError: return math.copysign(math.inf, x) truncate: Dict[DType, Callable] = {dtypes.bool: bool,