mirror of https://github.com/commaai/tinygrad.git
improve readability (#4809)
This commit is contained in:
parent
eecfdd2f6e
commit
2dae657415
|
@ -104,13 +104,14 @@ class Tensor:
|
|||
# None (the default) will be updated to True if it's put in an optimizer
|
||||
self.requires_grad: Optional[bool] = requires_grad
|
||||
|
||||
# internal variables used for autograd graph construction
|
||||
# internal variable used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
|
||||
# create a LazyBuffer from the different types of inputs
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, get_args(ConstType)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, Variable): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data.unbind()[1]), device, data)
|
||||
elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8))
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, list):
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
|
@ -120,16 +121,25 @@ class Tensor:
|
|||
elif isinstance(data, np.ndarray):
|
||||
if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item())
|
||||
else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data)
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
|
||||
# by this point, it has to be a LazyBuffer
|
||||
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)):
|
||||
raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
|
||||
# data is a LazyBuffer, but it might be on the wrong device
|
||||
if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
if isinstance(device, tuple):
|
||||
# TODO: what if it's a MultiLazyBuffer on other devices?
|
||||
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data
|
||||
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
||||
if isinstance(data, MultiLazyBuffer):
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data
|
||||
else:
|
||||
self.lazydata = MultiLazyBuffer.from_sharded(data, device, None)
|
||||
else:
|
||||
self.lazydata = data if data.device == device else data.copy_to_device(device)
|
||||
|
||||
def __repr__(self): return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
||||
def __repr__(self):
|
||||
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
||||
|
||||
# Python has a non moving GC, so this should be okay
|
||||
def __hash__(self): return id(self)
|
||||
|
@ -196,6 +206,7 @@ class Tensor:
|
|||
if not self.lazydata.is_realized(): return self.replace(x)
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
return self
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
|
||||
|
|
Loading…
Reference in New Issue