dropout contiguous after >= p (#6892)

make it a bool buffer
This commit is contained in:
chenyu 2024-10-06 19:40:42 -04:00 committed by GitHub
parent 9eb6eef441
commit 999e3780e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 6 deletions

View File

@ -454,7 +454,7 @@ class Tensor:
return counts0.cat(counts1) return counts0.cat(counts1)
@staticmethod @staticmethod
def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, **kwargs) -> Tensor: def rand(*shape, device:Optional[str]=None, dtype:Optional[DTypeLike]=None, contiguous:bool=True, **kwargs) -> Tensor:
""" """
Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`. Creates a tensor with the given shape, filled with random values from a uniform distribution over the interval `[0, 1)`.
@ -509,7 +509,7 @@ class Tensor:
if getenv("MOCKGPU") and _device: out = out.to(_device) if getenv("MOCKGPU") and _device: out = out.to(_device)
out.requires_grad = kwargs.get("requires_grad") out.requires_grad = kwargs.get("requires_grad")
return out.contiguous() return out.contiguous() if contiguous else out
# ***** creation helper functions ***** # ***** creation helper functions *****
@ -673,13 +673,15 @@ class Tensor:
""" """
dtype = kwargs.pop("dtype", self.dtype) dtype = kwargs.pop("dtype", self.dtype)
device = kwargs.pop("device", self.device) device = kwargs.pop("device", self.device)
contiguous = kwargs.pop("contiguous", True)
if isinstance(self.device, tuple): if isinstance(self.device, tuple):
assert isinstance(self.lazydata, MultiLazyBuffer) assert isinstance(self.lazydata, MultiLazyBuffer)
if self.lazydata.axis is not None: if self.lazydata.axis is not None:
rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype).lazydata) for lb in self.lazydata.lbs] rands = [cast(LazyBuffer, Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata) \
for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor(MultiLazyBuffer(rands, self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) return Tensor.rand(*self.shape, dtype=dtype, contiguous=contiguous, **kwargs).shard(self.device)
return Tensor.rand(*self.shape, device=device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=device, dtype=dtype, contiguous=contiguous, **kwargs)
# ***** rng hlops ***** # ***** rng hlops *****
@ -3131,7 +3133,7 @@ class Tensor:
``` ```
""" """
if not Tensor.training or p == 0: return self if not Tensor.training or p == 0: return self
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float) >= p).where(self, 0) * (1/(1.0 - p)) return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
def one_hot(self, num_classes:int=-1) -> Tensor: def one_hot(self, num_classes:int=-1) -> Tensor:
""" """