From bdd6325f315c62f3ab870502d58dc6e8d33340b2 Mon Sep 17 00:00:00 2001 From: Gabe Caldwell Date: Mon, 19 Aug 2024 15:07:14 -0400 Subject: [PATCH] default num_classes value for one_hot (#6182) * num_classes=-1 If num_classes set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor. * num_classes desc comment to explain num_classes default and what that means. * replacing ' with ` --- test/test_ops.py | 4 ++++ tinygrad/tensor.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 3c8a674f..9cc07ddf 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2033,9 +2033,13 @@ class TestOps(unittest.TestCase): data = [1, 2, 4] helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32), lambda: Tensor(data).one_hot(6), forward_only=True) + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32), + lambda: Tensor(data).one_hot(), forward_only=True) data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]] helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32), lambda: Tensor(data).one_hot(8), forward_only=True) + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32), + lambda: Tensor(data).one_hot(), forward_only=True) def test_masked_fill(self): helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a26acfe6..c7f04e4c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2948,15 +2948,18 @@ class Tensor: if not Tensor.training or p == 0: return self return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p)) - def one_hot(self, num_classes:int) -> Tensor: + def one_hot(self, num_classes:int=-1) -> Tensor: """ Converts `self` to a one-hot tensor. + `num_classes` defaults to -1, which means num_classes will be inferred as max(self) + 1. + ```python exec="true" source="above" session="tensor" result="python" t = Tensor([0, 1, 3, 3, 4]) print(t.one_hot(5).numpy()) ``` """ + if num_classes == -1: num_classes = (self.max()+1).item() return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,