mirror of https://github.com/commaai/tinygrad.git
206 lines
7.5 KiB
Python
206 lines
7.5 KiB
Python
# type: ignore
|
|
import pickle, hashlib, zipfile, io, requests, struct, tempfile, platform, concurrent.futures
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from pathlib import Path
|
|
from collections import defaultdict
|
|
from typing import Union
|
|
|
|
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
|
from tinygrad.helpers import GlobalCounters
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.lazy import LazyBuffer
|
|
from tinygrad.ops import Device
|
|
from tinygrad.shape.view import strides_for_shape
|
|
OSX = platform.system() == "Darwin"
|
|
WINDOWS = platform.system() == "Windows"
|
|
|
|
def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix()
|
|
|
|
def fetch(url):
|
|
if url.startswith("/") or url.startswith("."):
|
|
with open(url, "rb") as f:
|
|
return f.read()
|
|
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
|
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
|
with open(fp, "rb") as f:
|
|
return f.read()
|
|
|
|
def fetch_as_file(url):
|
|
if url.startswith("/") or url.startswith("."):
|
|
with open(url, "rb") as f:
|
|
return f.read()
|
|
fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest())
|
|
download_file(url, fp, skip_if_exists=not getenv("NOCACHE"))
|
|
return fp
|
|
|
|
def download_file(url, fp, skip_if_exists=True):
|
|
if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0:
|
|
return
|
|
r = requests.get(url, stream=True)
|
|
assert r.status_code == 200
|
|
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
|
(path := Path(fp).parent).mkdir(parents=True, exist_ok=True)
|
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
|
for chunk in r.iter_content(chunk_size=16384):
|
|
progress_bar.update(f.write(chunk))
|
|
f.close()
|
|
Path(f.name).rename(fp)
|
|
|
|
def my_unpickle(fb0):
|
|
key_prelookup = defaultdict(list)
|
|
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
|
|
#print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
|
|
ident, storage_type, obj_key, location, obj_size = storage[0:5]
|
|
assert ident == 'storage'
|
|
assert prod(size) <= (obj_size - storage_offset)
|
|
|
|
if storage_type not in [np.float16, np.float32]:
|
|
if DEBUG: print(f"unsupported type {storage_type} on {obj_key} with shape {size}")
|
|
ret = None
|
|
else:
|
|
ret = Tensor.empty(*size, dtype=dtypes.from_np(storage_type))
|
|
key_prelookup[obj_key].append((storage_type, obj_size, ret, size, stride, storage_offset))
|
|
return ret
|
|
|
|
def _rebuild_parameter(*args):
|
|
#print(args)
|
|
pass
|
|
|
|
class Dummy: pass
|
|
|
|
class MyPickle(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
#print(module, name)
|
|
if name == 'FloatStorage': return np.float32
|
|
if name == 'LongStorage': return np.int64
|
|
if name == 'IntStorage': return np.int32
|
|
if name == 'HalfStorage': return np.float16
|
|
if module == "torch._utils":
|
|
if name == "_rebuild_tensor_v2": return _rebuild_tensor_v2
|
|
if name == "_rebuild_parameter": return _rebuild_parameter
|
|
else:
|
|
if module.startswith('pytorch_lightning'): return Dummy
|
|
try:
|
|
return super().find_class(module, name)
|
|
except Exception:
|
|
return Dummy
|
|
|
|
def persistent_load(self, pid):
|
|
return pid
|
|
|
|
return MyPickle(fb0).load(), key_prelookup
|
|
|
|
def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, mmap_allowed=False):
|
|
bytes_size = np.dtype(dtype).itemsize
|
|
if t is None:
|
|
myfile.seek(prod(shape) * bytes_size, 1)
|
|
return
|
|
|
|
bytes_offset = 0
|
|
if storage_offset is not None:
|
|
bytes_offset = storage_offset * bytes_size
|
|
myfile.seek(bytes_offset)
|
|
|
|
assert t.shape == shape or shape == tuple(), f"shape mismatch {t.shape} != {shape}"
|
|
assert t.dtype.np == dtype and t.dtype.itemsize == bytes_size
|
|
if any(s != 1 and st1 != st2 for s, st1, st2 in zip(shape, strides_for_shape(shape), strides)):
|
|
# slow path
|
|
buffer_size = sum(strides[i]*t.dtype.itemsize * (shape[i] - 1) for i in range(len(shape)))
|
|
buffer_size += t.dtype.itemsize
|
|
np_array = np.frombuffer(myfile.read(buffer_size), t.dtype.np)
|
|
|
|
np_array = np.lib.stride_tricks.as_strided(
|
|
np_array, shape=shape, strides=[i*t.dtype.itemsize for i in strides])
|
|
|
|
lna = t.lazydata.op.arg
|
|
lna.fxn = lambda _: np_array
|
|
t.realize()
|
|
return
|
|
|
|
# ["METAL", "CLANG", "LLVM"] support readinto for more speed
|
|
# ["GPU", "CUDA"] use _mmap since they have to copy in to the GPU anyway
|
|
# this needs real APIs
|
|
if t.device in ["METAL", "CLANG", "LLVM"]:
|
|
del t.lazydata.op
|
|
t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype)
|
|
myfile.readinto(t.lazydata.realized._buffer())
|
|
else:
|
|
def _mmap(lna):
|
|
assert myfile._compress_type == 0, "compressed data can't be mmaped"
|
|
return np.memmap(myfile._fileobj._file, dtype=lna.dtype, mode='r', offset=myfile._orig_compress_start + bytes_offset, shape=lna.shape)
|
|
def _read(lna):
|
|
ret = np.empty(lna.shape, dtype=lna.dtype)
|
|
myfile.readinto(ret.data)
|
|
return ret
|
|
if mmap_allowed and not OSX and t.device in ["GPU", "CUDA"]: t.lazydata.op.arg.fxn = _mmap
|
|
else: t.lazydata.op.arg.fxn = _read
|
|
t.realize()
|
|
|
|
def fake_torch_load_zipped(fb0, load_weights=True, multithreaded=True):
|
|
if Device.DEFAULT in ["TORCH", "GPU", "CUDA"]: multithreaded = False # multithreaded doesn't work with CUDA or TORCH. for GPU it's a wash with _mmap
|
|
with zipfile.ZipFile(fb0, 'r') as myzip:
|
|
base_name = myzip.namelist()[0].split('/', 1)[0]
|
|
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
|
ret = my_unpickle(myfile)
|
|
if load_weights:
|
|
def load_weight(k, vv):
|
|
with myzip.open(f'{base_name}/data/{k}') as myfile:
|
|
for v in vv:
|
|
load_single_weight(v[2], myfile, v[3], v[4], v[0], v[5], mmap_allowed=True)
|
|
if multithreaded:
|
|
# 2 seems fastest
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
|
futures = {executor.submit(load_weight, k, v):k for k,v in ret[1].items()}
|
|
for future in (t:=tqdm(concurrent.futures.as_completed(futures), total=len(futures))):
|
|
if future.exception() is not None: raise future.exception()
|
|
k = futures[future]
|
|
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
|
else:
|
|
for k,v in (t := tqdm(ret[1].items())):
|
|
t.set_description(f"loading {k} ram used: {GlobalCounters.mem_used/1e9:5.2f} GB")
|
|
load_weight(k,v)
|
|
return ret[0]
|
|
|
|
def fake_torch_load(b0):
|
|
|
|
# convert it to a file
|
|
fb0 = io.BytesIO(b0)
|
|
|
|
if b0[0:2] == b"\x50\x4b":
|
|
return fake_torch_load_zipped(fb0)
|
|
|
|
# skip three junk pickles
|
|
pickle.load(fb0)
|
|
pickle.load(fb0)
|
|
pickle.load(fb0)
|
|
|
|
ret, key_prelookup = my_unpickle(fb0)
|
|
|
|
# create key_lookup
|
|
key_lookup = pickle.load(fb0)
|
|
key_real = [None] * len(key_lookup)
|
|
for k,v in key_prelookup.items():
|
|
assert len(v) == 1
|
|
key_real[key_lookup.index(k)] = v[0]
|
|
|
|
# read in the actual data
|
|
for storage_type, obj_size, tensor, np_shape, np_strides, storage_offset in key_real:
|
|
ll = struct.unpack("Q", fb0.read(8))[0]
|
|
assert ll == obj_size, f"size mismatch {ll} != {obj_size}"
|
|
assert storage_offset == 0, "not implemented"
|
|
load_single_weight(tensor, fb0, np_shape, np_strides, storage_type, None)
|
|
|
|
return ret
|
|
|
|
def get_child(parent, key):
|
|
obj = parent
|
|
for k in key.split('.'):
|
|
if k.isnumeric():
|
|
obj = obj[int(k)]
|
|
elif isinstance(obj, dict):
|
|
obj = obj[k]
|
|
else:
|
|
obj = getattr(obj, k)
|
|
return obj
|