mirror of https://github.com/commaai/tinygrad.git
refactor efficientnet loading
This commit is contained in:
parent
7472a7ebe2
commit
7d12482d80
|
@ -0,0 +1,6 @@
|
||||||
|
This is where we scope out adding accelerators to tinygrad
|
||||||
|
|
||||||
|
ane -- Apple Neural Engine, in the M1 + newer iPhones
|
||||||
|
cherry -- Largely defunct custom hardware based on a RISC-V extension
|
||||||
|
tpu -- Google's TPU, available for rent in Google Cloud
|
||||||
|
|
|
@ -4,7 +4,16 @@ from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import BatchNorm2D
|
from tinygrad.nn import BatchNorm2D
|
||||||
from extra.utils import fetch, fake_torch_load
|
from extra.utils import fetch, fake_torch_load
|
||||||
|
|
||||||
USE_TORCH = False
|
model_urls = {
|
||||||
|
0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth",
|
||||||
|
1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth",
|
||||||
|
2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth",
|
||||||
|
3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth",
|
||||||
|
4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth",
|
||||||
|
5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth",
|
||||||
|
6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth",
|
||||||
|
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||||
|
}
|
||||||
|
|
||||||
class MBConvBlock:
|
class MBConvBlock:
|
||||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se):
|
||||||
|
@ -126,26 +135,9 @@ class EfficientNet:
|
||||||
#x = x.dropout(0.2)
|
#x = x.dropout(0.2)
|
||||||
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
|
return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1]))
|
||||||
|
|
||||||
def load_weights_from_torch(self):
|
|
||||||
# load b0
|
|
||||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551
|
|
||||||
if self.number == 0:
|
|
||||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
|
||||||
elif self.number == 2:
|
|
||||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth")
|
|
||||||
elif self.number == 4:
|
|
||||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth")
|
|
||||||
elif self.number == 7:
|
|
||||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth")
|
|
||||||
else:
|
|
||||||
raise Exception("no pretrained weights")
|
|
||||||
|
|
||||||
if USE_TORCH:
|
def load_weights_from_torch(self):
|
||||||
import io
|
b0 = fake_torch_load(fetch(model_urls[self.number]))
|
||||||
import torch
|
|
||||||
b0 = torch.load(io.BytesIO(b0))
|
|
||||||
else:
|
|
||||||
b0 = fake_torch_load(b0)
|
|
||||||
|
|
||||||
for k,v in b0.items():
|
for k,v in b0.items():
|
||||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||||
|
@ -155,7 +147,7 @@ class EfficientNet:
|
||||||
|
|
||||||
#print(k, v.shape)
|
#print(k, v.shape)
|
||||||
mv = _get_child(self, k)
|
mv = _get_child(self, k)
|
||||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
|
vnp = v.astype(np.float32)
|
||||||
vnp = vnp if k != '_fc' else vnp.T
|
vnp = vnp if k != '_fc' else vnp.T
|
||||||
vnp = vnp if vnp.shape != () else np.array([vnp])
|
vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ model_urls = {
|
||||||
|
|
||||||
def load_from_pretrained(model, url):
|
def load_from_pretrained(model, url):
|
||||||
state_dict = load_state_dict_from_url(url, progress=True)
|
state_dict = load_state_dict_from_url(url, progress=True)
|
||||||
|
#state_dict = fake_torch_load(fetch(url))
|
||||||
layers_not_loaded = []
|
layers_not_loaded = []
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
par_name = ['model']
|
par_name = ['model']
|
||||||
|
|
Loading…
Reference in New Issue