mirror of https://github.com/commaai/tinygrad.git
No separate pad2d kernel needed (#99)
Co-authored-by: holonomicjl <58403584+holonomicjl@users.noreply.github.com>
This commit is contained in:
parent
2d4a5d5950
commit
f27628b21c
|
@ -281,47 +281,33 @@ class Pad2D(Function):
|
|||
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void pad2d(__global const float *input, __global float *output,
|
||||
int py, int px, int oy, int ox, int iy, int ix) {
|
||||
int ipx, int ipy, int py, int px, int oy, int ox, int iy, int ix) {
|
||||
int BC = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int X = get_global_id(2);
|
||||
|
||||
int iptr = BC*iy*ix + Y*ix + X;
|
||||
int iptr = BC*iy*ix + (Y+ipy)*ix + ipx + X;
|
||||
int optr = BC*oy*ox + (Y+py)*ox + px + X;
|
||||
|
||||
output[optr] = input[iptr];
|
||||
}""")
|
||||
ctx.save_for_backward(padding)
|
||||
ctx.save_for_backward(padding, prg)
|
||||
prg.pad2d(ctx.cl_queue, [bs*cin, iy, ix], None,
|
||||
x, ret,
|
||||
np.int32(padding[2]), np.int32(padding[0]),
|
||||
np.int32(0), np.int32(0), np.int32(padding[2]), np.int32(padding[0]),
|
||||
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
padding, = ctx.saved_tensors
|
||||
padding, prg = ctx.saved_tensors
|
||||
bs, cin, iy, ix = grad_output.shape
|
||||
oy, ox = iy - padding[2] - padding[3], ix - padding[0] - padding[1]
|
||||
ret = buffer_new(ctx, (bs, cin, oy, ox))
|
||||
prg = clbuild(ctx.cl_ctx, """
|
||||
__kernel void pad2d(__global const float *input, __global float *output,
|
||||
int cin, int py, int px, int oy, int ox, int iy, int ix) {
|
||||
int B = get_global_id(0);
|
||||
int C = get_global_id(1);
|
||||
int Y = get_global_id(2);
|
||||
|
||||
int iptr = B*cin*iy*ix + C*iy*ix + (Y+py)*ix + px;
|
||||
int optr = B*cin*oy*ox + C*oy*ox + Y*ox;
|
||||
|
||||
for (int x = 0; x < ox; x++) {
|
||||
output[optr+x] = input[iptr+x];
|
||||
}
|
||||
}""")
|
||||
prg.pad2d(ctx.cl_queue, [bs, cin, oy], None,
|
||||
prg.pad2d(ctx.cl_queue, [bs*cin, oy, ox], None,
|
||||
grad_output, ret,
|
||||
np.int32(cin), np.int32(padding[2]), np.int32(padding[0]),
|
||||
np.int32(padding[2]), np.int32(padding[0]), np.int32(0), np.int32(0),
|
||||
np.int32(oy), np.int32(ox), np.int32(iy), np.int32(ix)
|
||||
)
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue