mirror of https://github.com/commaai/tinygrad.git
use where in dropout (#6758)
should save memory since we only store mask in bool instead of the upcasted used in mul
This commit is contained in:
parent
76b3c1e818
commit
bc82f8c5be
|
@ -3090,7 +3090,7 @@ class Tensor:
|
|||
```
|
||||
"""
|
||||
if not Tensor.training or p == 0: return self
|
||||
return self * (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p) * (1/(1.0 - p))
|
||||
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p))
|
||||
|
||||
def one_hot(self, num_classes:int=-1) -> Tensor:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue