mirror of https://github.com/commaai/tinygrad.git
tests for Tensor init data dtype and resulting dtype (#3247)
Co-authored-by: Hristo Georgiev <6043312+hristog@users.noreply.github.com>
This commit is contained in:
parent
3c728d1082
commit
3ae811af21
|
@ -106,6 +106,16 @@ class TestDType(unittest.TestCase):
|
|||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
|
||||
|
||||
def test_resulting_and_init_dtypes_match(self):
|
||||
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
|
||||
data = [1., 2., 0., 0.5, -1.5, 5.25]
|
||||
for dt in dtypes:
|
||||
arr = np.asarray(data, dtype=dt)
|
||||
tin = Tensor(arr).numpy()
|
||||
tor = torch.as_tensor(arr).detach().numpy()
|
||||
assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
|
||||
np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3)
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype)
|
||||
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return
|
||||
|
|
Loading…
Reference in New Issue