mirror of https://github.com/commaai/tinygrad.git
bf16 tests in test_dtype.py (#3749)
With bf16 creation and bf16 to numpy, we can test bf16 in test_dtype. Only support HIP now as it needs bf16 buffer support. Also the rtoal is slightly larger
This commit is contained in:
parent
33c01c9db0
commit
d3a6319630
|
@ -14,7 +14,9 @@ core_dtypes = list(DTYPES_DICT.values())
|
|||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in ["HIP"]
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
# for CI GPU, cl_khr_fp16 isn't supported
|
||||
# for CI LLVM, it segfaults because it can't link to the casting function
|
||||
|
@ -44,7 +46,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
|||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7)
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
|
@ -53,6 +55,7 @@ def _test_op(fxn, target_dtype:DType, target):
|
|||
def _test_cast(a:Tensor, target_dtype:DType):
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
|
@ -157,7 +160,6 @@ class TestBFloat16DType(unittest.TestCase):
|
|||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
|
||||
def test_float_to_bf16(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
|
||||
|
||||
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
|
||||
|
|
Loading…
Reference in New Issue