diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c56bb5bd..771a6cd5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3090,7 +3090,7 @@ class Tensor: ``` """ if not Tensor.training or p == 0: return self - return self * (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p) * (1/(1.0 - p)) + return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p)) def one_hot(self, num_classes:int=-1) -> Tensor: """