improve readability (#4809)

This commit is contained in:
George Hotz 2024-06-03 14:57:57 +02:00 committed by GitHub
parent eecfdd2f6e
commit 2dae657415
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 6 deletions

View File

@ -104,13 +104,14 @@ class Tensor:
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variables used for autograd graph construction
# internal variable used for autograd graph construction
self._ctx: Optional[Function] = None
# create a LazyBuffer from the different types of inputs
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, list):
if dtype is None:
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
@ -120,16 +121,25 @@ class Tensor:
elif isinstance(data, np.ndarray):
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
# by this point, it has to be a LazyBuffer
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
# data is a LazyBuffer, but it might be on the wrong device
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
if isinstance(device, tuple):
# TODO: what if it's a MultiLazyBuffer on other devices?
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
# if device is a tuple, we should have/construct a MultiLazyBuffer
if isinstance(data, MultiLazyBuffer):
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
else:
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
else:
self.lazydata = data if data.device == device else data.copy_to_device(device)
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
def __hash__(self): return id(self)
@ -196,6 +206,7 @@ class Tensor:
if not self.lazydata.is_realized(): return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
def detach(self) -> Tensor:
"""
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.