mirror of https://github.com/commaai/tinygrad.git
factor out convs
This commit is contained in:
parent
b49bfb6e02
commit
f01bad36c2
|
@ -225,3 +225,100 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
|
|||
msize,
|
||||
i32(1) if not transpose_b else msize, osize if not transpose_b else i32(1),
|
||||
osize)
|
||||
return c
|
||||
|
||||
|
||||
# TODO: combine any of these three?
|
||||
def conv(x,w,ret,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
conv_prg = clbuild("conv", """
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int B = get_global_id(0)/(groups*rcout); // range 0-bs
|
||||
int g = (get_global_id(0)/rcout)%groups;
|
||||
int c = get_global_id(0) % rcout;
|
||||
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
int IY = Y*ys;
|
||||
int IX = X*xs;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = IY; y < IY+H; y++) {
|
||||
for (int x = IX; x < IX+W; x++) {
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
|
||||
}
|
||||
}
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv_prg([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
|
||||
return ret
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
# ggg = (bs, groups*rout, oy, ox)
|
||||
|
||||
def convdw(x,grad_output,dw,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
convdw_prg = clbuild("convdw", """
|
||||
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
|
||||
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
|
||||
int ci = get_global_id(0) % cin; // range 0-cin
|
||||
int y = get_global_id(1); // range 0-H
|
||||
int x = get_global_id(2); // range 0-W
|
||||
|
||||
float acc = 0.0;
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int B = 0; B < bs; B++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
|
||||
}
|
||||
}
|
||||
}
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convdw_prg([groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,conv_args):
|
||||
H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs = conv_args
|
||||
convdx_prg = clbuild("convdx", """
|
||||
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int B = get_global_id(0);
|
||||
int g = get_global_id(1);
|
||||
int ci = get_global_id(2);
|
||||
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int y = 0; y < H; y++) {
|
||||
for (int x = 0; x < W; x++) {
|
||||
float acc = 0.0;
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
}
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
convdx_prg([bs, groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
|
||||
return dx
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pyopencl as cl
|
||||
import numpy as np
|
||||
from ..tensor import Function
|
||||
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul
|
||||
from ..llops.gpu import GPUBuffer, clbuild, buffer_new, unary_op, binary_op, reduce_op, perm_axis, inner_slice, matmul, conv, convdw, convdx
|
||||
|
||||
i32 = np.int32
|
||||
|
||||
|
@ -143,8 +143,7 @@ class Matmul(Function):
|
|||
assert input.shape[-1] == weight.shape[-2]
|
||||
ret = buffer_new(ctx, list(input.shape[0:-1])+[weight.shape[-1]])
|
||||
ctx.save_for_backward(input, weight)
|
||||
matmul(input, weight, ret)
|
||||
return ret
|
||||
return matmul(input, weight, ret)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
|
@ -168,40 +167,8 @@ class Conv2D(Function):
|
|||
ctx.save_for_backward(x,w)
|
||||
|
||||
# output buffer
|
||||
ret = buffer_new(ctx, (bs, cout, oy, ox))
|
||||
|
||||
# input = (bs, groups, cin, iy, ix)
|
||||
# weight = (groups, rcout, cin, H, W)
|
||||
# output = (bs, groups, rcout, oy, ox)
|
||||
|
||||
conv = clbuild("conv", """
|
||||
__kernel void conv(__global const float *input, __global const float *weight, __global float *output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs) {
|
||||
|
||||
int B = get_global_id(0)/(groups*rcout); // range 0-bs
|
||||
int g = (get_global_id(0)/rcout)%groups;
|
||||
int c = get_global_id(0) % rcout;
|
||||
|
||||
int Y = get_global_id(1); // range 0-oy
|
||||
int X = get_global_id(2); // range 0-ox
|
||||
int IY = Y*ys;
|
||||
int IX = X*xs;
|
||||
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
for (int y = IY; y < IY+H; y++) {
|
||||
for (int x = IX; x < IX+W; x++) {
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + y*ix + x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + (y-IY)*W + (x-IX)];
|
||||
}
|
||||
}
|
||||
}
|
||||
output[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] = acc;
|
||||
}""")
|
||||
|
||||
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs
|
||||
conv([bs*groups*rcout, oy, ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in conv_args])
|
||||
return ret
|
||||
conv_args = H, W, groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
|
||||
return conv(x, w, buffer_new(ctx, (bs, cout, oy, ox)), conv_args)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
bs,_,oy,ox = grad_output.shape
|
||||
|
@ -214,60 +181,7 @@ class Conv2D(Function):
|
|||
assert cout % ctx.groups == 0
|
||||
rcout = cout//ctx.groups
|
||||
|
||||
dx = buffer_new(ctx, (bs, cin_, iy, ix), zero=True)
|
||||
dw = buffer_new(ctx, (cout, cin, H, W))
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
# ggg = (bs, groups*rout, oy, ox)
|
||||
|
||||
convw = clbuild("convw", """
|
||||
__kernel void convw(__global const float *tensx, __global const float *ggg, __global float *dw,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
|
||||
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
|
||||
int ci = get_global_id(0) % cin; // range 0-cin
|
||||
int y = get_global_id(1); // range 0-H
|
||||
int x = get_global_id(2); // range 0-W
|
||||
|
||||
float acc = 0.0;
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int B = 0; B < bs; B++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
|
||||
}
|
||||
}
|
||||
}
|
||||
dw[get_global_id(0)*H*W + y*W + x] = acc;
|
||||
}""")
|
||||
convx = clbuild("convx", """
|
||||
__kernel void convx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int B = get_global_id(0);
|
||||
int g = get_global_id(1);
|
||||
int ci = get_global_id(2);
|
||||
|
||||
for (int Y = 0; Y < oy; Y++) {
|
||||
for (int X = 0; X < ox; X++) {
|
||||
for (int y = 0; y < H; y++) {
|
||||
for (int x = 0; x < W; x++) {
|
||||
float acc = 0.0;
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
}
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs
|
||||
convw([ctx.groups*rcout*cin, H, W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in conv_args])
|
||||
convx([bs, ctx.groups, cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in conv_args])
|
||||
dw = convdw(x, grad_output, buffer_new(ctx, (cout, cin, H, W)), conv_args)
|
||||
dx = convdx(w, grad_output, buffer_new(ctx, (bs, cin_, iy, ix), zero=True), conv_args)
|
||||
return dx, dw
|
||||
|
|
Loading…
Reference in New Issue