move CONVDW out of llops

This commit is contained in:
George Hotz 2022-06-15 12:05:11 -07:00
parent fef6c82491
commit 6d98366214
6 changed files with 24 additions and 54 deletions

View File

@ -137,8 +137,8 @@ Buffer # class of memory on this devic
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size)
processing_op (CONV, CONVT, CONVDW) # A + B -> C
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
processing_op (CONV, CONVT) # A + B -> C
```
When tinygrad moves to lazy evaluation, optimizations will happen here.

View File

@ -77,15 +77,6 @@ def conv(x,w,ret,C):
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
def convdw(x,grad_output,dw,C):
tx = get_tx(x, C)
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
gdw = dw.reshape((C.groups, C.rcout, C.cin, C.H, C.W))
gdw[:] = 0
for g in range(C.groups):
#'ikYX,ijYXyx -> kjyx'
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
def convdx(grad_output,w,dx,C):
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
@ -102,4 +93,3 @@ def convdx(grad_output,w,dx,C):
def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,C)

View File

@ -159,28 +159,6 @@ def conv(x,w,ret,C):
# tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox)
def convdw(x,grad_output,dw,C):
convdw_prg = clbuild("convdw", """
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
int ci = get_global_id(0) % cin; // range 0-cin
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
} }
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C[0:12]])
def convdx(grad_output,w,dx,C):
convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,

View File

@ -25,17 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
from tinygrad.ops import ProcessingOps
def convdw(x,grad_output,dw,C):
# NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects
# https://github.com/pytorch/pytorch/issues/51430
grad_output = grad_output.reshape(C.bs * C.groups, C.rcout, C.oy, C.ox).repeat(1, C.cin, 1, 1)
grad_output = grad_output.reshape(C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox)
x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
# NOTE: this conv2d always has batch size 1.
grad_weight = torch.conv2d(x, grad_output, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin)
grad_weight = grad_weight.reshape(C.bs, C.groups, C.cin, C.rcout, *grad_weight.shape[2:]).transpose(3, 2).sum(dim=0)
dw[:] = grad_weight.reshape(C.groups*C.rcout, C.cin, *grad_weight.shape[3:])[:, :, :dw.shape[2], :dw.shape[3]]
def processing_op(op,x,w,ret,C):
stride, groups, dilation = (C.ys, C.xs), C.groups, (C.dy, C.dx)
if op == ProcessingOps.CONV:
@ -48,5 +37,3 @@ def processing_op(op,x,w,ret,C):
else:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation)
elif op == ProcessingOps.CONVDW:
convdw(x,w,ret,C)

View File

@ -170,10 +170,25 @@ class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1, dilation=1):
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation)
ctx.save_for_backward(x,w,C)
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), C)
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.cout, C.oy, C.ox), C)
def backward(ctx, grad_output):
x, w, C = ctx.saved_tensors
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, C) if ctx.needs_input_grad[1] else None
# compute derivative of weights using ProcessingOps.CONV
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (1, C.bs * C.groups * C.cin, C.iy, C.ix))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups, 1, C.rcout, C.oy, C.ox))
# this expand is slow
grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox))
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin)
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw)
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox))
# sum across the batch dimension
grad_weight = ctx.reduce_op(ReduceOps.SUM, grad_weight, (1, *grad_weight.shape[1:]))
# flip channels out and in
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (0,1,3,2,4,5))
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.groups*C.rcout, C.cin, Cdw.oy, Cdw.ox))
dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))
return dx, dw

View File

@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"])
from tinygrad.shapetracker import ShapeTracker