tinygrad/extra/utils.py

146 lines
3.9 KiB
Python
Raw Normal View History

2020-12-07 02:45:04 +08:00
from tinygrad.tensor import Tensor
2021-01-03 04:51:45 +08:00
import pickle
import numpy as np
2022-09-04 01:08:42 +08:00
from tinygrad.helpers import prod
2020-10-19 05:36:29 +08:00
2020-10-28 12:01:48 +08:00
def fetch(url):
2022-08-29 02:03:23 +08:00
if url.startswith("/"):
with open(url, "rb") as f:
dat = f.read()
return dat
import requests, os, hashlib, tempfile
fp = os.path.join(tempfile.gettempdir(), hashlib.md5(url.encode('utf-8')).hexdigest())
2021-11-30 07:54:57 +08:00
if os.path.isfile(fp) and os.stat(fp).st_size > 0 and os.getenv("NOCACHE", None) is None:
2020-10-28 12:01:48 +08:00
with open(fp, "rb") as f:
dat = f.read()
else:
print("fetching %s" % url)
2022-01-16 11:09:29 +08:00
r = requests.get(url)
assert r.status_code == 200
dat = r.content
with open(fp+".tmp", "wb") as f:
2020-10-28 12:01:48 +08:00
f.write(dat)
os.rename(fp+".tmp", fp)
2020-10-28 12:01:48 +08:00
return dat
# TODO: move this to optim.py?
def get_parameters(obj):
2020-12-07 02:45:04 +08:00
parameters = []
if isinstance(obj, Tensor):
parameters.append(obj)
2021-11-30 01:45:27 +08:00
elif isinstance(obj, list) or isinstance(obj, tuple):
for x in obj:
parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):
2021-11-30 07:05:31 +08:00
for v in obj.__dict__.values():
parameters.extend(get_parameters(v))
2020-12-07 02:45:04 +08:00
return parameters
2021-01-03 04:51:45 +08:00
def my_unpickle(fb0):
key_prelookup = {}
class HackTensor:
def __new__(cls, *args):
#print(args)
2021-01-03 04:51:45 +08:00
ident, storage_type, obj_key, location, obj_size = args[0][0:5]
assert ident == 'storage'
2022-09-04 01:08:42 +08:00
assert prod(args[2]) == obj_size
ret = np.zeros(args[2], dtype=storage_type)
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
return ret
2021-01-03 04:51:45 +08:00
class HackParameter:
def __new__(cls, *args):
#print(args)
pass
class Dummy:
pass
class MyPickle(pickle.Unpickler):
def find_class(self, module, name):
2021-01-03 04:53:30 +08:00
#print(module, name)
if name == 'FloatStorage':
return np.float32
if name == 'LongStorage':
return np.int64
2021-01-03 04:51:45 +08:00
if name == 'HalfStorage':
return np.float16
if module == "torch._utils":
if name == "_rebuild_tensor_v2":
return HackTensor
elif name == "_rebuild_parameter":
return HackParameter
else:
2021-01-03 04:51:45 +08:00
try:
return pickle.Unpickler.find_class(self, module, name)
except Exception:
return Dummy
def persistent_load(self, pid):
return pid
2021-01-03 04:51:45 +08:00
return MyPickle(fb0).load(), key_prelookup
2022-09-04 01:08:42 +08:00
def fake_torch_load_zipped(fb0, load_weights=True):
import zipfile
with zipfile.ZipFile(fb0, 'r') as myzip:
2022-06-08 01:06:48 +08:00
with myzip.open('archive/data.pkl') as myfile:
ret = my_unpickle(myfile)
2022-09-04 01:08:42 +08:00
if load_weights:
for k,v in ret[1].items():
with myzip.open(f'archive/data/{k}') as myfile:
if v[2].dtype == "object":
print(f"issue assigning object on {k}")
continue
np.copyto(v[2], np.frombuffer(myfile.read(), v[2].dtype).reshape(v[3]))
2022-06-08 01:06:48 +08:00
return ret[0]
2021-01-03 04:51:45 +08:00
def fake_torch_load(b0):
import io
import struct
# convert it to a file
fb0 = io.BytesIO(b0)
2022-09-04 01:08:42 +08:00
if b0[0:2] == b"\x50\x4b":
return fake_torch_load_zipped(fb0)
2021-01-03 04:51:45 +08:00
# skip three junk pickles
pickle.load(fb0)
pickle.load(fb0)
pickle.load(fb0)
ret, key_prelookup = my_unpickle(fb0)
# create key_lookup
key_lookup = pickle.load(fb0)
key_real = [None] * len(key_lookup)
for k,v in key_prelookup.items():
key_real[key_lookup.index(k)] = v
# read in the actual data
for storage_type, obj_size, np_array, np_shape, np_strides in key_real:
ll = struct.unpack("Q", fb0.read(8))[0]
assert ll == obj_size
bytes_size = {np.float32: 4, np.int64: 8}[storage_type]
mydat = fb0.read(ll * bytes_size)
np.copyto(np_array, np.frombuffer(mydat, storage_type).reshape(np_shape))
# numpy stores its strides in bytes
real_strides = tuple([x*bytes_size for x in np_strides])
np_array.strides = real_strides
return ret
2021-12-01 05:14:54 +08:00
def get_child(parent, key):
obj = parent
for k in key.split('.'):
if k.isnumeric():
obj = obj[int(k)]
elif isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj