No separate pad2d kernel needed (#99)

Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
adamritter 2020-11-10 14:47:53 +00:00 committed by GitHub
parent 2d4a5d5950
commit f27628b21c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 21 deletions

View File

@ -281,47 +281,33 @@ class Pad2D(Function):
prg = clbuild(ctx.cl_ctx, """
__kernel void pad2d(__global const float *input, __global float *output,
int py, int px, int oy, int ox, int iy, int ix) {
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
int BC = get_global_id(0);
int Y = get_global_id(1);
int X = get_global_id(2);
int iptr = BC*iy*ix + Y*ix + X;
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
int optr = BC*oy*ox + (Y+py)*ox + px + X;
output[optr] = input[iptr];
}""")
ctx.save_for_backward(padding)
ctx.save_for_backward(padding, prg)
prg.pad2d(ctx.cl_queue, [bs*cin, iy, ix], None,
x, ret,
np.int32(padding[2]), np.int32(padding[0]),
np.int32(0), np.int32(0), np.int32(padding[2]), np.int32(padding[0]),
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
)
return ret
@staticmethod
def backward(ctx, grad_output):
padding, = ctx.saved_tensors
padding, prg = ctx.saved_tensors
bs, cin, iy, ix = grad_output.shape
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
ret = buffer_new(ctx, (bs, cin, oy, ox))
prg = clbuild(ctx.cl_ctx, """
__kernel void pad2d(__global const float *input, __global float *output,
int cin, int py, int px, int oy, int ox, int iy, int ix) {
int B = get_global_id(0);
int C = get_global_id(1);
int Y = get_global_id(2);
int iptr = B*cin*iy*ix + C*iy*ix + (Y+py)*ix + px;
int optr = B*cin*oy*ox + C*oy*ox + Y*ox;
for (int x = 0; x < ox; x++) {
output[optr+x] = input[iptr+x];
}
}""")
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
prg.pad2d(ctx.cl_queue, [bs*cin, oy, ox], None,
grad_output, ret,
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
np.int32(padding[2]), np.int32(padding[0]), np.int32(0), np.int32(0),
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
)
return ret