diff --git a/test/test_dtype.py b/test/test_dtype.py index bc2463e9..8952eaa9 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -14,7 +14,9 @@ core_dtypes = list(DTYPES_DICT.values()) if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): - if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType + if dtype == dtypes.bfloat16: + # NOTE: this requires bf16 buffer support + return device in ["HIP"] if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] # for CI GPU, cl_khr_fp16 isn't supported # for CI LLVM, it segfaults because it can't link to the casting function @@ -44,7 +46,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7) + np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7)) except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e @@ -53,6 +55,7 @@ def _test_op(fxn, target_dtype:DType, target): def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np))) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): + if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet") _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist()) class TestDType(unittest.TestCase): @@ -157,8 +160,7 @@ class TestBFloat16DType(unittest.TestCase): _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32) def test_float_to_bf16(self): - with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16) + _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16) # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)