From 7d12482d80bb5a367a6731747c16cf69c656323b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 30 Oct 2021 17:02:17 -0700 Subject: [PATCH] refactor efficientnet loading --- accel/README | 6 ++++++ models/efficientnet.py | 34 +++++++++++++--------------------- models/resnet.py | 1 + 3 files changed, 20 insertions(+), 21 deletions(-) create mode 100644 accel/README diff --git a/accel/README b/accel/README new file mode 100644 index 00000000..a985ea09 --- /dev/null +++ b/accel/README @@ -0,0 +1,6 @@ +This is where we scope out adding accelerators to tinygrad + +ane -- Apple Neural Engine, in the M1 + newer iPhones +cherry -- Largely defunct custom hardware based on a RISC-V extension +tpu -- Google's TPU, available for rent in Google Cloud + diff --git a/models/efficientnet.py b/models/efficientnet.py index 3a4d71ff..e7f42215 100644 --- a/models/efficientnet.py +++ b/models/efficientnet.py @@ -4,7 +4,16 @@ from tinygrad.tensor import Tensor from tinygrad.nn import BatchNorm2D from extra.utils import fetch, fake_torch_load -USE_TORCH = False +model_urls = { + 0: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth", + 1: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth", + 2: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth", + 3: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth", + 4: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth", + 5: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth", + 6: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth", + 7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" +} class MBConvBlock: def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se): @@ -126,26 +135,9 @@ class EfficientNet: #x = x.dropout(0.2) return x.dot(self._fc).add(self._fc_bias.reshape(shape=[1,-1])) - def load_weights_from_torch(self): - # load b0 - # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/utils.py#L551 - if self.number == 0: - b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth") - elif self.number == 2: - b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth") - elif self.number == 4: - b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth") - elif self.number == 7: - b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth") - else: - raise Exception("no pretrained weights") - if USE_TORCH: - import io - import torch - b0 = torch.load(io.BytesIO(b0)) - else: - b0 = fake_torch_load(b0) + def load_weights_from_torch(self): + b0 = fake_torch_load(fetch(model_urls[self.number])) for k,v in b0.items(): for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']: @@ -155,7 +147,7 @@ class EfficientNet: #print(k, v.shape) mv = _get_child(self, k) - vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32) + vnp = v.astype(np.float32) vnp = vnp if k != '_fc' else vnp.T vnp = vnp if vnp.shape != () else np.array([vnp]) diff --git a/models/resnet.py b/models/resnet.py index 6a6a157f..d71deb38 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -14,6 +14,7 @@ model_urls = { def load_from_pretrained(model, url): state_dict = load_state_dict_from_url(url, progress=True) + #state_dict = fake_torch_load(fetch(url)) layers_not_loaded = [] for k, v in state_dict.items(): par_name = ['model']