diff --git a/test/test_dtype.py b/test/test_dtype.py index 391f722e..46b054c1 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -106,6 +106,16 @@ class TestDType(unittest.TestCase): self.assertTrue(all(isinstance(value, DType) for value in fields.values())) self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None)) + def test_resulting_and_init_dtypes_match(self): + dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"])) + data = [1., 2., 0., 0.5, -1.5, 5.25] + for dt in dtypes: + arr = np.asarray(data, dtype=dt) + tin = Tensor(arr).numpy() + tor = torch.as_tensor(arr).detach().numpy() + assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" + np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) + def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype) if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return