wow dilation support was simple

This commit is contained in:
George Hotz 2022-06-15 11:38:23 -07:00
parent 0b182029dd
commit fef6c82491
3 changed files with 21 additions and 14 deletions

View File

@ -268,19 +268,26 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
def test_dilated_conv2d_forward(self):
bs = 4
cin = 3
H,W = 3,3
for d in [2, (2,1)]:
with self.subTest(dilation := d):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4, forward_only=True)
@unittest.skipUnless(Device.DEFAULT == Device.TORCH, "Not Implemented")
def test_dilated_conv2d(self):
bs = 4
cin = 3
H,W = 3,3
with self.subTest(dilation := 2):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4)
with self.subTest(dilation := (2,1)):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4)
for d in [2, (2,1)]:
with self.subTest(dilation := d):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4)
def test_maxpool2d(self):
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:

View File

@ -64,7 +64,7 @@ def get_tx(x, C):
gx = x.reshape(C.bs,C.groups,C.cin,x.shape[2],x.shape[3])
return np.lib.stride_tricks.as_strided(gx,
shape=(C.bs, C.groups, C.cin, C.oy, C.ox, C.H, C.W),
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, *gx.strides[3:5]),
strides=(*gx.strides[0:3], gx.strides[3]*C.ys, gx.strides[4]*C.xs, gx.strides[3]*C.dy, gx.strides[4]*C.dx),
writeable=False,
)

View File

@ -132,7 +132,7 @@ def conv(x,w,ret,C):
# output = (bs, groups, rcout, oy, ox)
conv_prg = clbuild("conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy) {
int B = get_global_id(0)/(groups*rcout); // range 0-bs
int g = (get_global_id(0)/rcout)%groups;
@ -145,15 +145,15 @@ def conv(x,w,ret,C):
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) { for (int x = IX; x < IX+W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (y*dy+IY)*ix + (x*dx+IX)] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
} }
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in C[0:12]])
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy]])
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)