From 6d983662145669e533e1dc16bdfc562ecc6b0d55 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 15 Jun 2022 12:05:11 -0700 Subject: [PATCH] move CONVDW out of llops --- README.md | 12 ++++++------ tinygrad/llops/ops_cpu.py | 10 ---------- tinygrad/llops/ops_gpu.py | 22 ---------------------- tinygrad/llops/ops_torch.py | 13 ------------- tinygrad/mlops.py | 19 +++++++++++++++++-- tinygrad/ops.py | 2 +- 6 files changed, 24 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index a8da3260..47552076 100644 --- a/README.md +++ b/README.md @@ -133,12 +133,12 @@ You no longer need to write mlops for a new accelerator The autodiff stuff is all in mlops now so you can focus on the raw operations ``` -Buffer # class of memory on this device -unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A -reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) -binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size) -movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size) -processing_op (CONV, CONVT, CONVDW) # A + B -> C +Buffer # class of memory on this device +unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A +reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) +binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size) +movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size) +processing_op (CONV, CONVT) # A + B -> C ``` When tinygrad moves to lazy evaluation, optimizations will happen here. diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 965d3a76..b4ecac9a 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -77,15 +77,6 @@ def conv(x,w,ret,C): tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox) -def convdw(x,grad_output,dw,C): - tx = get_tx(x, C) - ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) - gdw = dw.reshape((C.groups, C.rcout, C.cin, C.H, C.W)) - gdw[:] = 0 - for g in range(C.groups): - #'ikYX,ijYXyx -> kjyx' - gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3))) - def convdx(grad_output,w,dx,C): ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) @@ -102,4 +93,3 @@ def convdx(grad_output,w,dx,C): def processing_op(op,a,b,ret,C): if op == ProcessingOps.CONV: conv(a,b,ret,C) elif op == ProcessingOps.CONVT: convdx(a,b,ret,C) - elif op == ProcessingOps.CONVDW: convdw(a,b,ret,C) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 7593abd4..3ecad889 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -159,28 +159,6 @@ def conv(x,w,ret,C): # tensw = (groups*rcout, cin, H, W) # ggg = (bs, groups*rout, oy, ox) -def convdw(x,grad_output,dw,C): - convdw_prg = clbuild("convdw", """ - __kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw, - int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) { - - int g = get_global_id(0)/(rcout*cin) ; // range 0-groups - int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout - int ci = get_global_id(0) % cin; // range 0-cin - int y = get_global_id(1); // range 0-H - int x = get_global_id(2); // range 0-W - - float acc = 0.0; - for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) { - for (int B = 0; B < bs; B++) { - acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \ - tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x]; - } - } } - dw[get_global_id(0)*H*W + y*W + x] = acc; - }""") - convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C[0:12]]) - def convdx(grad_output,w,dx,C): convdx_prg = clbuild("convdx", """ __kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx, diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 5abf3f0f..4032f1ac 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -25,17 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op from tinygrad.ops import ProcessingOps -def convdw(x,grad_output,dw,C): - # NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects - # https://github.com/pytorch/pytorch/issues/51430 - grad_output = grad_output.reshape(C.bs * C.groups, C.rcout, C.oy, C.ox).repeat(1, C.cin, 1, 1) - grad_output = grad_output.reshape(C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox) - x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix) - # NOTE: this conv2d always has batch size 1. - grad_weight = torch.conv2d(x, grad_output, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin) - grad_weight = grad_weight.reshape(C.bs, C.groups, C.cin, C.rcout, *grad_weight.shape[2:]).transpose(3, 2).sum(dim=0) - dw[:] = grad_weight.reshape(C.groups*C.rcout, C.cin, *grad_weight.shape[3:])[:, :, :dw.shape[2], :dw.shape[3]] - def processing_op(op,x,w,ret,C): stride, groups, dilation = (C.ys, C.xs), C.groups, (C.dy, C.dx) if op == ProcessingOps.CONV: @@ -48,5 +37,3 @@ def processing_op(op,x,w,ret,C): else: output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)] ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation) - elif op == ProcessingOps.CONVDW: - convdw(x,w,ret,C) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 84ccbdf3..f4700ee4 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -170,10 +170,25 @@ class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1, dilation=1): C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation) ctx.save_for_backward(x,w,C) - return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), C) + return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.cout, C.oy, C.ox), C) def backward(ctx, grad_output): x, w, C = ctx.saved_tensors dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None - dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, C) if ctx.needs_input_grad[1] else None + + # compute derivative of weights using ProcessingOps.CONV + xdw = ctx.movement_op(MovementOps.RESHAPE, x, (1, C.bs * C.groups * C.cin, C.iy, C.ix)) + grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups, 1, C.rcout, C.oy, C.ox)) + # this expand is slow + grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox)) + grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox)) + Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin) + grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw) + grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox)) + # sum across the batch dimension + grad_weight = ctx.reduce_op(ReduceOps.SUM, grad_weight, (1, *grad_weight.shape[1:])) + # flip channels out and in + grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (0,1,3,2,4,5)) + grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.groups*C.rcout, C.cin, Cdw.oy, Cdw.ox)) + dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) return dx, dw \ No newline at end of file diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 71d74ae1..3c1dae1e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"]) -ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"]) +ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"]) from tinygrad.shapetracker import ShapeTracker