use where in dropout (#6758)

should save memory since we only store mask in bool instead of the upcasted used in mul
This commit is contained in:
chenyu 2024-09-27 11:11:43 -04:00 committed by GitHub
parent 76b3c1e818
commit bc82f8c5be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -3090,7 +3090,7 @@ class Tensor:
```
"""
if not Tensor.training or p == 0: return self
return self * (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p) * (1/(1.0 - p))
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p))
def one_hot(self, num_classes:int=-1) -> Tensor:
"""