enet weight loading

This commit is contained in:
George Hotz 2020-10-27 21:01:48 -07:00
parent e84ad3e27d
commit 0ec279951f
2 changed files with 47 additions and 16 deletions

View File

@ -4,6 +4,7 @@
# a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
from tinygrad.tensor import Tensor
from tinygrad.utils import fetch
def swish(x):
return x.mul(x.sigmoid())
@ -13,6 +14,9 @@ class BatchNorm2D:
self.weight = Tensor.zeros(sz)
self.bias = Tensor.zeros(sz)
# TODO: need running_mean and running_var
self.running_mean = Tensor.zeros(sz)
self.running_var = Tensor.zeros(sz)
self.num_batches_tracked = Tensor.zeros(0)
def __call__(self, x):
# this work at inference?
@ -84,19 +88,43 @@ class EfficientNet:
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
self._bn1 = BatchNorm2D(1280)
self._fc = Tensor.zeros(1280, 1000)
self._fc_bias = Tensor.zeros(1000)
def forward(self, x):
x = x.pad2d(padding=(0,1,0,1))
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
for b in self._blocks:
print(x.shape)
x = b(x)
x = self._bn1(x.conv2d(self._conv_head))
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
#x = x.dropout(0.2)
return swish(x.dot(self._fc))
return swish(x.dot(self._fc).add(self._fc_bias))
if __name__ == "__main__":
# instantiate my net
model = EfficientNet()
# load b0
import io, torch
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
b0 = torch.load(io.BytesIO(b0))
for k,v in b0.items():
if '_blocks.' in k:
k = "%s[%s].%s" % tuple(k.split(".", 2))
mk = "model."+k
print(k, v.shape)
try:
mv = eval(mk)
except AttributeError:
try:
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
#b0 = pickle.loads(b0)
out = model.forward(Tensor.zeros(1, 3, 224, 224))
print(out)

View File

@ -9,21 +9,24 @@ def layer_init_uniform(*x):
ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
return ret.astype(np.float32)
def fetch(url):
import requests, os, hashlib
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
if os.path.isfile(fp):
with open(fp, "rb") as f:
dat = f.read()
else:
with open(fp, "wb") as f:
dat = requests.get(url).content
f.write(dat)
return dat
def fetch_mnist():
def fetch(url):
import requests, gzip, os, hashlib, numpy
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
if os.path.isfile(fp):
with open(fp, "rb") as f:
dat = f.read()
else:
with open(fp, "wb") as f:
dat = requests.get(url).content
f.write(dat)
return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy()
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
import gzip
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
return X_train, Y_train, X_test, Y_test