no more toCPU path (#1624)

This commit is contained in:
George Hotz 2023-08-22 11:07:26 -07:00 committed by GitHub
parent 463dece63e
commit de1fcc418f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -184,8 +184,7 @@ class LazyBuffer:
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return self.loadop(LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
# NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this?
def toCPU(self):
def toCPU(self) -> np.ndarray:
assert self.dtype.np, f"{self.dtype} is not supported in toCPU"
realized = self.cast((dtypes.from_np(self.dtype.np), False)).contiguous().realize().realized
return cast(RawBuffer, realized).toCPU().reshape(self.shape)
@ -371,6 +370,8 @@ def _realize_custom(buffer: LazyBuffer) -> None:
def _realize_from(buffer: LazyBuffer) -> None:
rawbuf = buffer.op.src[0].realize()
assert rawbuf.realized, "realize failed?"
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
# TODO: make this generic
if isinstance(rawbuf.realized, RawDiskBuffer) and issubclass(Device[buffer.device].buffer, RawBufferMapped):
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())

View File

@ -109,7 +109,7 @@ class Tensor:
return self
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self) -> np.ndarray: return self.lazydata.toCPU()
def numpy(self) -> np.ndarray: return self.to('CPU').lazydata.toCPU()
# TODO: if things are realized this won't work
def to_(self, device:str):
@ -117,7 +117,7 @@ class Tensor:
self.lazydata.device = device
if self.grad: self.grad.to_(device)
def to(self, device:str):
def to(self, device:str) -> Tensor:
ret = Tensor(self.lazydata, device)
if self.grad: ret.grad = self.grad.to(device)
return ret