From 12559035197cfdf2b417b6b223a255f193699ded Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 13:14:12 -0700 Subject: [PATCH] hmm, no reason it should have been wrong, numpy weirdness --- tinygrad/ops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ec9a9da8..456eb087 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -150,18 +150,21 @@ class FastConv2D(Function): bs,_,oy,ox = grad_output.shape tx, w = ctx.saved_tensors cout,cin,H,W = w.shape + # grad_output.shape = (bs, cout, oy, ox) + # tx.shape = (bs*oy*ox*cin, H*W) tw = w.reshape(w.shape[0], -1) - # order correctly - gg = np.moveaxis(grad_output, [0,1,2,3], [0,2,3,1]).reshape(-1, cout) + # reshape correctly + ggt = np.moveaxis(grad_output, [0,1,2,3], [1,0,2,3]).reshape(cout, -1) # dw is easy - dw = gg.T.dot(tx).reshape(w.shape) + dw = ggt.dot(tx).reshape(w.shape) # dx is harder - dxi = gg.dot(tw) + dxi = ggt.T.dot(tw) # if we im2col on the forward, we col2im on the backward + # dxi should be (bs, oy, ox, cin, H, W) dx = col2im(dxi, H, W, oy+(H-1), ox+(W-1)) return dx, dw