tests for Tensor init data dtype and resulting dtype (#3247)

Co-authored-by: Hristo Georgiev <6043312+hristog@users.noreply.github.com>
This commit is contained in:
Hristo Georgiev 2024-01-27 10:13:42 +02:00 committed by GitHub
parent 3c728d1082
commit 3ae811af21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 0 deletions

View File

@ -106,6 +106,16 @@ class TestDType(unittest.TestCase):
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data, dtype=dt)
tin = Tensor(arr).numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return