mirror of https://github.com/commaai/tinygrad.git
refactor efficientnet loading
This commit is contained in:
parent
7472a7ebe2
commit
7d12482d80
|
@ -0,0 +1,6 @@
|
|||
This is where we scope out adding accelerators to tinygrad
|
||||
|
||||
ane -- Apple Neural Engine, in the M1 + newer iPhones
|
||||
cherry -- Largely defunct custom hardware based on a RISC-V extension
|
||||
tpu -- Google's TPU, available for rent in Google Cloud
|
||||
|
|
@ -4,7 +4,16 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import BatchNorm2D
|
||||
from extra.utils import fetch, fake_torch_load
|
||||
|
||||
USE_TORCH = False
|
||||
model_urls = {
|
||||
0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||
1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||
2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||
3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
}
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
||||
|
@ -126,26 +135,9 @@ class EfficientNet:
|
|||
#x = x.dropout(0.2)
|
||||
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
|
||||
|
||||
def load_weights_from_torch(self):
|
||||
# load b0
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
|
||||
if self.number == 0:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
elif self.number == 2:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
|
||||
elif self.number == 4:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth")
|
||||
elif self.number == 7:
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
|
||||
else:
|
||||
raise Exception("no pretrained weights")
|
||||
|
||||
if USE_TORCH:
|
||||
import io
|
||||
import torch
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
else:
|
||||
b0 = fake_torch_load(b0)
|
||||
def load_weights_from_torch(self):
|
||||
b0 = fake_torch_load(fetch(model_urls[self.number]))
|
||||
|
||||
for k,v in b0.items():
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
|
@ -155,7 +147,7 @@ class EfficientNet:
|
|||
|
||||
#print(k, v.shape)
|
||||
mv = _get_child(self, k)
|
||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
|
||||
vnp = v.astype(np.float32)
|
||||
vnp = vnp if k != '_fc' else vnp.T
|
||||
vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ model_urls = {
|
|||
|
||||
def load_from_pretrained(model, url):
|
||||
state_dict = load_state_dict_from_url(url, progress=True)
|
||||
#state_dict = fake_torch_load(fetch(url))
|
||||
layers_not_loaded = []
|
||||
for k, v in state_dict.items():
|
||||
par_name = ['model']
|
||||
|
|
Loading…
Reference in New Issue