move CONVDW out of llops

This commit is contained in:
George Hotz 2022-06-15 12:05:11 -07:00
parent fef6c82491
commit 6d98366214
6 changed files with 24 additions and 54 deletions

View File

@ -133,12 +133,12 @@ You no longer need to write mlops for a new accelerator
The autodiff stuff is all in mlops now so you can focus on the raw operations The autodiff stuff is all in mlops now so you can focus on the raw operations
``` ```
Buffer # class of memory on this device Buffer # class of memory on this device
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
movement_op (RESHAPE, PERMUTE, SLICE, EXPAND) # A -> B (different size) movement_op (RESHAPE, PERMUTE, SLICE, EXPAND, FLIP) # A -> B (different size)
processing_op (CONV, CONVT, CONVDW) # A + B -> C processing_op (CONV, CONVT) # A + B -> C
``` ```
When tinygrad moves to lazy evaluation, optimizations will happen here. When tinygrad moves to lazy evaluation, optimizations will happen here.

View File

@ -77,15 +77,6 @@ def conv(x,w,ret,C):
tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) tmp[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3)))
ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox) ret[:] = np.moveaxis(tmp,4,2).reshape(C.bs, C.groups*C.rcout, C.oy, C.ox)
def convdw(x,grad_output,dw,C):
tx = get_tx(x, C)
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
gdw = dw.reshape((C.groups, C.rcout, C.cin, C.H, C.W))
gdw[:] = 0
for g in range(C.groups):
#'ikYX,ijYXyx -> kjyx'
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
def convdx(grad_output,w,dx,C): def convdx(grad_output,w,dx,C):
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
@ -102,4 +93,3 @@ def convdx(grad_output,w,dx,C):
def processing_op(op,a,b,ret,C): def processing_op(op,a,b,ret,C):
if op == ProcessingOps.CONV: conv(a,b,ret,C) if op == ProcessingOps.CONV: conv(a,b,ret,C)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,C) elif op == ProcessingOps.CONVT: convdx(a,b,ret,C)
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,C)

View File

@ -159,28 +159,6 @@ def conv(x,w,ret,C):
# tensw = (groups*rcout, cin, H, W) # tensw = (groups*rcout, cin, H, W)
# ggg = (bs, groups*rout, oy, ox) # ggg = (bs, groups*rout, oy, ox)
def convdw(x,grad_output,dw,C):
convdw_prg = clbuild("convdw", """
__kernel void convdw(__global const float *tensx, __global const float *ggg, __global float *dw,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs) {
int g = get_global_id(0)/(rcout*cin) ; // range 0-groups
int c = (get_global_id(0)/(cin)) %rcout; // range 0-rcout
int ci = get_global_id(0) % cin; // range 0-cin
int y = get_global_id(1); // range 0-H
int x = get_global_id(2); // range 0-W
float acc = 0.0;
for (int Y = 0; Y < oy; Y++) { for (int X = 0; X < ox; X++) {
for (int B = 0; B < bs; B++) {
acc += ggg[B*groups*rcout*oy*ox + +g*rcout*oy*ox + c*oy*ox + Y*ox + X] * \
tensx[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + (Y*ys+y)*ix + X*xs+x];
}
} }
dw[get_global_id(0)*H*W + y*W + x] = acc;
}""")
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C[0:12]])
def convdx(grad_output,w,dx,C): def convdx(grad_output,w,dx,C):
convdx_prg = clbuild("convdx", """ convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx, __kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,

View File

@ -25,17 +25,6 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
from tinygrad.ops import ProcessingOps from tinygrad.ops import ProcessingOps
def convdw(x,grad_output,dw,C):
# NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects
# https://github.com/pytorch/pytorch/issues/51430
grad_output = grad_output.reshape(C.bs * C.groups, C.rcout, C.oy, C.ox).repeat(1, C.cin, 1, 1)
grad_output = grad_output.reshape(C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox)
x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
# NOTE: this conv2d always has batch size 1.
grad_weight = torch.conv2d(x, grad_output, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin)
grad_weight = grad_weight.reshape(C.bs, C.groups, C.cin, C.rcout, *grad_weight.shape[2:]).transpose(3, 2).sum(dim=0)
dw[:] = grad_weight.reshape(C.groups*C.rcout, C.cin, *grad_weight.shape[3:])[:, :, :dw.shape[2], :dw.shape[3]]
def processing_op(op,x,w,ret,C): def processing_op(op,x,w,ret,C):
stride, groups, dilation = (C.ys, C.xs), C.groups, (C.dy, C.dx) stride, groups, dilation = (C.ys, C.xs), C.groups, (C.dy, C.dx)
if op == ProcessingOps.CONV: if op == ProcessingOps.CONV:
@ -48,5 +37,3 @@ def processing_op(op,x,w,ret,C):
else: else:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)] output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation) ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation)
elif op == ProcessingOps.CONVDW:
convdw(x,w,ret,C)

View File

@ -170,10 +170,25 @@ class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1, dilation=1): def forward(ctx, x, w, stride=1, groups=1, dilation=1):
C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation) C = get_conv_args(x.shape, w.shape, stride, groups, dilation=dilation)
ctx.save_for_backward(x,w,C) ctx.save_for_backward(x,w,C)
return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.groups*C.rcout, C.oy, C.ox), C) return ctx.processing_op(ProcessingOps.CONV, x, w, (C.bs, C.cout, C.oy, C.ox), C)
def backward(ctx, grad_output): def backward(ctx, grad_output):
x, w, C = ctx.saved_tensors x, w, C = ctx.saved_tensors
dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None
dw = ctx.processing_op(ProcessingOps.CONVDW, x, grad_output, w.shape, C) if ctx.needs_input_grad[1] else None
# compute derivative of weights using ProcessingOps.CONV
xdw = ctx.movement_op(MovementOps.RESHAPE, x, (1, C.bs * C.groups * C.cin, C.iy, C.ix))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output, (C.bs * C.groups, 1, C.rcout, C.oy, C.ox))
# this expand is slow
grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox))
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin)
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw)
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox))
# sum across the batch dimension
grad_weight = ctx.reduce_op(ReduceOps.SUM, grad_weight, (1, *grad_weight.shape[1:]))
# flip channels out and in
grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (0,1,3,2,4,5))
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.groups*C.rcout, C.cin, Cdw.oy, Cdw.ox))
dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3])))
return dx, dw return dx, dw

View File

@ -4,7 +4,7 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"]) MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"]) ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT"])
from tinygrad.shapetracker import ShapeTracker from tinygrad.shapetracker import ShapeTracker