bf16 tests in test_dtype.py (#3749)

With bf16 creation and bf16 to numpy, we can test bf16 in test_dtype.
Only support HIP now as it needs bf16 buffer support. Also the rtoal is slightly larger
This commit is contained in:
chenyu 2024-03-15 00:17:11 -04:00 committed by GitHub
parent 33c01c9db0
commit d3a6319630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 4 deletions

View File

@ -14,7 +14,9 @@ core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in ["HIP"]
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
@ -44,7 +46,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7)
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@ -53,6 +55,7 @@ def _test_op(fxn, target_dtype:DType, target):
def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
class TestDType(unittest.TestCase):
@ -157,7 +160,6 @@ class TestBFloat16DType(unittest.TestCase):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)