tensor reduction touchup (#6402)

- fixing spacing
- use get_args to get valid Literal values and raise ValueError to match, and a test for that
- use `Y` to be consistent
This commit is contained in:
chenyu 2024-09-08 03:55:51 -04:00 committed by GitHub
parent 65da03e186
commit 7df4373fd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 16 deletions

View File

@ -2109,6 +2109,9 @@ class TestOps(unittest.TestCase):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
lambda x,y: x.cross_entropy(y, reduction=r))
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
def test_cross_entropy_smoothing(self):
for ls in (0., 0.3, 0.7, 1.):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),

View File

@ -3050,37 +3050,37 @@ class Tensor:
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).cast(self.dtype).dropout(dropout_p) @ value
def _do_reduction(self, reduction:ReductionStr="mean") -> Tensor:
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
if reduction not in get_args(ReductionStr): raise ValueError(f"{reduction=} must be one of {get_args(ReductionStr)}")
reductions: Dict[str, Callable[[Tensor], Tensor]] = {"mean": Tensor.mean, "sum": Tensor.sum, "none": lambda x: x}
return reductions [reduction](self)
return reductions[reduction](self)
def binary_crossentropy(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor:
def binary_crossentropy(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y`.
Computes the binary cross-entropy loss between `self` and `Y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy(y).item())
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy(Y).item())
```
"""
return (-y*self.log() - (1-y)*(1-self).log())._do_reduction(reduction)
return (-Y*self.log() - (1-Y)*(1-self).log())._do_reduction(reduction)
def binary_crossentropy_logits(self, y:Tensor, reduction:ReductionStr="mean") -> Tensor:
def binary_crossentropy_logits(self, Y:Tensor, reduction:ReductionStr="mean") -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
Computes the binary cross-entropy loss between `self` and `Y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(y).item())
Y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(Y).item())
```
"""
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
return (self.maximum(0) - Y * self + (1 + self.abs().neg().exp()).log())._do_reduction(reduction)
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0, reduction:ReductionStr="mean") -> Tensor:
"""
@ -3107,7 +3107,7 @@ class Tensor:
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
return -(unreduced.sum() / loss_mask.sum() if reduction == "mean" else (unreduced.sum() if reduction == "sum" else unreduced))
def cross_entropy(self, y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
def cross_entropy(self, Y:Tensor, reduction:ReductionStr="mean", label_smoothing:float=0.0) -> Tensor:
"""
Compute the cross entropy loss between input logits and target.
@ -3127,9 +3127,9 @@ class Tensor:
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
y = y.one_hot(num_classes=cast(int, self.shape[1])) if y.ndim < 2 else y
y = (1 - label_smoothing)*y + label_smoothing / cast(int, y.shape[1])
ret = -self.log_softmax(axis=1).mul(y).sum(axis=1)
Y = Y.one_hot(num_classes=cast(int, self.shape[1])) if Y.ndim < 2 else Y
Y = (1 - label_smoothing)*Y + label_smoothing / cast(int, Y.shape[1])
ret = -self.log_softmax(axis=1).mul(Y).sum(axis=1)
return ret._do_reduction(reduction)
# ***** Tensor Properties *****