hmm, no reason it should have been wrong, numpy weirdness

This commit is contained in:
George Hotz 2020-10-25 13:14:12 -07:00
parent 96f9cdb8a0
commit 1255903519
1 changed files with 7 additions and 4 deletions

View File

@ -150,18 +150,21 @@ class FastConv2D(Function):
bs,_,oy,ox = grad_output.shape
tx, w = ctx.saved_tensors
cout,cin,H,W = w.shape
# grad_output.shape = (bs, cout, oy, ox)
# tx.shape = (bs*oy*ox*cin, H*W)
tw = w.reshape(w.shape[0], -1)
# order correctly
gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
# reshape correctly
ggt = np.moveaxis(grad_output, [0,1,2,3], [1,0,2,3]).reshape(cout, -1)
# dw is easy
dw = gg.T.dot(tx).reshape(w.shape)
dw = ggt.dot(tx).reshape(w.shape)
# dx is harder
dxi = gg.dot(tw)
dxi = ggt.T.dot(tw)
# if we im2col on the forward, we col2im on the backward
# dxi should be (bs, oy, ox, cin, H, W)
dx = col2im(dxi, H, W, oy+(H-1), ox+(W-1))
return dx, dw