mirror of https://github.com/commaai/tinygrad.git
try to recognize cat. do not succeed
This commit is contained in:
parent
03d9c98f5b
commit
cc17e3271a
|
@ -96,11 +96,11 @@ class EfficientNet:
|
|||
|
||||
def forward(self, x):
|
||||
x = x.pad2d(padding=(0,1,0,1))
|
||||
x = self._bn0(x.conv2d(self._conv_stem, stride=2))
|
||||
x = swish(self._bn0(x.conv2d(self._conv_stem, stride=2)))
|
||||
for b in self._blocks:
|
||||
print(x.shape)
|
||||
x = b(x)
|
||||
x = self._bn1(x.conv2d(self._conv_head))
|
||||
x = swish(self._bn1(x.conv2d(self._conv_head)))
|
||||
x = x.avg_pool2d(kernel_size=x.shape[2:4]).reshape(shape=(-1, 1280))
|
||||
#x = x.dropout(0.2)
|
||||
return swish(x.dot(self._fc).add(self._fc_bias))
|
||||
|
@ -130,8 +130,15 @@ if __name__ == "__main__":
|
|||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
img = np.zeros((1, 3, 224, 224), np.float32) + 0.5
|
||||
out = model.forward(Tensor(img))
|
||||
print(out.data[:, 0:10])
|
||||
# load cat image
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
|
||||
img = img.resize((224, 224))
|
||||
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
print(img.shape)
|
||||
|
||||
#b0 = pickle.loads(b0)
|
||||
out = model.forward(Tensor(img))
|
||||
print(np.argmax(out.data), np.max(out.data))
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ class Conv2D(Function):
|
|||
|
||||
cout,cin,H,W = w.shape
|
||||
if groups > 1:
|
||||
w = np.repeat(w, groups, axis=1)
|
||||
w = np.repeat(w, groups, axis=1) / groups
|
||||
tw = w.reshape(cout, -1).T
|
||||
ys,xs = ctx.stride
|
||||
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||
|
|
Loading…
Reference in New Issue