mirror of https://github.com/commaai/tinygrad.git
Quickstart: Upgrade section "Training" to new code (#1663)
Co-authored-by: Dave Farago <dfarago@innoopract.com>
This commit is contained in:
parent
29adae84eb
commit
1ba8f0dca3
|
@ -141,18 +141,15 @@ For our loss function we will be using sparse categorical cross entropy loss.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# from tinygrad.tensor import sparse_categorical_crossentropy
|
# from tinygrad.tensor import sparse_categorical_crossentropy
|
||||||
def sparse_categorical_crossentropy(out, Y, ignore_index=-1):
|
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||||
loss_mask = Y != ignore_index
|
loss_mask = Y != ignore_index
|
||||||
num_classes = out.shape[-1]
|
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||||
y_counter = Tensor.arange(num_classes, requires_grad=False).unsqueeze(0).expand(Y.numel(), num_classes)
|
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||||
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
|
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||||
y = y * loss_mask.reshape(-1, 1)
|
|
||||||
y = y.reshape(*Y.shape, num_classes)
|
|
||||||
return out.log_softmax().mul(y).sum() / loss_mask.sum()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support.
|
As we can see in this implementation of cross entropy loss, there are certain operations that tinygrad does not support.
|
||||||
Namely, operations that are load/store like indexing a tensor with another tensor or assigning a value to a tensor at a certain index.
|
Namely, operations that are load/store or assigning a value to a tensor at a certain index.
|
||||||
Load/store ops are not supported in tinygrad because they add complexity when trying to port to different backends and 90% of the models out there don't use/need them.
|
Load/store ops are not supported in tinygrad because they add complexity when trying to port to different backends and 90% of the models out there don't use/need them.
|
||||||
|
|
||||||
For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.
|
For our optimizer we will be using the traditional stochastic gradient descent optimizer with a learning rate of 3e-4.
|
||||||
|
|
Loading…
Reference in New Issue