Einsum space fix (#2927)

* space removal in formula and a single test to cover it

* space in torch einsum as well

* replacing spaces in a var formula to support truncating all the spaces
This commit is contained in:
Isalia20 2023-12-24 10:23:27 +04:00 committed by GitHub
parent b55b55d56e
commit 8de1fc2539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 0 deletions

View File

@ -465,6 +465,7 @@ class TestOps(unittest.TestCase):
def test_einsum(self):
# matrix transpose
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
helper_test_op([(150,150)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
helper_test_op([(150,150)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a))
helper_test_op([(20,30,40)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a))
helper_test_op([(20,30,40)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a))

View File

@ -526,6 +526,7 @@ class Tensor:
@staticmethod
def einsum(formula:str, *raw_xs) -> Tensor:
xs:Tuple[Tensor] = argfix(*raw_xs)
formula = formula.replace(" ", "")
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
inputs = [x for x in cast(str,inputs_str).split(',')]
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"