mirror of https://github.com/commaai/tinygrad.git
If feasible, do not truncate float64 down to float32 in cstyle renderer (#3420)
* do not truncate float64 precision * use l suffix to try avoid overload confusion * long line, ruff bloats the function otherwise * fmt * remove long double suffix (l), it's sufficient to have the float32 (f) suffix to avoid function overload ambigouity; add test showcasing rtol=1e-12 precision increase, the test fails without the renderer changes * use more reasonable test values, same as test_int_to_float_unary_func * disable test for CUDACPU, does not support half and segfaults on some operations per dtypes_alu test * disable test for HIP, renderer does not support f64 precision * do not use noqa E501, break up condition
This commit is contained in:
parent
30f26279c5
commit
2d702ca073
|
@ -161,7 +161,25 @@ class TestHalfDtype(TestDType): DTYPE = dtypes.half
|
|||
|
||||
class TestFloatDType(TestDType): DTYPE = dtypes.float
|
||||
|
||||
class TestDoubleDtype(TestDType): DTYPE = dtypes.double
|
||||
class TestDoubleDtype(TestDType):
|
||||
DTYPE = dtypes.double
|
||||
@unittest.skipIf(getenv("CUDACPU",0)==1, "conversion not supported on CUDACPU")
|
||||
@unittest.skipIf(getenv("HIP",0)==1, "HIP renderer does not support f64 precision")
|
||||
def test_float64_increased_precision(self):
|
||||
for func in [
|
||||
lambda t: t.exp(),
|
||||
lambda t: t.exp2(),
|
||||
lambda t: t.log(),
|
||||
lambda t: t.log2(),
|
||||
lambda t: t.sqrt(),
|
||||
lambda t: t.rsqrt(),
|
||||
lambda t: t.sin(),
|
||||
lambda t: t.cos(),
|
||||
lambda t: t.tan(),
|
||||
lambda t: t.sigmoid(),
|
||||
]:
|
||||
a = [2, 3, 4]
|
||||
np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12)
|
||||
|
||||
class TestInt8Dtype(TestDType):
|
||||
DTYPE = dtypes.int8
|
||||
|
|
|
@ -45,6 +45,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def render_const(self, x:Union[float,int,bool], var_dtype) -> str:
|
||||
if math.isnan(x): val = "NAN"
|
||||
elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY"
|
||||
elif var_dtype == dtypes.float64: val = f"{float(x)}"
|
||||
else: val = f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower()
|
||||
return (self.render_cast([val]*var_dtype.count, var_dtype)
|
||||
if var_dtype.count > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
|
Loading…
Reference in New Issue