From 999e3780e9295ddc85ba3202b269943ea1a556a0 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 6 Oct 2024 19:40:42 -0400 Subject: [PATCH] dropout contiguous after >= p (#6892) make it a bool buffer --- tinygrad/tensor.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0adbb557..50da24a6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -454,7 +454,7 @@ class Tensor: return counts0.cat(counts1) @staticmethod - def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor: + def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`. @@ -509,7 +509,7 @@ class Tensor: if getenv("MOCKGPU") and _device: out = out.to(_device) out.requires_grad = kwargs.get("requires_grad") - return out.contiguous() + return out.contiguous() if contiguous else out # ***** creation helper functions ***** @@ -673,13 +673,15 @@ class Tensor: """ dtype = kwargs.pop("dtype", self.dtype) device = kwargs.pop("device", self.device) + contiguous = kwargs.pop("contiguous", True) if isinstance(self.device, tuple): assert isinstance(self.lazydata, MultiLazyBuffer) if self.lazydata.axis is not None: - rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs] + rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \ + for lb in self.lazydata.lbs] return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) - return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) - return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs) + return Tensor.rand(*self.shape, dtype=dtype, contiguous=contiguous, **kwargs).shard(self.device) + return Tensor.rand(*self.shape, device=device, dtype=dtype, contiguous=contiguous, **kwargs) # ***** rng hlops ***** @@ -3131,7 +3133,7 @@ class Tensor: ``` """ if not Tensor.training or p == 0: return self - return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p)) + return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p) def one_hot(self, num_classes:int=-1) -> Tensor: """