tinygrad/test/test_rearrange_einops.py

322 lines
12 KiB
Python

# modified from
# https://github.com/arogozhnikov/einops/blob/master/tests/test_examples.py
# https://github.com/arogozhnikov/einops/blob/master/tests/test_ops.py
# https://github.com/arogozhnikov/einops/blob/master/tests/test_parsing.py
import numpy as np
import unittest
from tinygrad import Tensor
class test_rearrange_examples(unittest.TestCase):
def test1(self):
# transpose
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
y = x.rearrange("b c h w -> b h w c")
assert tuple(y.shape) == (10, 30, 40, 20)
def test2(self):
# view / reshape
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
y = x.rearrange("b c h w -> b (c h w)")
assert tuple(y.shape) == (10, 20 * 30 * 40)
def test3(self):
# depth-to-space
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
y = x.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)", h1=2, w1=2)
assert tuple(y.shape) == (10, 5, 30 * 2, 40 * 2)
def test4(self):
# space-to-depth
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
y = x.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=2, w1=2)
assert tuple(y.shape) == (10, 20 * 4, 30 // 2, 40 // 2)
def test5(self):
# simple transposition
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
y = x.rearrange("b1 sound b2 letter -> b1 b2 sound letter")
assert tuple(y.shape) == (10, 30, 20, 40)
def test6(self):
# parsing parameters
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
t = x.rearrange("b c h w -> (b h w) c")
t = t[:, ::2] # replacement for dot-product, just changes size of second axis
assert tuple(t.shape) == (10 * 30 * 40, 10)
def test7(self):
x = Tensor(np.arange(10 * 20 * 30 * 40).reshape([10, 20, 30, 40]))
# split of embedding into groups
y1, y2 = x.rearrange("b (c g) h w -> g b c h w", g=2)
assert tuple(y1.shape) == (10, 10, 30, 40)
assert tuple(y2.shape) == (10, 10, 30, 40)
def test8(self):
x = Tensor(np.arange(10 * 20 * 1 * 1).reshape([10, 20, 1, 1]))
# squeeze - unsqueeze
y = x.rearrange("b c () () -> b c")
assert tuple(y.shape) == (10, 20)
y = y.rearrange("b c -> c b () ()")
assert tuple(y.shape) == (20, 10, 1, 1)
def test9(self):
x = Tensor(np.arange(10 * 20 * 1 * 1).reshape([10, 20, 1, 1]))
# squeeze - unsqueeze
y = x.rearrange("b c 1 1 -> b c")
assert tuple(y.shape) == (10, 20)
y = y.rearrange("b1 c -> c b1 1 1")
assert tuple(y.shape) == (20, 10, 1, 1)
def test_tensor_train_example_numpy(self):
# kept here just for a collection, only tested for numpy
# https://arxiv.org/pdf/1509.06569.pdf, (5)
x = Tensor.ones([3, 4, 5, 6])
rank = 4
# creating appropriate Gs
Gs = [Tensor.ones([d, d, rank, rank]) for d in x.shape]
Gs[0] = Gs[0][:, :, :1, :]
Gs[-1] = Gs[-1][:, :, :, :1]
# einsum way
y = x.reshape((1,) + x.shape)
for G in Gs:
# taking partial results left-to-right
# y = numpy.einsum('i j alpha beta, alpha i ... -> beta ... j', G, y)
y = Tensor(np.einsum("i j a b, a i ... -> b ... j", G.numpy(), y.numpy()))
y1 = y.reshape(-1)
# alternative way
y = x.reshape(-1)
for G in Gs:
i, j, alpha, beta = G.shape
y = y.rearrange("(i rest alpha) -> rest (alpha i)", alpha=alpha, i=i)
y = y @ G.rearrange("i j alpha beta -> (alpha i) (j beta)")
y = y.rearrange("rest (beta j) -> (beta rest j)", beta=beta, j=j)
y2 = y
assert np.allclose(y1.numpy(), y2.numpy())
# yet another way
y = x
for G in Gs:
i, j, alpha, beta = G.shape
y = y.rearrange("i ... (j alpha) -> ... j (alpha i)", alpha=alpha, i=i)
y = y @ G.rearrange("i j alpha beta -> (alpha i) (j beta)")
y3 = y.reshape(-1)
assert np.allclose(y1.numpy(), y3.numpy())
class test_rearrange_ops(unittest.TestCase):
def test_rearrange_errors(self):
x = Tensor.zeros([1, 1, 1, 1, 1])
x.rearrange("a b c d ... -> a b c ... d")
bad_patterns = [
"a b c d (...) -> a b c ... d", # collapsed ellipsis on input
"a b (c d ... -> a b c ... d", # unbalanced brackets
"a b* c d ... -> a b c ... d", # not alphanumeric
"a b c d -> a b c d -> a b c d", # two "->"
"a ... c ... -> ... a ... c", # two "..."
"a b c d e -> f b c d e", # name mismatch
]
for pattern in bad_patterns:
with self.assertRaises(AssertionError):
x.rearrange(pattern)
x.rearrange("... -> (...)")
with self.assertRaises(AssertionError):
x.rearrange("(...) -> (...)")
y = Tensor.zeros([8, 1])
y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=2, a2=2)
with self.assertRaises(RuntimeError):
## should fail as not enough dimensions specified
y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=2)
with self.assertRaises(ValueError):
## should fail as 6 does not divide 8
y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=3, a2=2)
with self.assertRaises(AssertionError):
## incorrect dimension provided for an axis that is only permuted
y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=2, a2=2, b=2)
with self.assertRaises(AssertionError):
## unused axis provided
y.rearrange("(a b c) d -> a b c d", b=2, c=2, e=2)
def test_rearrange_ellipsis_ops(self):
identity_patterns = [
"...->...",
"a b c d e-> a b c d e",
"a b c d e ...-> ... a b c d e",
"a b c d e ...-> a ... b c d e",
"... a b c d e -> ... a b c d e",
"a ... e-> a ... e",
"a ... -> a ... ",
"a ... c d e -> a (...) c d e",
]
equivalent_rearrange_patterns = [
("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
("a b c d e -> a b c d e", "... -> ... "),
("a b c d e -> (a b c d e)", "... -> (...)"),
("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
]
xnp = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
x = Tensor(xnp)
for pattern in identity_patterns:
assert np.array_equal(xnp, x.rearrange(pattern).numpy()), pattern
for pattern1, pattern2 in equivalent_rearrange_patterns:
assert np.array_equal(x.rearrange(pattern1).numpy(), x.rearrange(pattern2).numpy())
def test_rearrange_consistency(self):
shape = [1, 2, 3, 5, 7, 11]
xnp = np.arange(np.prod(shape)).reshape(shape)
x = Tensor(xnp)
for pattern in [
"a b c d e f -> a b c d e f",
"b a c d e f -> a b d e f c",
"a b c d e f -> f e d c b a",
"a b c d e f -> (f e) d (c b a)",
"a b c d e f -> (f e d c b a)",
]:
result = x.rearrange(pattern).numpy()
assert len(np.setdiff1d(xnp, result)) == 0
assert result.dtype == xnp.dtype
result = x.rearrange("a b c d e f -> a (b) (c d e) f").numpy()
assert np.array_equal(xnp.flatten(), result.flatten())
result = x.rearrange("a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11").numpy()
assert np.array_equal(xnp, result)
result1 = x.rearrange("a b c d e f -> f e d c b a").numpy()
result2 = x.rearrange("f e d c b a -> a b c d e f").numpy()
assert np.array_equal(result1, result2)
result = x.rearrange("a b c d e f -> (f d) c (e b) a").rearrange("(f d) c (e b) a -> a b c d e f", b=2, d=5).numpy()
assert np.array_equal(xnp, result)
sizes = dict(zip("abcdef", shape))
temp = x.rearrange("a b c d e f -> (f d) c (e b) a", **sizes)
result = temp.rearrange("(f d) c (e b) a -> a b c d e f", **sizes).numpy()
assert np.array_equal(xnp, result)
x2 = np.arange(2 * 3 * 4).reshape([2, 3, 4])
result = Tensor(x2).rearrange("a b c -> b c a").numpy()
assert x2[1, 2, 3] == result[2, 3, 1]
assert x2[0, 1, 2] == result[1, 2, 0]
def test_rearrange_permutations(self):
# tests random permutation of axes against two independent numpy ways
for n_axes in range(1, 10):
x = np.arange(2**n_axes).reshape([2] * n_axes)
permutation = np.random.permutation(n_axes)
left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
right_expression = " ".join("i" + str(axis) for axis in permutation)
expression = left_expression + " -> " + right_expression
result = Tensor(x).rearrange(expression).numpy()
for pick in np.random.randint(0, 2, [10, n_axes]):
assert x[tuple(pick)] == result[tuple(pick[permutation])]
for n_axes in range(1, 10):
x = np.arange(2**n_axes).reshape([2] * n_axes)
permutation = np.random.permutation(n_axes)
left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
expression = left_expression + " -> " + right_expression
result = Tensor(x).rearrange(expression).numpy()
assert result.shape == x.shape
expected_result = np.zeros_like(x)
for original_axis, result_axis in enumerate(permutation):
expected_result |= ((x >> original_axis) & 1) << result_axis
assert np.array_equal(result, expected_result)
def check_expression_helper(expression: str):
Tensor.ones((1, 2, 3, 4, 5, 6, 7))
class test_rearrange_parsing(unittest.TestCase):
def test_elementary_axis_name(self):
for name in [
"a",
"b",
"h",
"dx",
"h1",
"zz",
"i9123",
"somelongname",
"Alex",
"camelCase",
"u_n_d_e_r_score",
"unreasonablyLongAxisName",
]:
Tensor.ones((1,)).rearrange(f"{name} -> {name}")
for name in ["2b", "12", "_startWithUnderscore", "endWithUnderscore_", "_"]:
with self.assertRaises(AssertionError):
Tensor.ones((1,)).rearrange(f"{name} -> {name}")
with self.assertRaises(RuntimeError):
Tensor.ones((1,)).rearrange(" -> ")
def test_invalid_expressions(self):
# double ellipsis should raise an error
def _test_expression(expression: str):
Tensor.ones((2, 3, 4, 5, 6)).rearrange(f"{expression} -> {expression}")
_test_expression("... a b c d")
with self.assertRaises(AssertionError):
_test_expression("... a b c d ...")
with self.assertRaises(AssertionError):
_test_expression("... a b c (d ...)")
with self.assertRaises(AssertionError):
_test_expression("(... a) b c (d ...)")
# double/missing/enclosed parenthesis
Tensor.ones((2, 3, 4, 5, 6)).rearrange("a b c d ... -> (a) b c (d ...)")
with self.assertRaises(AssertionError):
_test_expression("(a)) b c (d ...)")
with self.assertRaises(AssertionError):
_test_expression("(a b c (d ...)")
with self.assertRaises(AssertionError):
_test_expression("(a) (()) b c (d ...)")
with self.assertRaises(AssertionError):
_test_expression("(a) ((b c) (d ...))")
# invalid identifiers
_test_expression("camelCase under_scored cApiTaLs ß ...")
with self.assertRaises(AssertionError):
_test_expression("1a")
with self.assertRaises(AssertionError):
_test_expression("_pre")
with self.assertRaises(AssertionError):
_test_expression("...pre")
with self.assertRaises(AssertionError):
_test_expression("pre...")
def test_unicode_ellipsis(self):
equivalent_rearrange_patterns = [
("a b … -> (a b) … ", "a b ... -> (a b) ... "),
("… c d e -> … (c d) e", "... c d e -> ... (c d) e"),
("… -> … ", "... -> ... "),
("… -> (…)", "... -> (...)"),
("a b … -> b (…) a", "a b ... -> b (...) a"),
("a b … e -> b (a …) e", "a b ... e -> b (a ...) e"),
]
xnp = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
x = Tensor(xnp)
for pattern1, pattern2 in equivalent_rearrange_patterns:
assert np.array_equal(x.rearrange(pattern1).numpy(), x.rearrange(pattern2).numpy())
if __name__ == "__main__":
unittest.main()