mirror of https://github.com/commaai/tinygrad.git
start trying to load yolo v5
This commit is contained in:
parent
ece07a3d12
commit
895d142503
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
import io
|
||||
import pickle
|
||||
from extra.utils import fetch, my_unpickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
dat = fetch('https://github.com/ultralytics/yolov5/releases/download/v3.0/yolov5s.pt')
|
||||
#import torch
|
||||
#td = torch.load(io.BytesIO(dat))
|
||||
#print(td)
|
||||
|
||||
import zipfile
|
||||
fp = zipfile.ZipFile(io.BytesIO(dat))
|
||||
#fp.printdir()
|
||||
data = fp.read('archive/data.pkl')
|
||||
|
||||
#import pickletools
|
||||
#pickletools.dis(io.BytesIO(data))
|
||||
|
||||
ret, out = my_unpickle(io.BytesIO(data))
|
||||
print(dir(ret['model']))
|
||||
for m in ret['model']._modules['model']:
|
||||
print(m)
|
||||
print(m._modules.keys())
|
||||
|
||||
"""
|
||||
weights = fake_torch_load(data)
|
||||
for k,v in weights:
|
||||
print(k)
|
||||
"""
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
def fetch(url):
|
||||
import requests, os, hashlib, tempfile
|
||||
|
@ -26,11 +28,54 @@ def get_parameters(obj):
|
|||
parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = {}
|
||||
class HackTensor:
|
||||
def __new__(cls, *args):
|
||||
#print(args)
|
||||
ident, storage_type, obj_key, location, obj_size = args[0][0:5]
|
||||
assert ident == 'storage'
|
||||
|
||||
ret = np.zeros(obj_size, dtype=storage_type)
|
||||
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
|
||||
return ret
|
||||
|
||||
class HackParameter:
|
||||
def __new__(cls, *args):
|
||||
#print(args)
|
||||
pass
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
print(module, name)
|
||||
if name == 'FloatStorage':
|
||||
return np.float32
|
||||
if name == 'LongStorage':
|
||||
return np.int64
|
||||
if name == 'HalfStorage':
|
||||
return np.float16
|
||||
if module == "torch._utils":
|
||||
if name == "_rebuild_tensor_v2":
|
||||
return HackTensor
|
||||
elif name == "_rebuild_parameter":
|
||||
return HackParameter
|
||||
else:
|
||||
try:
|
||||
return pickle.Unpickler.find_class(self, module, name)
|
||||
except Exception:
|
||||
return Dummy
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
return MyPickle(fb0).load(), key_prelookup
|
||||
|
||||
def fake_torch_load(b0):
|
||||
import io
|
||||
import pickle
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
# convert it to a file
|
||||
fb0 = io.BytesIO(b0)
|
||||
|
@ -40,34 +85,7 @@ def fake_torch_load(b0):
|
|||
pickle.load(fb0)
|
||||
pickle.load(fb0)
|
||||
|
||||
key_prelookup = {}
|
||||
|
||||
class HackTensor:
|
||||
def __new__(cls, *args):
|
||||
#print(args)
|
||||
ident, storage_type, obj_key, location, obj_size, view_metadata = args[0]
|
||||
assert ident == 'storage'
|
||||
|
||||
ret = np.zeros(obj_size, dtype=storage_type)
|
||||
key_prelookup[obj_key] = (storage_type, obj_size, ret, args[2], args[3])
|
||||
return ret
|
||||
|
||||
class MyPickle(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#print(module, name)
|
||||
if name == 'FloatStorage':
|
||||
return np.float32
|
||||
if name == 'LongStorage':
|
||||
return np.int64
|
||||
if module == "torch._utils" or module == "torch":
|
||||
return HackTensor
|
||||
else:
|
||||
return pickle.Unpickler.find_class(self, module, name)
|
||||
|
||||
def persistent_load(self, pid):
|
||||
return pid
|
||||
|
||||
ret = MyPickle(fb0).load()
|
||||
ret, key_prelookup = my_unpickle(fb0)
|
||||
|
||||
# create key_lookup
|
||||
key_lookup = pickle.load(fb0)
|
||||
|
|
Loading…
Reference in New Issue