mirror of https://github.com/commaai/tinygrad.git
hmm, no reason it should have been wrong, numpy weirdness
This commit is contained in:
parent
96f9cdb8a0
commit
1255903519
|
@ -150,18 +150,21 @@ class FastConv2D(Function):
|
|||
bs,_,oy,ox = grad_output.shape
|
||||
tx, w = ctx.saved_tensors
|
||||
cout,cin,H,W = w.shape
|
||||
# grad_output.shape = (bs, cout, oy, ox)
|
||||
# tx.shape = (bs*oy*ox*cin, H*W)
|
||||
tw = w.reshape(w.shape[0], -1)
|
||||
|
||||
# order correctly
|
||||
gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout)
|
||||
# reshape correctly
|
||||
ggt = np.moveaxis(grad_output, [0,1,2,3], [1,0,2,3]).reshape(cout, -1)
|
||||
|
||||
# dw is easy
|
||||
dw = gg.T.dot(tx).reshape(w.shape)
|
||||
dw = ggt.dot(tx).reshape(w.shape)
|
||||
|
||||
# dx is harder
|
||||
dxi = gg.dot(tw)
|
||||
dxi = ggt.T.dot(tw)
|
||||
|
||||
# if we im2col on the forward, we col2im on the backward
|
||||
# dxi should be (bs, oy, ox, cin, H, W)
|
||||
dx = col2im(dxi, H, W, oy+(H-1), ox+(W-1))
|
||||
|
||||
return dx, dw
|
||||
|
|
Loading…
Reference in New Issue