factor out convs

This commit is contained in:
George Hotz 2022-06-05 14:48:42 -07:00
parent b49bfb6e02
commit f01bad36c2
2 changed files with 103 additions and 92 deletions

View File

@ -225,3 +225,100 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
msize,
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
osize)
return c
# TODO: combine any of these three?
def conv(x,w,ret,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
conv_prg = clbuild("conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int B = get_global_id(0)/(groups*rcout); // range 0-bs
int g = (get_global_id(0)/rcout)%groups;
int c = get_global_id(0) % rcout;
int Y = get_global_id(1); // range 0-oy
int X = get_global_id(2); // range 0-ox
int IY = Y*ys;
int IX = X*xs;
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) {
for (int x = IX; x < IX+W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
}
}
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_prg([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
return ret
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
def convdw(x,grad_output,dw,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
convdw_prg = clbuild("convdw", """
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
int ci = get_global_id(0) % cin; // range 0-cin
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
}
}
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convdw_prg([groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
return dw
def convdx(w,grad_output,dx,conv_args):
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int B = get_global_id(0);
int g = get_global_id(1);
int ci = get_global_id(2);
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
float acc = 0.0;
for (int c = 0; c < rcout; c++) {
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
}
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
}
}
}
}
}
""")
convdx_prg([bs, groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
return dx

View File

@ -1,7 +1,7 @@
import pyopencl as cl
import numpy as np
from ..tensor import Function
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul, conv, convdw, convdx
i32 = np.int32
@ -143,8 +143,7 @@ class Matmul(Function):
assert input.shape[-1] == weight.shape[-2]
ret = buffer_new(ctx, list(input.shape[0:-1])+[weight.shape[-1]])
ctx.save_for_backward(input, weight)
matmul(input, weight, ret)
return ret
return matmul(input, weight, ret)
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
@ -168,40 +167,8 @@ class Conv2D(Function):
ctx.save_for_backward(x,w)
# output buffer
ret = buffer_new(ctx, (bs, cout, oy, ox))
# input = (bs, groups, cin, iy, ix)
# weight = (groups, rcout, cin, H, W)
# output = (bs, groups, rcout, oy, ox)
conv = clbuild("conv", """
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
int B = get_global_id(0)/(groups*rcout); // range 0-bs
int g = (get_global_id(0)/rcout)%groups;
int c = get_global_id(0) % rcout;
int Y = get_global_id(1); // range 0-oy
int X = get_global_id(2); // range 0-ox
int IY = Y*ys;
int IX = X*xs;
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
for (int y = IY; y < IY+H; y++) {
for (int x = IX; x < IX+W; x++) {
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
}
}
}
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
}""")
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs
conv([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
return ret
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
return conv(x, w, buffer_new(ctx, (bs, cout, oy, ox)), conv_args)
def backward(ctx, grad_output):
bs,_,oy,ox = grad_output.shape
@ -214,60 +181,7 @@ class Conv2D(Function):
assert cout % ctx.groups == 0
rcout = cout//ctx.groups
dx = buffer_new(ctx, (bs, cin_, iy, ix), zero=True)
dw = buffer_new(ctx, (cout, cin, H, W))
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
convw = clbuild("convw", """
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
int ci = get_global_id(0) % cin; // range 0-cin
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
}
}
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convx = clbuild("convx", """
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int B = get_global_id(0);
int g = get_global_id(1);
int ci = get_global_id(2);
for (int Y = 0; Y < oy; Y++) {
for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) {
for (int x = 0; x < W; x++) {
float acc = 0.0;
for (int c = 0; c < rcout; c++) {
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
}
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
}
}
}
}
}
""")
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
dw = convdw(x, grad_output, buffer_new(ctx, (cout, cin, H, W)), conv_args)
dx = convdx(w, grad_output, buffer_new(ctx, (bs, cin_, iy, ix), zero=True), conv_args)
return dx, dw