diff --git a/test/test_dtype.py b/test/test_dtype.py index db17a531..0f5297f1 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -13,7 +13,8 @@ settings.load_profile("my_profile") core_dtypes = list(DTYPES_DICT.values()) if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove -floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] +dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt)] +dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)] def get_available_cast_dtypes(dtype: DType) -> List[DType]: if not is_dtype_supported(dtype): return [] @@ -331,66 +332,67 @@ class TestTypeSpec(unittest.TestCase): dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float def test_set_dtype_default(self): - dtypes.default_int = dtypes.int16 - assert dtypes.default_int == dtypes.int16 - dtypes.default_int = dtypes.int64 - assert dtypes.default_int == dtypes.int64 - dtypes.default_int = dtypes.int32 - assert dtypes.default_int == dtypes.int32 - dtypes.default_float = dtypes.float16 - assert dtypes.default_float == dtypes.float16 - dtypes.default_float = dtypes.float64 - assert dtypes.default_float == dtypes.float64 + for default_int in [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64]: + dtypes.default_int = default_int + assert dtypes.default_int == default_int - @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: + dtypes.default_float = default_float + assert dtypes.default_float == default_float + + @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_creation(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float - assert Tensor(True).dtype == dtypes.bool - assert Tensor(None).dtype == dtypes.default_float - assert Tensor(2).dtype == dtypes.default_int - assert Tensor(2.34).dtype == dtypes.default_float - assert Tensor([]).dtype == dtypes.default_float - assert Tensor([1]).dtype == dtypes.default_int - assert Tensor([1.1]).dtype == dtypes.default_float - #assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16 + _assert_eq(Tensor(True), dtypes.bool, True) + _assert_eq(Tensor(None), dtypes.default_float, []) + _assert_eq(Tensor(2), dtypes.default_int, 2) + _assert_eq(Tensor(2.34), dtypes.default_float, 2.34) + _assert_eq(Tensor([]), dtypes.default_float, []) + _assert_eq(Tensor([1]), dtypes.default_int, [1]) + _assert_eq(Tensor([1.1]), dtypes.default_float, [1.1]) - assert Tensor.eye(0).dtype == dtypes.default_float - assert Tensor.eye(3).dtype == dtypes.default_float - assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16 - assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64 + _assert_eq(Tensor.eye(0), dtypes.default_float, np.eye(0)) + _assert_eq(Tensor.eye(3), dtypes.default_float, np.eye(3)) + _assert_eq(Tensor.eye(3, dtype=dtypes.int64), dtypes.int64, np.eye(3)) + if is_dtype_supported(dtypes.float16): + _assert_eq(Tensor.eye(3, dtype=dtypes.float16), dtypes.float16, np.eye(3)) - - @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_full(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float - assert Tensor.ones([2,3]).dtype == dtypes.default_float - assert Tensor.zeros([2,3]).dtype == dtypes.default_float - assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float - assert Tensor.full([2,3], 3).dtype == dtypes.default_int - assert Tensor.full([2,3], True).dtype == dtypes.bool + _assert_eq(Tensor.zeros((2, 3)), dtypes.default_float, np.zeros((2, 3))) + _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.int64), dtypes.int64, np.zeros((2, 3))) + if is_dtype_supported(dtypes.float16): + _assert_eq(Tensor.zeros((2, 3), dtype=dtypes.float16), dtypes.float16, np.zeros((2, 3))) - assert Tensor.zeros(3, 3).dtype == dtypes.default_float - assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 - assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + _assert_eq(Tensor.ones((2, 3)), dtypes.default_float, np.ones((2, 3))) + _assert_eq(Tensor.ones((2, 3), dtype=dtypes.int64), dtypes.int64, np.ones((2, 3))) + if is_dtype_supported(dtypes.float16): + _assert_eq(Tensor.ones((2, 3), dtype=dtypes.float16), dtypes.float16, np.ones((2, 3))) - assert Tensor.ones(3, 3).dtype == dtypes.default_float - assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 - assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + _assert_eq(Tensor.full((2, 3), 3.0), dtypes.default_float, np.full((2, 3), 3.0)) + _assert_eq(Tensor.full((2, 3), 3), dtypes.default_int, np.full((2, 3), 3)) + _assert_eq(Tensor.full((2, 3), True), dtypes.bool, np.full((2, 3), True)) + _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) + _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.int64), dtypes.int64, np.full((2, 3), 3)) + if is_dtype_supported(dtypes.float16): + _assert_eq(Tensor.full((2, 3), 3, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3)) + _assert_eq(Tensor.full((2, 3), 3.0, dtype=dtypes.float16), dtypes.float16, np.full((2, 3), 3)) - assert Tensor.full((3, 3), 3).dtype == dtypes.default_int - assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float - assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16 - assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64 + @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) + def test_reduce_0d_default(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + _assert_eq(Tensor.ones((2,3,0)).sum(2), dtypes.default_float, np.zeros((2, 3))) + # TODO: what should this one be? + # _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.default_int).sum(2), dtypes.default_int, np.zeros((2, 3))) + _assert_eq(Tensor.ones((2,3,0), dtype=dtypes.int32).sum(2), dtypes.int32, np.zeros((2, 3))) - def test_reduce_0d_default(self): - assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float - assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int - - @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + @given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_arange(self, default_int, default_float): dtypes.default_int, dtypes.default_float = default_int, default_float + # TODO: this might fail with different default dtype https://github.com/tinygrad/tinygrad/issues/3823 assert Tensor.arange(5).dtype == dtypes.default_int assert Tensor.arange(5.0).dtype == dtypes.default_float assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16 @@ -399,7 +401,6 @@ class TestTypeSpec(unittest.TestCase): assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec") @given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) def test_bool_ops(self, dtype, op): assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool @@ -442,7 +443,7 @@ class TestTypePromotion(unittest.TestCase): assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16 assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16 - @given(strat.sampled_from(floats)) + @given(strat.sampled_from(dtype_floats)) def test_float_to_float(self, dt): assert least_upper_float(dt) == dt