mirror of https://github.com/commaai/tinygrad.git
implement logcumsumexp (#6921)
* implement logcumsumexp * change axis=None to axis=0
This commit is contained in:
parent
f588169fdc
commit
19a7e41113
|
@ -13,6 +13,7 @@
|
|||
::: tinygrad.Tensor.softmax
|
||||
::: tinygrad.Tensor.log_softmax
|
||||
::: tinygrad.Tensor.logsumexp
|
||||
::: tinygrad.Tensor.logcumsumexp
|
||||
::: tinygrad.Tensor.argmax
|
||||
::: tinygrad.Tensor.argmin
|
||||
|
||||
|
|
|
@ -1072,6 +1072,14 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
def test_logcumsumexp(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
def test_sinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6)
|
||||
# TODO: backward nan instead of inf
|
||||
|
|
|
@ -1753,6 +1753,33 @@ class Tensor:
|
|||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
||||
|
||||
def logcumsumexp(self, axis=0):
|
||||
"""
|
||||
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
|
||||
|
||||
The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials.
|
||||
|
||||
You can pass in the `axis` keyword argument to control the axis along which
|
||||
the log-cum-sum-exp is computed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logcumsumexp().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logcumsumexp(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logcumsumexp(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().cumsum(axis=axis).log() + m
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
"""
|
||||
Returns the indices of the maximum value of the tensor along the specified axis.
|
||||
|
|
Loading…
Reference in New Issue