implement logcumsumexp (#6921)

* implement logcumsumexp

* change axis=None to axis=0
This commit is contained in:
jeffzh4ng 2024-10-06 10:45:36 -04:00 committed by GitHub
parent f588169fdc
commit 19a7e41113
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 0 deletions

View File

@ -13,6 +13,7 @@
::: tinygrad.Tensor.softmax
::: tinygrad.Tensor.log_softmax
::: tinygrad.Tensor.logsumexp
::: tinygrad.Tensor.logcumsumexp
::: tinygrad.Tensor.argmax
::: tinygrad.Tensor.argmin

View File

@ -1072,6 +1072,14 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
def test_logcumsumexp(self):
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7)
def test_sinh(self):
helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6)
# TODO: backward nan instead of inf

View File

@ -1753,6 +1753,33 @@ class Tensor:
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
def logcumsumexp(self, axis=0):
"""
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
You can pass in the `axis` keyword argument to control the axis along which
the log-cum-sum-exp is computed.
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp(axis=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.logcumsumexp(axis=1).numpy())
```
"""
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().cumsum(axis=axis).log() + m
def argmax(self, axis=None, keepdim=False):
"""
Returns the indices of the maximum value of the tensor along the specified axis.