mirror of https://github.com/commaai/tinygrad.git
backward pass for conv2d, lol i mostly guessed and made shapes match
This commit is contained in:
parent
9172bc9c38
commit
e3110c9922
16
test/test.py
16
test/test.py
|
@ -33,12 +33,20 @@ class TestTinygrad(unittest.TestCase):
|
|||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_conv2d(self):
|
||||
x = torch.randn((5,2,10,7))
|
||||
w = torch.randn((4,2,3,3))
|
||||
x = torch.randn((5,2,10,7), requires_grad=True)
|
||||
w = torch.randn((4,2,3,3), requires_grad=True)
|
||||
xt = Tensor(x.detach().numpy())
|
||||
wt = Tensor(w.detach().numpy())
|
||||
|
||||
out = torch.nn.functional.conv2d(x,w)
|
||||
ret = Conv2D.apply(Conv2D, Tensor(x.numpy()), Tensor(w.numpy()))
|
||||
np.testing.assert_allclose(ret.data, out.numpy(), atol=1e-5)
|
||||
ret = Conv2D.apply(Conv2D, xt, wt)
|
||||
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
|
||||
|
||||
out.mean().backward()
|
||||
ret.mean().backward()
|
||||
|
||||
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
|
||||
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -157,6 +157,7 @@ register('logsoftmax', LogSoftmax)
|
|||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
ctx.save_for_backward(x, w)
|
||||
cout,cin,H,W = w.shape
|
||||
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
|
||||
for j in range(H):
|
||||
|
@ -169,6 +170,19 @@ class Conv2D(Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise Exception("please write backward pass for Conv2D")
|
||||
x, w = ctx.saved_tensors
|
||||
dx = np.zeros_like(x)
|
||||
dw = np.zeros_like(w)
|
||||
cout,cin,H,W = w.shape
|
||||
for j in range(H):
|
||||
for i in range(W):
|
||||
tw = w[:, :, j, i]
|
||||
for Y in range(grad_output.shape[2]):
|
||||
for X in range(grad_output.shape[3]):
|
||||
gg = grad_output[:, :, Y, X]
|
||||
tx = x[:, :, Y+j, X+i]
|
||||
dx[:, :, Y+j, X+i] += gg.dot(tw)
|
||||
dw[:, :, j, i] += gg.T.dot(tx)
|
||||
return dx, dw
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
|
|
Loading…
Reference in New Issue