diff --git a/.gitignore b/.gitignore index 163d6977..eca5fc91 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,4 @@ temp coverage.xml htmlcov outputs_yolov8 +wandb diff --git a/extra/hip_wrapper.py b/extra/hip_wrapper.py index 19f5e571..6f95454d 100644 --- a/extra/hip_wrapper.py +++ b/extra/hip_wrapper.py @@ -265,6 +265,14 @@ try: hipMemcpyDeviceToDevice = 3 hipMemcpyDefault = 4 + _libhip.hipHostMalloc.restype = int + _libhip.hipHostMalloc.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, ctypes.c_uint32] + def hipHostMalloc(count, flags=0): + ptr = ctypes.c_void_p() + status = _libhip.hipHostMalloc(ctypes.byref(ptr), count, flags) + hipCheckStatus(status) + return ptr.value + _libhip.hipMemcpy.restype = int _libhip.hipMemcpy.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] def hipMemcpy(dst, src, count, direction): diff --git a/extra/lr_scheduler.py b/extra/lr_scheduler.py index 25badc17..a1bb9316 100644 --- a/extra/lr_scheduler.py +++ b/extra/lr_scheduler.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor class LR_Scheduler: def __init__(self, optimizer: Optimizer): self.optimizer = optimizer - self.epoch_counter = Tensor([0], requires_grad=False) + self.epoch_counter = Tensor([0], requires_grad=False, device=self.optimizer.device) def get_lr(self): pass @@ -61,7 +61,7 @@ class CosineAnnealingLR(LR_Scheduler): self.eta_max = optimizer.lr.numpy()[0] def get_lr(self) -> Tensor: - return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))]) + return Tensor([self.eta_min + 0.5 * (self.eta_max - self.eta_min) * (1 + math.cos((self.epoch_counter.numpy()[0]/self.T_max) * math.pi))], device=self.optimizer.device) class OneCycleLR(LR_Scheduler): def __init__(self, optimizer: Optimizer, max_lr: float, div_factor: float, final_div_factor: float, total_steps: int, pct_start: float, diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 68d273f3..8f14cbba 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -1,5 +1,7 @@ import tinygrad.nn as nn from tinygrad.tensor import Tensor +from tinygrad.nn.state import torch_load +from tinygrad.helpers import fetch from extra.utils import get_child class BasicBlock: @@ -131,11 +133,8 @@ class ResNet: } self.url = model_urls[(self.num, self.groups, self.base_width)] - - from torch.hub import load_state_dict_from_url - state_dict = load_state_dict_from_url(self.url, progress=True) - for k, v in state_dict.items(): - obj = get_child(self, k) + for k, v in torch_load(fetch(self.url)).items(): + obj: Tensor = get_child(self, k) dat = v.detach().numpy() if 'fc.' in k and obj.shape != dat.shape: diff --git a/test/extra/test_lr_scheduler.py b/test/extra/test_lr_scheduler.py index 9aa9b863..3bff3660 100644 --- a/test/extra/test_lr_scheduler.py +++ b/test/extra/test_lr_scheduler.py @@ -57,7 +57,9 @@ def get_lrs(optim, sched, epochs, steps=1, accs=None): class TestLrScheduler(unittest.TestCase): def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol): accs = opts.pop('accs', None) - tinygrad_optim, torch_optim = Adam([], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) + test_tensor = Tensor([0], requires_grad=True) # NOTE: optimizers are broken on 0-dim tensors because it broadcasts to [lr] + test_tensor.mean().backward() + tinygrad_optim, torch_optim = Adam([test_tensor], lr=0.01), torch.optim.Adam([torch.tensor([0.], requires_grad=True)], lr=0.01) tinygrad_sched, torch_sched = tinygrad_sched(tinygrad_optim, **opts), torch_sched(torch_optim, **opts) tinygrad_lrs = get_lrs(tinygrad_optim, tinygrad_sched, epochs, accs=accs) diff --git a/test/test_copy_speed.py b/test/test_copy_speed.py new file mode 100644 index 00000000..dc4b532a --- /dev/null +++ b/test/test_copy_speed.py @@ -0,0 +1,58 @@ +import unittest +from tinygrad import Tensor +from tinygrad.ops import Device +from tinygrad.helpers import Timing, CI + +N = 4096 if CI else 16384 +class TestCopySpeed(unittest.TestCase): + @classmethod + def setUpClass(cls): Device[Device.DEFAULT].synchronize() + + def testCopySHMtoDefault(self): + t = Tensor.empty(N, N, device="disk:/dev/shm/test_X").realize() + #t = Tensor.empty(N, N, device="disk:shm:test_X").realize() + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() + + def testCopyCPUtoDefault(self): + t = Tensor.rand(N, N, device="cpu").realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() + + def testCopyCPUtoDefaultFresh(self): + print("fresh copy") + for _ in range(3): + t = Tensor.rand(N, N, device="cpu").realize() + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + with Timing("queue: "): + t.to(Device.DEFAULT).realize() + Device[Device.DEFAULT].synchronize() + del t + + def testCopyDefaulttoCPU(self): + t = Tensor.rand(N, N).realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): + t.to('cpu').realize() + + @unittest.skipIf(CI, "CI doesn't have 6 GPUs") + def testCopyCPUto6GPUs(self): + t = Tensor.rand(N, N, device="cpu").realize() + print(f"buffer: {t.nbytes()*1e-9:.2f} GB") + for _ in range(3): + with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"): + with Timing("queue: "): + for g in range(6): + t.to(f"gpu:{g}").realize() + Device[f"gpu"].synchronize() + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 1fea8411..fb8c75ee 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -36,6 +36,9 @@ def partition(lst:List[T], fxn:Callable[[T],bool]): b:List[T] = [] for s in lst: (a if fxn(s) else b).append(s) return a,b +def unwrap(x:Optional[T]) -> T: + assert x is not None + return x @functools.lru_cache(maxsize=None) def getenv(key, default=0): return type(default)(os.getenv(key, default)) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 84948c7f..52acb462 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -10,8 +10,10 @@ class Optimizer: if x.requires_grad is None: x.requires_grad = True self.params: List[Tensor] = dedup([x for x in params if x.requires_grad]) + assert len(self.params) != 0, "optimizer must have at least one param" + self.device = self.params[0].device self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized - self.lr = Tensor([lr], requires_grad=False).contiguous() + self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous() def zero_grad(self): for param in self.params: param.grad = None diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index beabe4bd..26480e43 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -1,8 +1,8 @@ -import os, json, pathlib, zipfile, pickle +import os, json, pathlib, zipfile, pickle, tarfile, struct from tqdm import tqdm from typing import Dict, Union, List, Optional, Any, Tuple from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI +from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap from tinygrad.shape.view import strides_for_shape from tinygrad.ops import Device @@ -64,9 +64,9 @@ def load_state_dict(model, state_dict, strict=True, verbose=True): def torch_load(fn:str): t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") - offsets: Dict[str, int] = {} - lens: Dict[str, int] = {} - def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): + offsets: Dict[Union[str, int], int] = {} + lens: Dict[Union[str, int], int] = {} + def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None): #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) lens[storage[2]] = storage[4] * storage[1].itemsize if storage[2] not in offsets: return None @@ -92,7 +92,12 @@ def torch_load(fn:str): return ret.reshape(size) - intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2} + class Parameter: + def __setstate__(self, state): self.tensor = state[0] + + deserialized_objects: Dict[str, Any] = {} + intercept = {"HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, + "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter} whitelist = {"torch", "collections", "numpy", "_codecs"} # NOTE: this is not for security, only speed class Dummy: pass class TorchPickle(pickle.Unpickler): @@ -102,7 +107,7 @@ def torch_load(fn:str): if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}") return Dummy return intercept[name] if module_root == "torch" else super().find_class(module, name) - def persistent_load(self, pid): return pid + def persistent_load(self, pid): return deserialized_objects[pid] if pid in deserialized_objects else pid if tuple(t[0:2].numpy()) == (0x50, 0x4b): myzip = zipfile.ZipFile(fn, 'r') @@ -113,6 +118,20 @@ def torch_load(fn:str): offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore with myzip.open(f'{base_name}/data.pkl') as myfile: return TorchPickle(myfile).load() + elif bytes(t[0:0xe].numpy()) == b"././@PaxHeader": # TODO: is this how you detect a tarfile? + with tarfile.open(fn, "r") as tar: + storages_offset = tar.getmember('storages').offset_data + f = unwrap(tar.extractfile('storages')) + for i in range(TorchPickle(f).load()): # num_storages + (key, _, storage_type), sz = TorchPickle(f).load(), struct.unpack(' CompiledASTRunner: diff --git a/tinygrad/realize.py b/tinygrad/realize.py index b8b4a67c..fb152342 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -1,7 +1,7 @@ from typing import List, cast, Dict, Callable import numpy as np from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps -from tinygrad.graph import log_schedule_item +from tinygrad.graph import log_schedule_item, print_tree from tinygrad.lazy import LazyBuffer from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE @@ -23,6 +23,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False): for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs) else: + assert all(si.out.device == x.device for x in si.inputs), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}" Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args()) del si.out.op for v in si.out.views: del v.op diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index a3af8cc1..3d74be2e 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -18,8 +18,6 @@ class RawBuffer: # pylint: disable=abstract-method if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf) def __repr__(self): return f"buffer<{self.size}, {self.dtype}, {id(self)}>" - @property - def key(self): return (self.size, self.dtype) # NOTE: this interface allows for 0 copy @classmethod @@ -57,7 +55,7 @@ class RawBufferCopyInOut(RawBufferCopyIn): return x class RawBufferTransfer(RawBuffer): - def _transfer(self, x) -> None: raise NotImplementedError("must be implemented") + def _transfer(self, x:RawBuffer) -> None: raise NotImplementedError("must be implemented") @classmethod def transfer(cls, x, shape, dtype, **kwargs): diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 11acab1f..7df7499a 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -7,6 +7,7 @@ from tinygrad.helpers import prod, DType, OSX from tinygrad.runtime.lib import RawBufferMapped from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps from tinygrad.shape.view import strides_for_shape +MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000 class RawDiskBuffer(RawBufferMapped): def __init__(self, size, dtype:DType, device:Optional[str]=None, buf=None, shape=None, offset=0): # pylint: disable=super-init-not-called @@ -22,7 +23,7 @@ class RawDiskBuffer(RawBufferMapped): else: fd = _posixshmem.shm_open(device[4:], os.O_RDWR, 0o600) # TODO: these flags are somewhat platform specific, but python doesn't expose the ones we need - shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | 0x2000 | 0x008000) + shm = mmap.mmap(fd, size * dtype.itemsize, flags=mmap.MAP_SHARED | MAP_LOCKED | MAP_POPULATE) shm.madvise(mmap.MADV_HUGEPAGE) # type: ignore # not on OSX os.close(fd) buf = [None, shm, 1] diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 7d4b99f5..13cc2853 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -15,7 +15,6 @@ OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external/external_osx_pr # TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait() ROCM_LLVM_PATH = pathlib.Path("/opt/rocm/llvm/bin") -#ROCM_LLVM_PATH = pathlib.Path(__file__).parents[3] / "extra/rocm/build/llvm-project/bin" if DEBUG >= 5: early_exec = fromimport("extra.helpers", "enable_early_exec")() @@ -49,15 +48,17 @@ if not getenv("DELAYED_RUNTIME_INIT", False): CL.post_init() class CLBuffer(RawBufferCopyInOut, RawBufferTransfer): def __init__(self, size, dtype, device='0'): super().__init__(size, dtype, allocator=CL.cl_allocator, **{'device': device}) + def _clear_event(self, _): del self.event def _copyin(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyin images {self.dtype}" self.event = cl.enqueue_copy(CL.cl_queue[self._buf.device], self._buf, np.require(x, requirements=['C', 'A']), is_blocking=False) + self.event.set_callback(cl.command_execution_status.COMPLETE, self._clear_event) def _copyout(self, x:np.ndarray): assert not self.dtype.name.startswith("image"), f"can't copyout images {self.dtype}" CL.cl_allocator.ensure_has_free_space(self.size*self.dtype.itemsize, self._device) buf = cl.Buffer(CL.cl_ctxs[self._buf.device], cl.mem_flags.WRITE_ONLY | cl.mem_flags.USE_HOST_PTR, 0, hostbuf=x.data) mapped, event = cl.enqueue_map_buffer(CL.cl_queue[self._buf.device], buf, cl.map_flags.WRITE, 0, self.size, dtype=self.dtype.np, is_blocking=False) - with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([self.event] if hasattr(self, "event") else [])) + with mapped.base: cl.enqueue_copy(CL.cl_queue[self._buf.device], mapped, self._buf, is_blocking=True, wait_for=[event] + ([evt] if (evt:=getattr(self, "event", None)) else [])) def _transfer(self, x): if "gfx" in CL.cl_ctxs[x._buf.device].devices[0].name: cl.enqueue_copy_buffer_p2p_amd(CL.cl_platform, CL.cl_queue[x._buf.device], x._buf, self._buf, x.size * x.dtype.itemsize).wait() @@ -96,7 +97,7 @@ class CLProgram: for x in bufs: if x.__class__ is CLBuffer: cl_bufs.append(x._buf) - if hasattr(x, "event"): wait_for.append(x.event) + if (event:=getattr(x, "event",None)): wait_for.append(event) else: cl_bufs.append(x) e = self.clprgs[cl_bufs[0].device](CL.cl_queue[cl_bufs[0].device], [int(g*l) for g,l in zip(global_size, local_size)] if local_size is not None else global_size, local_size, *cl_bufs, wait_for=wait_for) if wait: diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 1258c177..77f9453c 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -4,7 +4,7 @@ import extra.hip_wrapper as hip from typing import Tuple from tinygrad.helpers import DEBUG, getenv, diskcache from tinygrad.ops import Compiled -from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer +from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage @@ -38,7 +38,7 @@ class RawHIPBuffer(RawBufferCopyInOut, RawBufferTransfer): def _copyout(self, x:np.ndarray): hip.hipSetDevice(self._device) hip.hipMemcpy(x.ctypes.data, self._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToHost) - def _transfer(self, x): + def _transfer(self, x:RawBuffer): hip.hipSetDevice(x._device) hip.hipMemcpy(self._buf, x._buf, self.size * self.dtype.itemsize, hip.hipMemcpyDeviceToDevice) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d8fcce92..20bba090 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -46,7 +46,7 @@ class Tensor: no_grad: ClassVar[bool] = False default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) # tensors have gradients, buffers do not @@ -64,6 +64,8 @@ class Tensor: elif data is None or data.__class__ is list: assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" data = LazyBuffer.fromCPU(np.array([] if data is None else data, dtype=(dtype or Tensor.default_type).np)) + elif isinstance(data, bytes): + data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) elif isinstance(data, np.ndarray): assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" if data.shape == (): @@ -124,11 +126,18 @@ class Tensor: return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape) def item(self) -> Union[float, int]: return self.numpy().item() - def to(self, device:str) -> Tensor: + def to(self, device:Optional[str]) -> Tensor: + if device is None or device == self.device: return self ret = Tensor(self.lazydata, device) if self.grad: ret.grad = self.grad.to(device) return ret + def to_(self, device:Optional[str]): + if device is None or device == self.device: return + if self.grad: self.grad = self.grad.to_(device) + _ret = Tensor(self.lazydata, device) + self.lazydata = _ret.lazydata + # ***** creation llop entrypoint ***** @staticmethod @@ -398,7 +407,7 @@ class Tensor: final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) - def chunk(self, num:int, dim:int) -> List[Tensor]: + def chunk(self, num:int, dim:int=0) -> List[Tensor]: assert all_int(self.shape), f"does not support symbolic shape {self.shape}" dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num) slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]