remove convt and compute dx with conv

This commit is contained in:
George Hotz 2022-06-15 19:54:15 -07:00
parent 142c88f2e3
commit ff648e9510
6 changed files with 17 additions and 61 deletions

View File

@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
processing_op (CONV, CONVT) # A + B -> C
processing_op (CONV) # A + B -> C
```
When tinygrad moves to lazy evaluation, optimizations will happen here.

View File

@ -78,19 +78,5 @@ def conv(x,w,ret,C):
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
def convdx(grad_output,w,dx,C):
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
gdx = dx.reshape((C.bs, C.groups, C.cin, C.iy, C.ix))
gdx[:] = 0
for k in range(C.oy*C.ox):
Y, X = k//C.ox, k%C.ox
iY,iX = Y*C.ys, X*C.xs
#gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw)
for g in range(C.groups):
tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1))
gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W))
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)

View File

@ -158,37 +158,5 @@ def conv(x,w,ret,C):
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
# tensx = (bs, groups*cin, iy, ix)
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
def convdx(grad_output,w,dx,C):
convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int B = get_global_id(0);
int g = get_global_id(1);
int ci = get_global_id(2);
for (int Y = 0; Y < iy; Y++) { for (int X = 0; X < ix; X++) {
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + Y*ix + X] = 0.0;
} }
for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) {
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
float acc = 0.0;
for (int c = 0; c < rcout; c++) {
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
}
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
} }
} }
}
""")
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C[0:12]])
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)

View File

@ -26,17 +26,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
from tinygrad.ops import ProcessingOps
def processing_op(op,x,w,ret,C):
stride, groups, dilation, padding = (C.ys, C.xs), C.groups, (C.dy, C.dx), (C.py, C.px)
# stride is the same as doing the full conv and slicing with stride at the end
# dilation is the same as conving with a larger weight matrix with 0s added
if op == ProcessingOps.CONV:
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups, dilation=dilation, padding=padding)
elif op == ProcessingOps.CONVT:
if stride == (1,1):
# strided needs weird "unstride": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
# it's 0 insertion between the inputs
w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
ret[:] = torch.conv2d(x, w, dilation=dilation, padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=groups)
else:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - padding[d]*2 - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, padding=padding, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation)
ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))

View File

@ -174,7 +174,20 @@ class Conv2D(Function):
def backward(ctx, grad_output):
x, w, C = ctx.saved_tensors
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
#dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
xt = grad_output
if C.xs > 1 or C.ys > 1: # unstride
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs)))
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs))
wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups)
# TODO: this shape can be wrong. support asymmetric padding to remove the slice
dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx)
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
# compute derivative of weights using ProcessingOps.CONV
# TODO: there has to be a way to do this without the expand/reduce for at least matmul

View File

@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"])
ProcessingOps = Enum("ProcessingOps", ["CONV"])
from tinygrad.shapetracker import ShapeTracker