From ff648e951064dd9827dddd318c9cc4ad7c67f8ed Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 15 Jun 2022 19:54:15 -0700 Subject: [PATCH] remove convt and compute dx with conv --- README.md | 2 +- tinygrad/llops/ops_cpu.py | 14 -------------- tinygrad/llops/ops_gpu.py | 32 -------------------------------- tinygrad/llops/ops_torch.py | 13 +------------ tinygrad/mlops.py | 15 ++++++++++++++- tinygrad/ops.py | 2 +- 6 files changed, 17 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 47552076..151f053d 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size) movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size) -processing_op (CONV, CONVT) # A + B -> C +processing_op (CONV) # A + B -> C ``` When tinygrad moves to lazy evaluation, optimizations will happen here. diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 67795fb8..95727f57 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -78,19 +78,5 @@ def conv(x,w,ret,C): tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox) -def convdx(grad_output,w,dx,C): - ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) - tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) - gdx = dx.reshape((C.bs, C.groups, C.cin, C.iy, C.ix)) - gdx[:] = 0 - for k in range(C.oy*C.ox): - Y, X = k//C.ox, k%C.ox - iY,iX = Y*C.ys, X*C.xs - #gdx[:,:,: , iY:iY+H, iX:iX+W] += np.einsum('igk,gkjyx->igjyx', ggg[:,:,:,Y,X], tw) - for g in range(C.groups): - tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1)) - gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W)) - def processing_op(op,a,b,ret,C): if op == ProcessingOps.CONV: conv(a,b,ret,C) - elif op == ProcessingOps.CONVT: convdx(a,b,ret,C) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 54be8e7d..8e1fe04a 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -158,37 +158,5 @@ def conv(x,w,ret,C): conv_prg([C.bs*C.groups*C.rcout, C.oy, C.ox], None, x.cl, w.cl, ret.cl, *[i32(x) for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]]) -# tensx = (bs, groups*cin, iy, ix) -# tensw = (groups*rcout, cin, H, W) -# ggg = (bs, groups*rout, oy, ox) - -def convdx(grad_output,w,dx,C): - convdx_prg = clbuild("convdx", """ - __kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx, - int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { - - int B = get_global_id(0); - int g = get_global_id(1); - int ci = get_global_id(2); - - for (int Y = 0; Y < iy; Y++) { for (int X = 0; X < ix; X++) { - dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + Y*ix + X] = 0.0; - } } - - for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) { - for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) { - float acc = 0.0; - for (int c = 0; c < rcout; c++) { - acc += ggg[B*groups*rcout*oy*ox + g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \ - tensw[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x]; - } - dx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x] += acc; - } } - } } - } - """) - convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C[0:12]]) - def processing_op(op,a,b,ret,C): if op == ProcessingOps.CONV: conv(a,b,ret,C) - elif op == ProcessingOps.CONVT: convdx(a,b,ret,C) diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 62944bbb..3826b328 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -26,17 +26,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op from tinygrad.ops import ProcessingOps def processing_op(op,x,w,ret,C): - stride, groups, dilation, padding = (C.ys, C.xs), C.groups, (C.dy, C.dx), (C.py, C.px) # stride is the same as doing the full conv and slicing with stride at the end # dilation is the same as conving with a larger weight matrix with 0s added - if op == ProcessingOps.CONV: - ret[:] = torch.conv2d(x, w, stride=stride, groups=groups, dilation=dilation, padding=padding) - elif op == ProcessingOps.CONVT: - if stride == (1,1): - # strided needs weird "unstride": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md - # it's 0 insertion between the inputs - w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W) - ret[:] = torch.conv2d(x, w, dilation=dilation, padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=groups) - else: - output_padding = [ret.shape[d+2] - ((x.shape[d+2] - padding[d]*2 - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)] - ret[:] = torch.conv_transpose2d(x, w, padding=padding, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation) + ret[:] = torch.conv2d(x, w, stride=(C.ys, C.xs), groups=C.groups, dilation=(C.dy, C.dx), padding=(C.py, C.px)) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 459fe2f4..98b57594 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -174,7 +174,20 @@ class Conv2D(Function): def backward(ctx, grad_output): x, w, C = ctx.saved_tensors - dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None + #dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None + xt = grad_output + if C.xs > 1 or C.ys > 1: # unstride + xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) + xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs))) + xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs)) + wt = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W)) + wt = ctx.movement_op(MovementOps.FLIP, wt, (3, 4)) + wt = ctx.movement_op(MovementOps.PERMUTE, wt, (0, 2, 1, 3, 4)) + wt = ctx.movement_op(MovementOps.RESHAPE, wt, (C.groups*C.cin, C.rcout, C.H, C.W)) + Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=C.groups) + # TODO: this shape can be wrong. support asymmetric padding to remove the slice + dx = ctx.processing_op(ProcessingOps.CONV, xt, wt, (Cdx.bs, Cdx.cout, Cdx.oy, Cdx.ox), Cdx) + dx = ctx.movement_op(MovementOps.SLICE, dx, [(0,s) for s in x.shape]) # compute derivative of weights using ProcessingOps.CONV # TODO: there has to be a way to do this without the expand/reduce for at least matmul diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3c1dae1e..d3e7fa74 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"]) -ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"]) +ProcessingOps = Enum("ProcessingOps", ["CONV"]) from tinygrad.shapetracker import ShapeTracker