# type: ignore import pickle import numpy as np from tqdm import tqdm import tempfile, platform from pathlib import Path from collections import defaultdict from tinygrad.helpers import prod, getenv, DEBUG, dtypes from tinygrad.ops import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.ops import Device from tinygrad.shape.view import strides_for_shape OSX = platform.system() == "Darwin" WINDOWS = platform.system() == "Windows" def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix() def fetch(url): if url.startswith("/") or url.startswith("."): with open(url, "rb") as f: return f.read() import hashlib fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest()) download_file(url, fp, skip_if_exists=not getenv("NOCACHE")) with open(fp, "rb") as f: return f.read() def fetch_as_file(url): if url.startswith("/") or url.startswith("."): with open(url, "rb") as f: return f.read() import hashlib fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest()) download_file(url, fp, skip_if_exists=not getenv("NOCACHE")) return fp def download_file(url, fp, skip_if_exists=True): import requests if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0: return r = requests.get(url, stream=True) assert r.status_code == 200 progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url) (path := Path(fp).parent).mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk)) f.close() Path(f.name).rename(fp) def my_unpickle(fb0): key_prelookup = defaultdict(list) def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None): #print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) ident, storage_type, obj_key, location, obj_size = storage[0:5] assert ident == 'storage' assert prod(size) <= (obj_size - storage_offset) if storage_type not in [np.float16, np.float32]: if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}") ret = None else: ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type)) key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset)) return ret def _rebuild_parameter(*args): #print(args) pass class Dummy: pass class MyPickle(pickle.Unpickler): def find_class(self, module, name): #print(module, name) if name == 'FloatStorage': return np.float32 if name == 'LongStorage': return np.int64 if name == 'IntStorage': return np.int32 if name == 'HalfStorage': return np.float16 if module == "torch._utils": if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2 if name == "_rebuild_parameter": return _rebuild_parameter else: if module.startswith('pytorch_lightning'): return Dummy try: return super().find_class(module, name) except Exception: return Dummy def persistent_load(self, pid): return pid return MyPickle(fb0).load(), key_prelookup def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False): bytes_size = np.dtype(dtype).itemsize if t is None: myfile.seek(prod(shape) * bytes_size, 1) return bytes_offset = 0 if storage_offset is not None: bytes_offset = storage_offset * bytes_size myfile.seek(bytes_offset) assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}" assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)): # slow path buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape))) buffer_size += t.dtype.itemsize np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np) np_array = np.lib.stride_tricks.as_strided( np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides]) lna = t.lazydata.op.arg lna.fxn = lambda _: np_array t.realize() return # ["METAL", "CLANG", "LLVM"] support readinto for more speed # ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway # this needs real APIs if t.device in ["METAL", "CLANG", "LLVM"]: del t.lazydata.op t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype) myfile.readinto(t.lazydata.realized._buffer()) else: def _mmap(lna): assert myfile._compress_type == 0, "compressed data can't be mmaped" return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape) def _read(lna): ret = np.empty(lna.shape, dtype=lna.dtype) myfile.readinto(ret.data) return ret if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap else: t.lazydata.op.arg.fxn = _read t.realize() def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True): if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap import zipfile with zipfile.ZipFile(fb0, 'r') as myzip: base_name = myzip.namelist()[0].split('/', 1)[0] with myzip.open(f'{base_name}/data.pkl') as myfile: ret = my_unpickle(myfile) if load_weights: def load_weight(k, vv): with myzip.open(f'{base_name}/data/{k}') as myfile: for v in vv: load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True) if multithreaded: import concurrent.futures # 2 seems fastest with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()} for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))): if future.exception() is not None: raise future.exception() k = futures[future] t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") else: for k,v in (t := tqdm(ret[1].items())): t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB") load_weight(k,v) return ret[0] def fake_torch_load(b0): import io import struct # convert it to a file fb0 = io.BytesIO(b0) if b0[0:2] == b"\x50\x4b": return fake_torch_load_zipped(fb0) # skip three junk pickles pickle.load(fb0) pickle.load(fb0) pickle.load(fb0) ret, key_prelookup = my_unpickle(fb0) # create key_lookup key_lookup = pickle.load(fb0) key_real = [None] * len(key_lookup) for k,v in key_prelookup.items(): assert len(v) == 1 key_real[key_lookup.index(k)] = v[0] # read in the actual data for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real: ll = struct.unpack("Q", fb0.read(8))[0] assert ll == obj_size, f"size mismatch {ll} != {obj_size}" assert storage_offset == 0, "not implemented" load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None) return ret def get_child(parent, key): obj = parent for k in key.split('.'): if k.isnumeric(): obj = obj[int(k)] elif isinstance(obj, dict): obj = obj[k] else: obj = getattr(obj, k) return obj def _tree(lazydata, prefix=""): if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ") if len(lazydata.src) == 0: return [f"━━ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] lines = [f"━┳ {prefix}{lazydata.op.name} {lazydata.arg if lazydata.arg else ''}"] childs = [_tree(c) for c in lazydata.src[:]] for c in childs[:-1]: lines += [f" ┣{c[0]}"] + [f" ┃{l}" for l in c[1:]] return lines + [" ┗"+childs[-1][0]] + [" "+l for l in childs[-1][1:]] def print_tree(tensor:Tensor):print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(tensor if not isinstance(tensor, Tensor) else tensor.lazydata))]))