mirror of https://github.com/commaai/tinygrad.git
ResNet training changes (update benchmark) (#2390)
* default arg for chunk * bring back to_ * good changes * new set * unused hash * fix optim * new torch loader * fix test lr scheduler
This commit is contained in:
parent
2dec86970a
commit
cbb8486779
|
@ -43,3 +43,4 @@ temp
|
|||
coverage.xml
|
||||
htmlcov
|
||||
outputs_yolov8
|
||||
wandb
|
||||
|
|
|
@ -265,6 +265,14 @@ try:
|
|||
hipMemcpyDeviceToDevice = 3
|
||||
hipMemcpyDefault = 4
|
||||
|
||||
_libhip.hipHostMalloc.restype = int
|
||||
_libhip.hipHostMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, ctypes.c_uint32]
|
||||
def hipHostMalloc(count, flags=0):
|
||||
ptr = ctypes.c_void_p()
|
||||
status = _libhip.hipHostMalloc(ctypes.byref(ptr), count, flags)
|
||||
hipCheckStatus(status)
|
||||
return ptr.value
|
||||
|
||||
_libhip.hipMemcpy.restype = int
|
||||
_libhip.hipMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
|
||||
def hipMemcpy(dst, src, count, direction):
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor
|
|||
class LR_Scheduler:
|
||||
def __init__(self, optimizer: Optimizer):
|
||||
self.optimizer = optimizer
|
||||
self.epoch_counter = Tensor([0], requires_grad=False)
|
||||
self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device)
|
||||
|
||||
def get_lr(self): pass
|
||||
|
||||
|
@ -61,7 +61,7 @@ class CosineAnnealingLR(LR_Scheduler):
|
|||
self.eta_max = optimizer.lr.numpy()[0]
|
||||
|
||||
def get_lr(self) -> Tensor:
|
||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))])
|
||||
return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device)
|
||||
|
||||
class OneCycleLR(LR_Scheduler):
|
||||
def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float,
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import tinygrad.nn as nn
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.helpers import fetch
|
||||
from extra.utils import get_child
|
||||
|
||||
class BasicBlock:
|
||||
|
@ -131,11 +133,8 @@ class ResNet:
|
|||
}
|
||||
|
||||
self.url = model_urls[(self.num, self.groups, self.base_width)]
|
||||
|
||||
from torch.hub import load_state_dict_from_url
|
||||
state_dict = load_state_dict_from_url(self.url, progress=True)
|
||||
for k, v in state_dict.items():
|
||||
obj = get_child(self, k)
|
||||
for k, v in torch_load(fetch(self.url)).items():
|
||||
obj: Tensor = get_child(self, k)
|
||||
dat = v.detach().numpy()
|
||||
|
||||
if 'fc.' in k and obj.shape != dat.shape:
|
||||
|
|
|
@ -57,7 +57,9 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None):
|
|||
class TestLrScheduler(unittest.TestCase):
|
||||
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol):
|
||||
accs = opts.pop('accs', None)
|
||||
tinygrad_optim, torch_optim = Adam([], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
|
||||
test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr]
|
||||
test_tensor.mean().backward()
|
||||
tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01)
|
||||
tinygrad_sched, torch_sched = tinygrad_sched(tinygrad_optim, **opts), torch_sched(torch_optim, **opts)
|
||||
|
||||
tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs)
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import Timing, CI
|
||||
|
||||
N = 4096 if CI else 16384
|
||||
class TestCopySpeed(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls): Device[Device.DEFAULT].synchronize()
|
||||
|
||||
def testCopySHMtoDefault(self):
|
||||
t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize()
|
||||
#t = Tensor.empty(N, N, device="disk:shm:test_X").realize()
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
with Timing("queue: "):
|
||||
t.to(Device.DEFAULT).realize()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
|
||||
def testCopyCPUtoDefault(self):
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
with Timing("queue: "):
|
||||
t.to(Device.DEFAULT).realize()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
|
||||
def testCopyCPUtoDefaultFresh(self):
|
||||
print("fresh copy")
|
||||
for _ in range(3):
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
with Timing("queue: "):
|
||||
t.to(Device.DEFAULT).realize()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
del t
|
||||
|
||||
def testCopyDefaulttoCPU(self):
|
||||
t = Tensor.rand(N, N).realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
t.to('cpu').realize()
|
||||
|
||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||
def testCopyCPUto6GPUs(self):
|
||||
t = Tensor.rand(N, N, device="cpu").realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"):
|
||||
with Timing("queue: "):
|
||||
for g in range(6):
|
||||
t.to(f"gpu:{g}").realize()
|
||||
Device[f"gpu"].synchronize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -36,6 +36,9 @@ def partition(lst:List[T], fxn:Callable[[T],bool]):
|
|||
b:List[T] = []
|
||||
for s in lst: (a if fxn(s) else b).append(s)
|
||||
return a,b
|
||||
def unwrap(x:Optional[T]) -> T:
|
||||
assert x is not None
|
||||
return x
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
||||
|
|
|
@ -10,8 +10,10 @@ class Optimizer:
|
|||
if x.requires_grad is None: x.requires_grad = True
|
||||
|
||||
self.params: List[Tensor] = dedup([x for x in params if x.requires_grad])
|
||||
assert len(self.params) != 0, "optimizer must have at least one param"
|
||||
self.device = self.params[0].device
|
||||
self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized
|
||||
self.lr = Tensor([lr], requires_grad=False).contiguous()
|
||||
self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous()
|
||||
|
||||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import os, json, pathlib, zipfile, pickle
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI
|
||||
from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.ops import Device
|
||||
|
||||
|
@ -64,9 +64,9 @@ def load_state_dict(model, state_dict, strict=True, verbose=True):
|
|||
def torch_load(fn:str):
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
offsets: Dict[str, int] = {}
|
||||
lens: Dict[str, int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
||||
offsets: Dict[Union[str, int], int] = {}
|
||||
lens: Dict[Union[str, int], int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
||||
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
||||
lens[storage[2]] = storage[4] * storage[1].itemsize
|
||||
if storage[2] not in offsets: return None
|
||||
|
@ -92,7 +92,12 @@ def torch_load(fn:str):
|
|||
|
||||
return ret.reshape(size)
|
||||
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2}
|
||||
class Parameter:
|
||||
def __setstate__(self, state): self.tensor = state[0]
|
||||
|
||||
deserialized_objects: Dict[str, Any] = {}
|
||||
intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64,
|
||||
"_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter}
|
||||
whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed
|
||||
class Dummy: pass
|
||||
class TorchPickle(pickle.Unpickler):
|
||||
|
@ -102,7 +107,7 @@ def torch_load(fn:str):
|
|||
if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}")
|
||||
return Dummy
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return pid
|
||||
def persistent_load(self, pid): return deserialized_objects[pid] if pid in deserialized_objects else pid
|
||||
|
||||
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
|
@ -113,6 +118,20 @@ def torch_load(fn:str):
|
|||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
elif bytes(t[0:0xe].numpy()) == b"././@PaxHeader": # TODO: is this how you detect a tarfile?
|
||||
with tarfile.open(fn, "r") as tar:
|
||||
storages_offset = tar.getmember('storages').offset_data
|
||||
f = unwrap(tar.extractfile('storages'))
|
||||
for i in range(TorchPickle(f).load()): # num_storages
|
||||
(key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack('<q', f.read(8))[0]
|
||||
offsets[key] = storages_offset + f.tell()
|
||||
f.seek(sz*storage_type.itemsize, 1)
|
||||
f = unwrap(tar.extractfile('tensors'))
|
||||
for _ in range(TorchPickle(f).load()): # num_tensors
|
||||
(key, storage_id, _), ndim, _ = TorchPickle(f).load(), struct.unpack('<i', f.read(4))[0], f.read(4)
|
||||
size, stride, storage_offset = struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack(f'<{ndim}q', f.read(8 * ndim)), struct.unpack('<q', f.read(8))[0]
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
||||
else:
|
||||
with open(fn, "rb") as f:
|
||||
pkl = TorchPickle(f)
|
||||
|
|
|
@ -336,7 +336,7 @@ class Compiled:
|
|||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers)
|
||||
if ast not in self.method_cache or getenv("DISABLE_METHOD_CACHE"): self.method_cache[ast] = get_optimized_program(self.linearizer_opts, self.to_program, ast, rawbuffers)
|
||||
self.method_cache[ast].exec(rawbuffers, var_vals)
|
||||
|
||||
def get_optimized_program(linearizer_opts:LinearizerOptions, to_program, ast:LazyOp, rawbuffers:List[RawBuffer]) -> CompiledASTRunner:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List, cast, Dict, Callable
|
||||
import numpy as np
|
||||
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad.graph import log_schedule_item, print_tree
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE
|
||||
|
||||
|
@ -23,6 +23,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
|||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
|
||||
else:
|
||||
assert all(si.out.device == x.device for x in si.inputs), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
|
|
|
@ -18,8 +18,6 @@ class RawBuffer: # pylint: disable=abstract-method
|
|||
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
|
||||
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
|
||||
def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>"
|
||||
@property
|
||||
def key(self): return (self.size, self.dtype)
|
||||
|
||||
# NOTE: this interface allows for 0 copy
|
||||
@classmethod
|
||||
|
@ -57,7 +55,7 @@ class RawBufferCopyInOut(RawBufferCopyIn):
|
|||
return x
|
||||
|
||||
class RawBufferTransfer(RawBuffer):
|
||||
def _transfer(self, x) -> None: raise NotImplementedError("must be implemented")
|
||||
def _transfer(self, x:RawBuffer) -> None: raise NotImplementedError("must be implemented")
|
||||
|
||||
@classmethod
|
||||
def transfer(cls, x, shape, dtype, **kwargs):
|
||||
|
|
|
@ -7,6 +7,7 @@ from tinygrad.helpers import prod, DType, OSX
|
|||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called
|
||||
|
@ -22,7 +23,7 @@ class RawDiskBuffer(RawBufferMapped):
|
|||
else:
|
||||
fd = _posixshmem.shm_open(device[4:], os.O_RDWR, 0o600)
|
||||
# TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need
|
||||
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000)
|
||||
shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE)
|
||||
shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX
|
||||
os.close(fd)
|
||||
buf = [None, shm, 1]
|
||||
|
|
|
@ -15,7 +15,6 @@ OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_pr
|
|||
|
||||
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
|
||||
ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin")
|
||||
#ROCM_LLVM_PATH = pathlib.Path(__file__).parents[3] / "extra/rocm/build/llvm-project/bin"
|
||||
if DEBUG >= 5:
|
||||
early_exec = fromimport("extra.helpers", "enable_early_exec")()
|
||||
|
||||
|
@ -49,15 +48,17 @@ if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init()
|
|||
|
||||
class CLBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
||||
def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device})
|
||||
def _clear_event(self, _): del self.event
|
||||
def _copyin(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}"
|
||||
self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False)
|
||||
self.event.set_callback(cl.command_execution_status.COMPLETE, self._clear_event)
|
||||
def _copyout(self, x:np.ndarray):
|
||||
assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}"
|
||||
CL.cl_allocator.ensure_has_free_space(self.size*self.dtype.itemsize, self._device)
|
||||
buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data)
|
||||
mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False)
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else []))
|
||||
with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([evt] if (evt:=getattr(self, "event", None)) else []))
|
||||
def _transfer(self, x):
|
||||
if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name:
|
||||
cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait()
|
||||
|
@ -96,7 +97,7 @@ class CLProgram:
|
|||
for x in bufs:
|
||||
if x.__class__ is CLBuffer:
|
||||
cl_bufs.append(x._buf)
|
||||
if hasattr(x, "event"): wait_for.append(x.event)
|
||||
if (event:=getattr(x, "event",None)): wait_for.append(event)
|
||||
else: cl_bufs.append(x)
|
||||
e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for)
|
||||
if wait:
|
||||
|
|
|
@ -4,7 +4,7 @@ import extra.hip_wrapper as hip
|
|||
from typing import Tuple
|
||||
from tinygrad.helpers import DEBUG, getenv, diskcache
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
||||
|
||||
|
@ -38,7 +38,7 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer):
|
|||
def _copyout(self, x:np.ndarray):
|
||||
hip.hipSetDevice(self._device)
|
||||
hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost)
|
||||
def _transfer(self, x):
|
||||
def _transfer(self, x:RawBuffer):
|
||||
hip.hipSetDevice(x._device)
|
||||
hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice)
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class Tensor:
|
|||
|
||||
no_grad: ClassVar[bool] = False
|
||||
default_type: ClassVar[DType] = dtypes.float32
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = Device.canonicalize(device)
|
||||
# tensors have gradients, buffers do not
|
||||
|
@ -64,6 +64,8 @@ class Tensor:
|
|||
elif data is None or data.__class__ is list:
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
data = LazyBuffer.fromCPU(np.array([] if data is None else data, dtype=(dtype or Tensor.default_type).np))
|
||||
elif isinstance(data, bytes):
|
||||
data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
|
||||
elif isinstance(data, np.ndarray):
|
||||
assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype"
|
||||
if data.shape == ():
|
||||
|
@ -124,11 +126,18 @@ class Tensor:
|
|||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
|
||||
def item(self) -> Union[float, int]: return self.numpy().item()
|
||||
|
||||
def to(self, device:str) -> Tensor:
|
||||
def to(self, device:Optional[str]) -> Tensor:
|
||||
if device is None or device == self.device: return self
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
def to_(self, device:Optional[str]):
|
||||
if device is None or device == self.device: return
|
||||
if self.grad: self.grad = self.grad.to_(device)
|
||||
_ret = Tensor(self.lazydata, device)
|
||||
self.lazydata = _ret.lazydata
|
||||
|
||||
# ***** creation llop entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
|
@ -398,7 +407,7 @@ class Tensor:
|
|||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def chunk(self, num:int, dim:int) -> List[Tensor]:
|
||||
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
|
|
Loading…
Reference in New Issue