move test_dtype tests to test dtype and output value (#3826)

This commit is contained in:
chenyu 2024-03-19 16:31:27 -04:00 committed by GitHub
parent 131bbb6563
commit fa1921ec7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 49 additions and 48 deletions

View File

@ -13,7 +13,8 @@ settings.load_profile("my_profile")
core_dtypes = list(DTYPES_DICT.values())
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
floats = [dt for dt in core_dtypes if dtypes.is_float(dt)]
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt)]
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
if not is_dtype_supported(dtype): return []
@ -331,66 +332,67 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
def test_set_dtype_default(self):
dtypes.default_int = dtypes.int16
assert dtypes.default_int == dtypes.int16
dtypes.default_int = dtypes.int64
assert dtypes.default_int == dtypes.int64
dtypes.default_int = dtypes.int32
assert dtypes.default_int == dtypes.int32
dtypes.default_float = dtypes.float16
assert dtypes.default_float == dtypes.float16
dtypes.default_float = dtypes.float64
assert dtypes.default_float == dtypes.float64
for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]:
dtypes.default_int = default_int
assert dtypes.default_int == default_int
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor(True).dtype == dtypes.bool
assert Tensor(None).dtype == dtypes.default_float
assert Tensor(2).dtype == dtypes.default_int
assert Tensor(2.34).dtype == dtypes.default_float
assert Tensor([]).dtype == dtypes.default_float
assert Tensor([1]).dtype == dtypes.default_int
assert Tensor([1.1]).dtype == dtypes.default_float
#assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16
_assert_eq(Tensor(True), dtypes.bool, True)
_assert_eq(Tensor(None), dtypes.default_float, [])
_assert_eq(Tensor(2), dtypes.default_int, 2)
_assert_eq(Tensor(2.34), dtypes.default_float, 2.34)
_assert_eq(Tensor([]), dtypes.default_float, [])
_assert_eq(Tensor([1]), dtypes.default_int, [1])
_assert_eq(Tensor([1.1]), dtypes.default_float, [1.1])
assert Tensor.eye(0).dtype == dtypes.default_float
assert Tensor.eye(3).dtype == dtypes.default_float
assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64
_assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0))
_assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3))
_assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3))
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_full(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
assert Tensor.ones([2,3]).dtype == dtypes.default_float
assert Tensor.zeros([2,3]).dtype == dtypes.default_float
assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float
assert Tensor.full([2,3], 3).dtype == dtypes.default_int
assert Tensor.full([2,3], True).dtype == dtypes.bool
_assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3)))
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3)))
assert Tensor.zeros(3, 3).dtype == dtypes.default_float
assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
_assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3)))
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3)))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3)))
assert Tensor.ones(3, 3).dtype == dtypes.default_float
assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64
_assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0))
_assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True))
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
_assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3))
assert Tensor.full((3, 3), 3).dtype == dtypes.default_int
assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float
assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16
assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_reduce_0d_default(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
_assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3)))
# TODO: what should this one be?
# _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3)))
_assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3)))
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float
assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
@given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64]))
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_arange(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float
# TODO: this might fail with different default dtype https://github.com/tinygrad/tinygrad/issues/3823
assert Tensor.arange(5).dtype == dtypes.default_int
assert Tensor.arange(5.0).dtype == dtypes.default_float
assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16
@ -399,7 +401,6 @@ class TestTypeSpec(unittest.TestCase):
assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float
assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec")
@given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne]))
def test_bool_ops(self, dtype, op):
assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool
@ -442,7 +443,7 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
@given(strat.sampled_from(floats))
@given(strat.sampled_from(dtype_floats))
def test_float_to_float(self, dt):
assert least_upper_float(dt) == dt