test output dtypes matche in test_ops (#3703)

need to cast some torch output to int32 because torch default returns int64 for index related function

close #2797
This commit is contained in:
chenyu 2024-03-12 12:44:40 -04:00 committed by GitHub
parent 798970cfad
commit f599c6e7f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 54 additions and 34 deletions

View File

@ -30,8 +30,9 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
def compare(s, tinygrad_output, torch_output, atol, rtol):
if PRINT_TENSORS: print(s, tinygrad_output, torch_output)
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
try:
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}"
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
except Exception as e:
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")
@ -83,19 +84,29 @@ class TestOps(unittest.TestCase):
if not CI: print("\ntesting %40r torch/tinygrad exception: %s / %s" % (shps, torch_cm.exception, tinygrad_cm.exception), end="")
def test_full_like(self):
a = Tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,2,3],[4,5,6]])
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True)
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True)
def test_full(self):
helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True)
helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True)
def test_zeros(self):
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True)
def test_zeros_like(self):
a = Tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,2,3],[4,5,6]])
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
def test_empty_0(self):
@ -105,9 +116,14 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True)
helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True)
def test_ones_like(self):
a = Tensor([[1,2,3],[4,5,6]])
b = torch.tensor([[1,2,3],[4,5,6]])
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
def test_eye(self):
@ -116,7 +132,7 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True)
def test_split(self):
def tensor(s): return torch.arange(math.prod(s)).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
def tensor(s): return torch.arange(math.prod(s), dtype=torch.int32).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
test_cases = [
(tensor((10,)), 5, {}),
(tensor((10,)), [1,4,5], {}),
@ -135,42 +151,42 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True)
def test_chunk(self):
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 1)
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 1).chunk(6, 0)
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 0)
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 1).chunk(3, -1)
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(3, -1)
ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2)
tor = torch.arange(13, dtype=torch.int32).repeat(8, 3, 3).chunk(3, -2)
ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2)
assert len(tor) == len(ten)
for i in range(len(tor)):
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True)
helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True)
helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True)
helper_test_op([], lambda: torch.arange(1, 78, 2), lambda: Tensor.arange(1, 78, 2), forward_only=True)
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
helper_test_op([], lambda: torch.arange(5, 10, 3, dtype=torch.int32), lambda: Tensor.arange(5, 10, 3), forward_only=True)
helper_test_op([], lambda: torch.arange(10, 5, -3, dtype=torch.int32), lambda: Tensor.arange(10, 5, -3), forward_only=True)
helper_test_op([], lambda: torch.arange(11, 5, -3, dtype=torch.int32), lambda: Tensor.arange(11, 5, -3), forward_only=True)
helper_test_op([], lambda: torch.arange(1, 78, 2, dtype=torch.int32), lambda: Tensor.arange(1, 78, 2), forward_only=True)
helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True)
helper_test_op([], lambda: torch.arange(-30.2, -0.3, 0.75), lambda: Tensor.arange(-30.2, -0.3, 0.75), forward_only=True)
helper_test_op([], lambda: torch.arange(-50.3, -380.2, -2.25), lambda: Tensor.arange(-50.3, -380.2, -2.25), forward_only=True)
def test_arange_big(self):
helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True)
helper_test_op([], lambda: torch.arange(256, dtype=torch.int32), lambda: Tensor.arange(256), forward_only=True)
def test_sum_fake(self):
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))
@ -196,7 +212,7 @@ class TestOps(unittest.TestCase):
def test_where(self):
helper_test_op(
[(100,)],
lambda x: torch.where(x > 0.5, 4, 2),
lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32),
lambda x: (x > 0.5).where(4, 2), forward_only=True)
for shps in [[(8,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]:
@ -208,7 +224,7 @@ class TestOps(unittest.TestCase):
def test_where_permute(self):
helper_test_op(
[(5, 5)],
lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)),
lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32).permute((1, 0)),
lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True)
def _test_cmp(self, fxn, reverse=True):
@ -503,7 +519,7 @@ class TestOps(unittest.TestCase):
def test_multinomial(self):
# NOTE: this is random, so it has a very large atol
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1),
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1).type(torch.int32),
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
def test_small_cumsum(self):
@ -524,17 +540,17 @@ class TestOps(unittest.TestCase):
def test_argmax(self):
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
helper_test_op([(10,20)], lambda x: x.argmax(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
def test_argmin(self):
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
helper_test_op([(10,20)], lambda x: x.argmin(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, True), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
def test_einsum(self):
# matrix transpose
@ -765,8 +781,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552")
def test_softmax_argmax(self):
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32),
lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax().type(torch.int32),
lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
def test_log_softmax(self):
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
@ -1634,9 +1652,11 @@ class TestOps(unittest.TestCase):
def test_one_hot(self):
data = [1, 2, 4]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True)
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32),
lambda: Tensor(data).one_hot(6), forward_only=True)
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True)
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32),
lambda: Tensor(data).one_hot(8), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)