bf16 is float (#2786)

* add bfloat16 to is_float check

* and test
This commit is contained in:
chenyu 2023-12-15 21:41:30 -05:00 committed by GitHub
parent e4bbbc5bc3
commit 0703075357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 1 deletions

View File

@ -216,6 +216,9 @@ class TestHelpers(unittest.TestCase):
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
def test_bf16_is_float(self):
assert dtypes.is_float(dtypes.bfloat16)
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype

View File

@ -137,7 +137,7 @@ class PtrDType(DType):
class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod