mirror of https://github.com/commaai/tinygrad.git
remove convt and compute dx with conv
This commit is contained in:
parent
142c88f2e3
commit
ff648e9510
|
@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
|||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
|
||||
processing_op (CONV, CONVT) # A + B -> C
|
||||
processing_op (CONV) # A + B -> C
|
||||
```
|
||||
|
||||
When tinygrad moves to lazy evaluation, optimizations will happen here.
|
||||
|
|
|
@ -78,19 +78,5 @@ def conv(x,w,ret,C):
|
|||
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
|
||||
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
|
||||
|
||||
def convdx(grad_output,w,dx,C):
|
||||
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
gdx = dx.reshape((C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
gdx[:] = 0
|
||||
for k in range(C.oy*C.ox):
|
||||
Y, X = k//C.ox, k%C.ox
|
||||
iY,iX = Y*C.ys, X*C.xs
|
||||
#gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw)
|
||||
for g in range(C.groups):
|
||||
tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1))
|
||||
gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W))
|
||||
|
||||
def processing_op(op,a,b,ret,C):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,C)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)
|
||||
|
|
|
@ -158,37 +158,5 @@ def conv(x,w,ret,C):
|
|||
|
||||
conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
|
||||
# tensx = (bs, groups*cin, iy, ix)
|
||||
# tensw = (groups*rcout, cin, H, W)
|
||||
# ggg = (bs, groups*rout, oy, ox)
|
||||
|
||||
def convdx(grad_output,w,dx,C):
|
||||
convdx_prg = clbuild("convdx", """
|
||||
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
|
||||
|
||||
int B = get_global_id(0);
|
||||
int g = get_global_id(1);
|
||||
int ci = get_global_id(2);
|
||||
|
||||
for (int Y = 0; Y < iy; Y++) { for (int X = 0; X < ix; X++) {
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + Y*ix + X] = 0.0;
|
||||
} }
|
||||
|
||||
for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) {
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
float acc = 0.0;
|
||||
for (int c = 0; c < rcout; c++) {
|
||||
acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
|
||||
tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
}
|
||||
dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc;
|
||||
} }
|
||||
} }
|
||||
}
|
||||
""")
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C[0:12]])
|
||||
|
||||
def processing_op(op,a,b,ret,C):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,C)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)
|
||||
|
|
|
@ -26,17 +26,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
|||
from tinygrad.ops import ProcessingOps
|
||||
|
||||
def processing_op(op,x,w,ret,C):
|
||||
stride, groups, dilation, padding = (C.ys, C.xs), C.groups, (C.dy, C.dx), (C.py, C.px)
|
||||
# stride is the same as doing the full conv and slicing with stride at the end
|
||||
# dilation is the same as conving with a larger weight matrix with 0s added
|
||||
if op == ProcessingOps.CONV:
|
||||
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups, dilation=dilation, padding=padding)
|
||||
elif op == ProcessingOps.CONVT:
|
||||
if stride == (1,1):
|
||||
# strided needs weird "unstride": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
# it's 0 insertion between the inputs
|
||||
w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
|
||||
ret[:] = torch.conv2d(x, w, dilation=dilation, padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=groups)
|
||||
else:
|
||||
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - padding[d]*2 - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)]
|
||||
ret[:] = torch.conv_transpose2d(x, w, padding=padding, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation)
|
||||
ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px))
|
||||
|
|
|
@ -174,7 +174,20 @@ class Conv2D(Function):
|
|||
|
||||
def backward(ctx, grad_output):
|
||||
x, w, C = ctx.saved_tensors
|
||||
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
|
||||
#dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
|
||||
xt = grad_output
|
||||
if C.xs > 1 or C.ys > 1: # unstride
|
||||
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1))
|
||||
xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs)))
|
||||
xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs))
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4))
|
||||
wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4))
|
||||
wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W))
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups)
|
||||
# TODO: this shape can be wrong. support asymmetric padding to remove the slice
|
||||
dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx)
|
||||
dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape])
|
||||
|
||||
# compute derivative of weights using ProcessingOps.CONV
|
||||
# TODO: there has to be a way to do this without the expand/reduce for at least matmul
|
||||
|
|
|
@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
|||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV"])
|
||||
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
|
||||
|
|
Loading…
Reference in New Issue