mirror of https://github.com/commaai/tinygrad.git
parent
e4bbbc5bc3
commit
0703075357
|
@ -216,6 +216,9 @@ class TestHelpers(unittest.TestCase):
|
|||
assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype)
|
||||
assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype)
|
||||
|
||||
def test_bf16_is_float(self):
|
||||
assert dtypes.is_float(dtypes.bfloat16)
|
||||
|
||||
@given(st.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), st.integers(min_value=2, max_value=8))
|
||||
def test_scalar(self, dtype, amt):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
|
|
|
@ -137,7 +137,7 @@ class PtrDType(DType):
|
|||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.float32, dtypes.float64)
|
||||
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue