refactor efficientnet loading

This commit is contained in:
George Hotz 2021-10-30 17:02:17 -07:00
parent 7472a7ebe2
commit 7d12482d80
3 changed files with 20 additions and 21 deletions

6
accel/README Normal file
View File

@ -0,0 +1,6 @@
This is where we scope out adding accelerators to tinygrad
ane -- Apple Neural Engine, in the M1 + newer iPhones
cherry -- Largely defunct custom hardware based on a RISC-V extension
tpu -- Google's TPU, available for rent in Google Cloud

View File

@ -4,7 +4,16 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2D from tinygrad.nn import BatchNorm2D
from extra.utils import fetch, fake_torch_load from extra.utils import fetch, fake_torch_load
USE_TORCH = False model_urls = {
0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
}
class MBConvBlock: class MBConvBlock:
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se): def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
@ -126,26 +135,9 @@ class EfficientNet:
#x = x.dropout(0.2) #x = x.dropout(0.2)
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1])) return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
def load_weights_from_torch(self):
# load b0
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
if self.number == 0:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
elif self.number == 2:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
elif self.number == 4:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth")
elif self.number == 7:
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
else:
raise Exception("no pretrained weights")
if USE_TORCH: def load_weights_from_torch(self):
import io b0 = fake_torch_load(fetch(model_urls[self.number]))
import torch
b0 = torch.load(io.BytesIO(b0))
else:
b0 = fake_torch_load(b0)
for k,v in b0.items(): for k,v in b0.items():
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']: for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
@ -155,7 +147,7 @@ class EfficientNet:
#print(k, v.shape) #print(k, v.shape)
mv = _get_child(self, k) mv = _get_child(self, k)
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32) vnp = v.astype(np.float32)
vnp = vnp if k != '_fc' else vnp.T vnp = vnp if k != '_fc' else vnp.T
vnp = vnp if vnp.shape != () else np.array([vnp]) vnp = vnp if vnp.shape != () else np.array([vnp])

View File

@ -14,6 +14,7 @@ model_urls = {
def load_from_pretrained(model, url): def load_from_pretrained(model, url):
state_dict = load_state_dict_from_url(url, progress=True) state_dict = load_state_dict_from_url(url, progress=True)
#state_dict = fake_torch_load(fetch(url))
layers_not_loaded = [] layers_not_loaded = []
for k, v in state_dict.items(): for k, v in state_dict.items():
par_name = ['model'] par_name = ['model']