If feasible, do not truncate float64 down to float32 in cstyle renderer (#3420)

* do not truncate float64 precision

* use l suffix to try avoid overload confusion

* long line, ruff bloats the function otherwise

* fmt

* remove long double suffix (l), it's sufficient to have the float32 (f) suffix to avoid function overload ambigouity; add test showcasing rtol=1e-12 precision increase, the test fails without the renderer changes

* use more reasonable test values, same as test_int_to_float_unary_func

* disable test for CUDACPU, does not support half and segfaults on some operations per dtypes_alu test

* disable test for HIP, renderer does not support f64 precision

* do not use noqa E501, break up condition
This commit is contained in:
zku 2024-02-16 10:08:59 +01:00 committed by GitHub
parent 30f26279c5
commit 2d702ca073
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -161,7 +161,25 @@ class TestHalfDtype(TestDType): DTYPE = dtypes.half
class TestFloatDType(TestDType): DTYPE = dtypes.float
class TestDoubleDtype(TestDType): DTYPE = dtypes.double
class TestDoubleDtype(TestDType):
DTYPE = dtypes.double
@unittest.skipIf(getenv("CUDACPU",0)==1, "conversion not supported on CUDACPU")
@unittest.skipIf(getenv("HIP",0)==1, "HIP renderer does not support f64 precision")
def test_float64_increased_precision(self):
for func in [
lambda t: t.exp(),
lambda t: t.exp2(),
lambda t: t.log(),
lambda t: t.log2(),
lambda t: t.sqrt(),
lambda t: t.rsqrt(),
lambda t: t.sin(),
lambda t: t.cos(),
lambda t: t.tan(),
lambda t: t.sigmoid(),
]:
a = [2, 3, 4]
np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
class TestInt8Dtype(TestDType):
DTYPE = dtypes.int8

View File

@ -45,6 +45,7 @@ class CStyleLanguage(NamedTuple):
def render_const(self, x:Union[float,int,bool], var_dtype) -> str:
if math.isnan(x): val = "NAN"
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
elif var_dtype == dtypes.float64: val = f"{float(x)}"
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
return (self.render_cast([val]*var_dtype.count, var_dtype)
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)