diff --git a/test/test_rearrange_einops.py b/test/test_rearrange_einops.py index f2651fb6..faf159b9 100644 --- a/test/test_rearrange_einops.py +++ b/test/test_rearrange_einops.py @@ -61,6 +61,14 @@ class test_rearrange_examples(unittest.TestCase): y = y.rearrange("b c -> c b () ()") assert tuple(y.shape) == (20, 10, 1, 1) + def test9(self): + x = Tensor(np.arange(10 * 20 * 1 * 1).reshape([10, 20, 1, 1])) + # squeeze - unsqueeze + y = x.rearrange("b c 1 1 -> b c") + assert tuple(y.shape) == (10, 20) + y = y.rearrange("b1 c -> c b1 1 1") + assert tuple(y.shape) == (20, 10, 1, 1) + def test_tensor_train_example_numpy(self): # kept here just for a collection, only tested for numpy # https://arxiv.org/pdf/1509.06569.pdf, (5) @@ -131,6 +139,9 @@ class test_rearrange_ops(unittest.TestCase): with self.assertRaises(AssertionError): ## incorrect dimension provided for an axis that is only permuted y.rearrange("(a1 a2 a3) b -> b a3 a2 a1", a1=2, a2=2, b=2) + with self.assertRaises(AssertionError): + ## unused axis provided + y.rearrange("(a b c) d -> a b c d", b=2, c=2, e=2) def test_rearrange_ellipsis_ops(self): identity_patterns = [ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 90d633e6..207903cb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1657,15 +1657,17 @@ class Tensor: ``` """ def parse_formula(formula: str): - lparens, rparens = map(lambda x: [i for i, ch in enumerate(formula.split()) if ch == x], ("(", ")")) + tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() + lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) pairs = list(zip(lparens, rparens)) assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" - return [name for name in formula.split() if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] + return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] assert formula.count("->") == 1, 'need exactly one "->" in formula' - (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.replace("…", "...").replace("(", " ( ").replace(")", " ) ").split("->")) + (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) + for name in sizes: assert name in lhs, f"axis {name} is not used in transform" assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"