mirror of https://github.com/commaai/tinygrad.git
enet weight loading
This commit is contained in:
parent
e84ad3e27d
commit
0ec279951f
|
@ -4,6 +4,7 @@
|
|||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch
|
||||
|
||||
def swish(x):
|
||||
return x.mul(x.sigmoid())
|
||||
|
@ -13,6 +14,9 @@ class BatchNorm2D:
|
|||
self.weight = Tensor.zeros(sz)
|
||||
self.bias = Tensor.zeros(sz)
|
||||
# TODO: need running_mean and running_var
|
||||
self.running_mean = Tensor.zeros(sz)
|
||||
self.running_var = Tensor.zeros(sz)
|
||||
self.num_batches_tracked = Tensor.zeros(0)
|
||||
|
||||
def __call__(self, x):
|
||||
# this work at inference?
|
||||
|
@ -84,19 +88,43 @@ class EfficientNet:
|
|||
self._conv_head = Tensor.zeros(1280, 320, 1, 1)
|
||||
self._bn1 = BatchNorm2D(1280)
|
||||
self._fc = Tensor.zeros(1280, 1000)
|
||||
self._fc_bias = Tensor.zeros(1000)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
|
||||
for b in self._blocks:
|
||||
print(x.shape)
|
||||
x = b(x)
|
||||
x = self._bn1(x.conv2d(self._conv_head))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
|
||||
#x = x.dropout(0.2)
|
||||
return swish(x.dot(self._fc))
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet()
|
||||
|
||||
# load b0
|
||||
import io, torch
|
||||
b0 = fetch("https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth")
|
||||
b0 = torch.load(io.BytesIO(b0))
|
||||
|
||||
for k,v in b0.items():
|
||||
if '_blocks.' in k:
|
||||
k = "%s[%s].%s" % tuple(k.split(".", 2))
|
||||
mk = "model."+k
|
||||
print(k, v.shape)
|
||||
try:
|
||||
mv = eval(mk)
|
||||
except AttributeError:
|
||||
try:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
out = model.forward(Tensor.zeros(1, 3, 224, 224))
|
||||
print(out)
|
||||
|
||||
|
|
|
@ -9,21 +9,24 @@ def layer_init_uniform(*x):
|
|||
ret = np.random.uniform(-1., 1., size=x)/np.sqrt(np.prod(x))
|
||||
return ret.astype(np.float32)
|
||||
|
||||
def fetch(url):
|
||||
import requests, os, hashlib
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return dat
|
||||
|
||||
def fetch_mnist():
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
fp = os.path.join("/tmp", hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if os.path.isfile(fp):
|
||||
with open(fp, "rb") as f:
|
||||
dat = f.read()
|
||||
else:
|
||||
with open(fp, "wb") as f:
|
||||
dat = requests.get(url).content
|
||||
f.write(dat)
|
||||
return numpy.frombuffer(gzip.decompress(dat), dtype=numpy.uint8).copy()
|
||||
X_train = fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
import gzip
|
||||
parse = lambda dat: np.frombuffer(gzip.decompress(dat), dtype=np.uint8).copy()
|
||||
X_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_train = parse(fetch("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28, 28))
|
||||
Y_test = parse(fetch("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
|
Loading…
Reference in New Issue