diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 5fbf376f..2c04a115 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -13,6 +13,7 @@ ::: tinygrad.Tensor.softmax ::: tinygrad.Tensor.log_softmax ::: tinygrad.Tensor.logsumexp +::: tinygrad.Tensor.logcumsumexp ::: tinygrad.Tensor.argmax ::: tinygrad.Tensor.argmin diff --git a/test/test_ops.py b/test/test_ops.py index 342b27a4..c4494aa3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1072,6 +1072,14 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: torch.logsumexp(x, dim=0), lambda x: x.logsumexp(0), atol=1e-7, grad_atol=1e-7) helper_test_op([()], lambda x: torch.logsumexp(x, dim=-1), lambda x: x.logsumexp(-1), atol=1e-7, grad_atol=1e-7) + def test_logcumsumexp(self): + helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: torch.logcumsumexp(x, dim=1), lambda x: x.logcumsumexp(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(45)], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=0), lambda x: x.logcumsumexp(), atol=1e-7, grad_atol=1e-7) + helper_test_op([()], lambda x: torch.logcumsumexp(x, dim=-1), lambda x: x.logcumsumexp(-1), atol=1e-7, grad_atol=1e-7) + def test_sinh(self): helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6) # TODO: backward nan instead of inf diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8a2e38aa..507fcc25 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1753,6 +1753,33 @@ class Tensor: m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis) + def logcumsumexp(self, axis=0): + """ + Computes the log-cumsum-exp of the tensor along the specified axis or axes. + + The log-cumsum-exp function is a numerically stable way to compute the logarithm of the cumulative sum of exponentials. + + You can pass in the `axis` keyword argument to control the axis along which + the log-cum-sum-exp is computed. + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.randn(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.logcumsumexp(axis=1).numpy()) + ``` + """ + m = self.max(axis=axis, keepdim=True) + return (self - m).exp().cumsum(axis=axis).log() + m + def argmax(self, axis=None, keepdim=False): """ Returns the indices of the maximum value of the tensor along the specified axis.