mirror of https://github.com/commaai/tinygrad.git
test output dtypes matche in test_ops (#3703)
need to cast some torch output to int32 because torch default returns int64 for index related function close #2797
This commit is contained in:
parent
798970cfad
commit
f599c6e7f4
|
@ -30,8 +30,9 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
|||
|
||||
def compare(s, tinygrad_output, torch_output, atol, rtol):
|
||||
if PRINT_TENSORS: print(s, tinygrad_output, torch_output)
|
||||
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
|
||||
try:
|
||||
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
|
||||
assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}"
|
||||
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
|
||||
except Exception as e:
|
||||
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")
|
||||
|
@ -83,19 +84,29 @@ class TestOps(unittest.TestCase):
|
|||
if not CI: print("\ntesting %40r torch/tinygrad exception: %s / %s" % (shps, torch_cm.exception, tinygrad_cm.exception), end="")
|
||||
|
||||
def test_full_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
|
||||
helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True)
|
||||
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
|
||||
helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True)
|
||||
|
||||
def test_full(self):
|
||||
helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True)
|
||||
helper_test_op([], lambda: torch.full((45,65), 4, dtype=torch.int32), lambda: Tensor.full((45,65), 4), forward_only=True)
|
||||
|
||||
def test_zeros(self):
|
||||
helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.zeros([]), lambda: Tensor.zeros([]), forward_only=True)
|
||||
|
||||
def test_zeros_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
|
||||
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
|
||||
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
|
||||
helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True)
|
||||
|
||||
def test_empty_0(self):
|
||||
|
@ -105,9 +116,14 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True)
|
||||
helper_test_op([], lambda: torch.ones([]), lambda: Tensor.ones([]), forward_only=True)
|
||||
|
||||
def test_ones_like(self):
|
||||
a = Tensor([[1,2,3],[4,5,6]])
|
||||
b = torch.tensor([[1,2,3],[4,5,6]])
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.float32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float32)
|
||||
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
|
||||
|
||||
a = Tensor([[1,2,3],[4,5,6]], dtype=dtypes.int32)
|
||||
b = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.int32)
|
||||
helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True)
|
||||
|
||||
def test_eye(self):
|
||||
|
@ -116,7 +132,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True)
|
||||
|
||||
def test_split(self):
|
||||
def tensor(s): return torch.arange(math.prod(s)).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
|
||||
def tensor(s): return torch.arange(math.prod(s), dtype=torch.int32).reshape(s), Tensor.arange(math.prod(s)).reshape(s)
|
||||
test_cases = [
|
||||
(tensor((10,)), 5, {}),
|
||||
(tensor((10,)), [1,4,5], {}),
|
||||
|
@ -135,42 +151,42 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True)
|
||||
|
||||
def test_chunk(self):
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
|
||||
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 1)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(6, 0)
|
||||
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(6, 0)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(3, -1)
|
||||
tor = torch.arange(13, dtype=torch.int32).repeat(8, 1).chunk(3, -1)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2)
|
||||
tor = torch.arange(13, dtype=torch.int32).repeat(8, 3, 3).chunk(3, -2)
|
||||
ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(1, 78, 2), lambda: Tensor.arange(1, 78, 2), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(5, 10, 3, dtype=torch.int32), lambda: Tensor.arange(5, 10, 3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(10, 5, -3, dtype=torch.int32), lambda: Tensor.arange(10, 5, -3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(11, 5, -3, dtype=torch.int32), lambda: Tensor.arange(11, 5, -3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(1, 78, 2, dtype=torch.int32), lambda: Tensor.arange(1, 78, 2), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(5.5, 175.5, 2.5), lambda: Tensor.arange(5.5, 175.5, 2.5), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(-30.2, -0.3, 0.75), lambda: Tensor.arange(-30.2, -0.3, 0.75), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(-50.3, -380.2, -2.25), lambda: Tensor.arange(-50.3, -380.2, -2.25), forward_only=True)
|
||||
|
||||
def test_arange_big(self):
|
||||
helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(256, dtype=torch.int32), lambda: Tensor.arange(256), forward_only=True)
|
||||
|
||||
def test_sum_fake(self):
|
||||
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))
|
||||
|
@ -196,7 +212,7 @@ class TestOps(unittest.TestCase):
|
|||
def test_where(self):
|
||||
helper_test_op(
|
||||
[(100,)],
|
||||
lambda x: torch.where(x > 0.5, 4, 2),
|
||||
lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32),
|
||||
lambda x: (x > 0.5).where(4, 2), forward_only=True)
|
||||
|
||||
for shps in [[(8,),(1,),(1,)], [(10,10),(10,),(10,)], [(100,)]*3, [(10,10)]*3]:
|
||||
|
@ -208,7 +224,7 @@ class TestOps(unittest.TestCase):
|
|||
def test_where_permute(self):
|
||||
helper_test_op(
|
||||
[(5, 5)],
|
||||
lambda x: torch.where(x > 0.5, 4, 2).permute((1, 0)),
|
||||
lambda x: torch.where(x > 0.5, 4, 2).type(torch.int32).permute((1, 0)),
|
||||
lambda x: (x > 0.5).where(4, 2).permute((1, 0)), forward_only=True)
|
||||
|
||||
def _test_cmp(self, fxn, reverse=True):
|
||||
|
@ -503,7 +519,7 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_multinomial(self):
|
||||
# NOTE: this is random, so it has a very large atol
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1),
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1).type(torch.int32),
|
||||
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
|
||||
|
||||
def test_small_cumsum(self):
|
||||
|
@ -524,17 +540,17 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_argmax(self):
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
|
||||
|
||||
def test_argmin(self):
|
||||
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(0, False).type(torch.int32), lambda x: x.argmin(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
|
||||
|
||||
def test_einsum(self):
|
||||
# matrix transpose
|
||||
|
@ -765,8 +781,10 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552")
|
||||
def test_softmax_argmax(self):
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32),
|
||||
lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: x.softmax(1).argmax().type(torch.int32),
|
||||
lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7)
|
||||
def test_log_softmax(self):
|
||||
helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7)
|
||||
|
@ -1634,9 +1652,11 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_one_hot(self):
|
||||
data = [1, 2, 4]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True)
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(6), forward_only=True)
|
||||
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True)
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(8), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
|
|
Loading…
Reference in New Issue