From cc17e3271a175fd12b71316a03776cd316eab5ad Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 27 Oct 2020 21:41:52 -0700 Subject: [PATCH] try to recognize cat. do not succeed --- examples/efficientnet.py | 19 +++++++++++++------ tinygrad/ops.py | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/efficientnet.py b/examples/efficientnet.py index de76c3c8..75b6b81d 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -96,11 +96,11 @@ class EfficientNet: def forward(self, x): x = x.pad2d(padding=(0,1,0,1)) - x = self._bn0(x.conv2d(self._conv_stem, stride=2)) + x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2))) for b in self._blocks: print(x.shape) x = b(x) - x = self._bn1(x.conv2d(self._conv_head)) + x = swish(self._bn1(x.conv2d(self._conv_head))) x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280)) #x = x.dropout(0.2) return swish(x.dot(self._fc).add(self._fc_bias)) @@ -130,8 +130,15 @@ if __name__ == "__main__": mv = eval(mk.replace(".bias", "_bias")) mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T - #b0 = pickle.loads(b0) - img = np.zeros((1, 3, 224, 224), np.float32) + 0.5 - out = model.forward(Tensor(img)) - print(out.data[:, 0:10]) + # load cat image + from PIL import Image + img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"))) + img = img.resize((224, 224)) + img = np.moveaxis(np.array(img), [2,0,1], [0,1,2]) + img = img.astype(np.float32).reshape(1,3,224,224) + print(img.shape) + + #b0 = pickle.loads(b0) + out = model.forward(Tensor(img)) + print(np.argmax(out.data), np.max(out.data)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4a8cf468..2cea9ad1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -167,7 +167,7 @@ class Conv2D(Function): cout,cin,H,W = w.shape if groups > 1: - w = np.repeat(w, groups, axis=1) + w = np.repeat(w, groups, axis=1) / groups tw = w.reshape(cout, -1).T ys,xs = ctx.stride bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs