mirror of https://github.com/commaai/tinygrad.git
conv stride support
This commit is contained in:
parent
2a55d7402b
commit
1654008c1f
|
@ -40,7 +40,6 @@ class TestOps(unittest.TestCase):
|
||||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6)
|
lambda x,w: Tensor.conv2d(x,w).relu(), atol=2e-5, grad_atol=2e-6)
|
||||||
|
|
||||||
@unittest.skip("please write stride support")
|
|
||||||
def test_strided_conv2d(self):
|
def test_strided_conv2d(self):
|
||||||
bs = 4
|
bs = 4
|
||||||
cin = 3
|
cin = 3
|
||||||
|
|
|
@ -102,16 +102,21 @@ register('logsoftmax', LogSoftmax)
|
||||||
|
|
||||||
class Conv2D(Function):
|
class Conv2D(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, w):
|
def forward(ctx, x, w, stride=1):
|
||||||
|
if type(ctx.stride) == int:
|
||||||
|
ctx.stride = (ctx.stride, ctx.stride)
|
||||||
|
|
||||||
cout,cin,H,W = w.shape
|
cout,cin,H,W = w.shape
|
||||||
tw = w.reshape(cout, -1).T
|
tw = w.reshape(cout, -1).T
|
||||||
bs,oy,ox = x.shape[0], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
ys,xs = ctx.stride
|
||||||
|
bs,oy,ox = x.shape[0], (x.shape[2]-(H-ys))//ys, (x.shape[3]-(W-xs))//xs
|
||||||
|
|
||||||
ctx.save_for_backward(x, w)
|
ctx.save_for_backward(x, w)
|
||||||
ret = np.zeros((bs, cout, oy, ox), dtype=w.dtype)
|
ret = np.zeros((bs, cout, oy, ox), dtype=w.dtype)
|
||||||
for Y in range(oy):
|
for Y in range(oy):
|
||||||
for X in range(ox):
|
for X in range(ox):
|
||||||
tx = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
iY,iX = Y*ys, X*xs
|
||||||
|
tx = x[:, :, iY:iY+H, iX:iX+W].reshape(bs, -1)
|
||||||
ret[:, :, Y, X] = tx.dot(tw)
|
ret[:, :, Y, X] = tx.dot(tw)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@ -121,14 +126,16 @@ class Conv2D(Function):
|
||||||
x, w = ctx.saved_tensors
|
x, w = ctx.saved_tensors
|
||||||
cout,cin,H,W = w.shape
|
cout,cin,H,W = w.shape
|
||||||
tw = w.reshape(cout, -1)
|
tw = w.reshape(cout, -1)
|
||||||
|
ys,xs = ctx.stride
|
||||||
|
|
||||||
dx, dw = np.zeros_like(x), np.zeros_like(w)
|
dx, dw = np.zeros_like(x), np.zeros_like(w)
|
||||||
for Y in range(grad_output.shape[2]):
|
for Y in range(grad_output.shape[2]):
|
||||||
for X in range(grad_output.shape[3]):
|
for X in range(grad_output.shape[3]):
|
||||||
|
iY,iX = Y*ys, X*xs
|
||||||
gg = grad_output[:, :, Y, X]
|
gg = grad_output[:, :, Y, X]
|
||||||
tx = x[:, :, Y:Y+H, X:X+W].reshape(x.shape[0], -1)
|
tx = x[:, :, iY:iY+H, iX:iX+W].reshape(x.shape[0], -1)
|
||||||
dw += gg.T.dot(tx).reshape(dw.shape)
|
dw += gg.T.dot(tx).reshape(dw.shape)
|
||||||
dx[:, :, Y:Y+H, X:X+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W)
|
dx[:, :, iY:iY+H, iX:iX+W] += gg.dot(tw).reshape(dx.shape[0], dx.shape[1], H, W)
|
||||||
return dx, dw
|
return dx, dw
|
||||||
register('conv2d', Conv2D)
|
register('conv2d', Conv2D)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue