default num_classes value for one_hot (#6182)

* num_classes=-1

If num_classes set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor.

* num_classes desc

comment to explain num_classes default and what that means.

* replacing ' with `
This commit is contained in:
Gabe Caldwell 2024-08-19 15:07:14 -04:00 committed by GitHub
parent 9328248610
commit bdd6325f31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -2033,9 +2033,13 @@ class TestOps(unittest.TestCase):
data = [1, 2, 4] data = [1, 2, 4]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32), helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32),
lambda: Tensor(data).one_hot(6), forward_only=True) lambda: Tensor(data).one_hot(6), forward_only=True)
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32),
lambda: Tensor(data).one_hot(), forward_only=True)
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]] data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32), helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32),
lambda: Tensor(data).one_hot(8), forward_only=True) lambda: Tensor(data).one_hot(8), forward_only=True)
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32),
lambda: Tensor(data).one_hot(), forward_only=True)
def test_masked_fill(self): def test_masked_fill(self):
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))

View File

@ -2948,15 +2948,18 @@ class Tensor:
if not Tensor.training or p == 0: return self if not Tensor.training or p == 0: return self
return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p)) return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
def one_hot(self, num_classes:int) -> Tensor: def one_hot(self, num_classes:int=-1) -> Tensor:
""" """
Converts `self` to a one-hot tensor. Converts `self` to a one-hot tensor.
`num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1.
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, 3, 3, 4]) t = Tensor([0, 1, 3, 3, 4])
print(t.one_hot(5).numpy()) print(t.one_hot(5).numpy())
``` ```
""" """
if num_classes == -1: num_classes = (self.max()+1).item()
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0) return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,