mirror of https://github.com/commaai/tinygrad.git
Einsum space fix (#2927)
* space removal in formula and a single test to cover it * space in torch einsum as well * replacing spaces in a var formula to support truncating all the spaces
This commit is contained in:
parent
b55b55d56e
commit
8de1fc2539
|
@ -465,6 +465,7 @@ class TestOps(unittest.TestCase):
|
|||
def test_einsum(self):
|
||||
# matrix transpose
|
||||
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
|
||||
helper_test_op([(150,150)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a))
|
||||
helper_test_op([(150,150)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a))
|
||||
helper_test_op([(20,30,40)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a))
|
||||
helper_test_op([(20,30,40)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a))
|
||||
|
|
|
@ -526,6 +526,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def einsum(formula:str, *raw_xs) -> Tensor:
|
||||
xs:Tuple[Tensor] = argfix(*raw_xs)
|
||||
formula = formula.replace(" ", "")
|
||||
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
|
||||
inputs = [x for x in cast(str,inputs_str).split(',')]
|
||||
assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}"
|
||||
|
|
Loading…
Reference in New Issue