backward pass for conv2d, lol i mostly guessed and made shapes match

This commit is contained in:
George Hotz 2020-10-21 08:45:35 -07:00
parent 9172bc9c38
commit e3110c9922
2 changed files with 27 additions and 5 deletions

View File

@ -33,12 +33,20 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_conv2d(self):
x = torch.randn((5,2,10,7))
w = torch.randn((4,2,3,3))
x = torch.randn((5,2,10,7), requires_grad=True)
w = torch.randn((4,2,3,3), requires_grad=True)
xt = Tensor(x.detach().numpy())
wt = Tensor(w.detach().numpy())
out = torch.nn.functional.conv2d(x,w)
ret = Conv2D.apply(Conv2D, Tensor(x.numpy()), Tensor(w.numpy()))
np.testing.assert_allclose(ret.data, out.numpy(), atol=1e-5)
ret = Conv2D.apply(Conv2D, xt, wt)
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
out.mean().backward()
ret.mean().backward()
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
if __name__ == '__main__':

View File

@ -157,6 +157,7 @@ register('logsoftmax', LogSoftmax)
class Conv2D(Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
cout,cin,H,W = w.shape
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
for j in range(H):
@ -169,6 +170,19 @@ class Conv2D(Function):
@staticmethod
def backward(ctx, grad_output):
raise Exception("please write backward pass for Conv2D")
x, w = ctx.saved_tensors
dx = np.zeros_like(x)
dw = np.zeros_like(w)
cout,cin,H,W = w.shape
for j in range(H):
for i in range(W):
tw = w[:, :, j, i]
for Y in range(grad_output.shape[2]):
for X in range(grad_output.shape[3]):
gg = grad_output[:, :, Y, X]
tx = x[:, :, Y+j, X+i]
dx[:, :, Y+j, X+i] += gg.dot(tw)
dw[:, :, j, i] += gg.T.dot(tx)
return dx, dw
register('conv2d', Conv2D)