diff --git a/test/test_dtype.py b/test/test_dtype.py index 1e083694..1f6236d6 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -216,6 +216,9 @@ class TestHelpers(unittest.TestCase): assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype) assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype) + def test_bf16_is_float(self): + assert dtypes.is_float(dtypes.bfloat16) + @given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8)) def test_scalar(self, dtype, amt): assert dtype.vec(amt).scalar() == dtype diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 40a3c674..62d124e2 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -137,7 +137,7 @@ class PtrDType(DType): class dtypes: @staticmethod - def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64) + def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64) @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x) @staticmethod