fix conv args. fix spacing

This commit is contained in:
George Hotz 2022-06-05 14:35:31 -07:00
parent 365e62a609
commit b49bfb6e02
2 changed files with 34 additions and 37 deletions

View File

@ -192,36 +192,36 @@ def inner_slice(ctx, x, arg):
# c = a@b
def matmul(a, b, c, transpose_a=False, transpose_b=False):
cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1
isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1])
if transpose_a: isize,msize = msize,isize
assert isize == c.shape[-2]
assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2])
assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1])
matmul_prg = clbuild("matmul", """
__kernel void matmul(
__global const float *input, __global const float *weight, __global float *res,
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
) {
int stride = get_global_id(2);
cnt = np.prod(a.shape[0:-2]) if len(a.shape) > 2 else 1
isize, msize, osize = i32(a.shape[-2]), i32(a.shape[-1]), i32(c.shape[-1])
if transpose_a: isize,msize = msize,isize
assert isize == c.shape[-2]
assert (msize == b.shape[-1]) if transpose_b else (msize == b.shape[-2])
assert (osize == b.shape[-2]) if transpose_b else (osize == b.shape[-1])
matmul_prg = clbuild("matmul", """
__kernel void matmul(
__global const float *input, __global const float *weight, __global float *res,
int isize, int is0, int is1, int msize, int ws0, int ws1, int osize
) {
int stride = get_global_id(2);
int X = get_global_id(0); // isize
int Y = get_global_id(1); // osize
int X = get_global_id(0); // isize
int Y = get_global_id(1); // osize
float ret = 0.0;
for (int x = 0; x < msize; x++) {
ret += input[X * is0 + x * is1 + isize*msize*stride] *
weight[Y * ws0 + x * ws1 + msize*osize*stride];
}
float ret = 0.0;
for (int x = 0; x < msize; x++) {
ret += input[X * is0 + x * is1 + isize*msize*stride] *
weight[Y * ws0 + x * ws1 + msize*osize*stride];
}
res[X * osize + Y + isize*osize*stride] = ret;
}""")
res[X * osize + Y + isize*osize*stride] = ret;
}""")
matmul_prg([isize, osize, cnt], None,
a.cl, b.cl, c.cl,
isize,
msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize,
msize,
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
osize)
matmul_prg([isize, osize, cnt], None,
a.cl, b.cl, c.cl,
isize,
msize if not transpose_a else i32(1), i32(1) if not transpose_a else isize,
msize,
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
osize)

View File

@ -199,11 +199,8 @@ class Conv2D(Function):
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv([bs*groups*rcout, oy, ox], None,
x.cl, w.cl, ret.cl,
i32(H), i32(W), i32(groups), i32(rcout), i32(cin),
i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs)
)
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs
conv([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
return ret
def backward(ctx, grad_output):
@ -270,7 +267,7 @@ class Conv2D(Function):
}
""")
conv_args = i32(H), i32(W), i32(ctx.groups), i32(rcout), i32(cin), i32(oy), i32(ox), i32(iy), i32(ix), i32(ys), i32(xs), i32(bs)
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *conv_args)
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *conv_args)
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
return dx, dw