mirror of https://github.com/commaai/tinygrad.git
move test_dtype tests to test dtype and output value (#3826)
This commit is contained in:
parent
131bbb6563
commit
fa1921ec7d
|
@ -13,7 +13,8 @@ settings.load_profile("my_profile")
|
|||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
|
@ -331,66 +332,67 @@ class TestTypeSpec(unittest.TestCase):
|
|||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
def test_set_dtype_default(self):
|
||||
dtypes.default_int = dtypes.int16
|
||||
assert dtypes.default_int == dtypes.int16
|
||||
dtypes.default_int = dtypes.int64
|
||||
assert dtypes.default_int == dtypes.int64
|
||||
dtypes.default_int = dtypes.int32
|
||||
assert dtypes.default_int == dtypes.int32
|
||||
dtypes.default_float = dtypes.float16
|
||||
assert dtypes.default_float == dtypes.float16
|
||||
dtypes.default_float = dtypes.float64
|
||||
assert dtypes.default_float == dtypes.float64
|
||||
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
|
||||
dtypes.default_int = default_int
|
||||
assert dtypes.default_int == default_int
|
||||
|
||||
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
dtypes.default_float = default_float
|
||||
assert dtypes.default_float == default_float
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
assert Tensor(True).dtype == dtypes.bool
|
||||
assert Tensor(None).dtype == dtypes.default_float
|
||||
assert Tensor(2).dtype == dtypes.default_int
|
||||
assert Tensor(2.34).dtype == dtypes.default_float
|
||||
assert Tensor([]).dtype == dtypes.default_float
|
||||
assert Tensor([1]).dtype == dtypes.default_int
|
||||
assert Tensor([1.1]).dtype == dtypes.default_float
|
||||
#assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
|
||||
_assert_eq(Tensor(True), dtypes.bool, True)
|
||||
_assert_eq(Tensor(None), dtypes.default_float, [])
|
||||
_assert_eq(Tensor(2), dtypes.default_int, 2)
|
||||
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
|
||||
_assert_eq(Tensor([]), dtypes.default_float, [])
|
||||
_assert_eq(Tensor([1]), dtypes.default_int, [1])
|
||||
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
|
||||
|
||||
assert Tensor.eye(0).dtype == dtypes.default_float
|
||||
assert Tensor.eye(3).dtype == dtypes.default_float
|
||||
assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
|
||||
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
|
||||
|
||||
|
||||
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_full(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
assert Tensor.ones([2,3]).dtype == dtypes.default_float
|
||||
assert Tensor.zeros([2,3]).dtype == dtypes.default_float
|
||||
assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float
|
||||
assert Tensor.full([2,3], 3).dtype == dtypes.default_int
|
||||
assert Tensor.full([2,3], True).dtype == dtypes.bool
|
||||
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
|
||||
|
||||
assert Tensor.zeros(3, 3).dtype == dtypes.default_float
|
||||
assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
|
||||
|
||||
assert Tensor.ones(3, 3).dtype == dtypes.default_float
|
||||
assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
|
||||
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
|
||||
|
||||
assert Tensor.full((3, 3), 3).dtype == dtypes.default_int
|
||||
assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float
|
||||
assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16
|
||||
assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_reduce_0d_default(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
|
||||
# TODO: what should this one be?
|
||||
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
|
||||
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
|
||||
|
||||
def test_reduce_0d_default(self):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float
|
||||
assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_arange(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
# TODO: this might fail with different default dtype https://github.com/tinygrad/tinygrad/issues/3823
|
||||
assert Tensor.arange(5).dtype == dtypes.default_int
|
||||
assert Tensor.arange(5.0).dtype == dtypes.default_float
|
||||
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
|
||||
|
@ -399,7 +401,6 @@ class TestTypeSpec(unittest.TestCase):
|
|||
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
|
||||
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec")
|
||||
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
|
||||
def test_bool_ops(self, dtype, op):
|
||||
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
|
||||
|
@ -442,7 +443,7 @@ class TestTypePromotion(unittest.TestCase):
|
|||
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
|
||||
|
||||
@given(strat.sampled_from(floats))
|
||||
@given(strat.sampled_from(dtype_floats))
|
||||
def test_float_to_float(self, dt):
|
||||
assert least_upper_float(dt) == dt
|
||||
|
||||
|
|
Loading…
Reference in New Issue