try to recognize cat. do not succeed

This commit is contained in:
George Hotz 2020-10-27 21:41:52 -07:00
parent 03d9c98f5b
commit cc17e3271a
2 changed files with 14 additions and 7 deletions

View File

@ -96,11 +96,11 @@ class EfficientNet:
def forward(self, x):
x = x.pad2d(padding=(0,1,0,1))
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2)))
for b in self._blocks:
print(x.shape)
x = b(x)
x = self._bn1(x.conv2d(self._conv_head))
x = swish(self._bn1(x.conv2d(self._conv_head)))
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
#x = x.dropout(0.2)
return swish(x.dot(self._fc).add(self._fc_bias))
@ -130,8 +130,15 @@ if __name__ == "__main__":
mv = eval(mk.replace(".bias", "_bias"))
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
#b0 = pickle.loads(b0)
img = np.zeros((1, 3, 224, 224), np.float32) + 0.5
out = model.forward(Tensor(img))
print(out.data[:, 0:10])
# load cat image
from PIL import Image
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
img = img.resize((224, 224))
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
img = img.astype(np.float32).reshape(1,3,224,224)
print(img.shape)
#b0 = pickle.loads(b0)
out = model.forward(Tensor(img))
print(np.argmax(out.data), np.max(out.data))

View File

@ -167,7 +167,7 @@ class Conv2D(Function):
cout,cin,H,W = w.shape
if groups > 1:
w = np.repeat(w, groups, axis=1)
w = np.repeat(w, groups, axis=1) / groups
tw = w.reshape(cout, -1).T
ys,xs = ctx.stride
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs