processing op

This commit is contained in:
George Hotz 2022-06-11 08:12:02 -07:00
parent 72186ebd5a
commit fc7eabb86f
6 changed files with 42 additions and 16 deletions

View File

@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported)
movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size)
conv, convdw, convdx # A + B -> C
processing_op (CONV, CONVT, CONVDW) # A + B -> C
```
When tinygrad moves to lazy evaluation, optimizations will happen here.

View File

@ -41,3 +41,4 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])

View File

@ -1,5 +1,5 @@
import numpy as np
from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
class CPUBuffer(np.ndarray):
def relu(x): return np.maximum(x, 0)
@ -89,7 +89,7 @@ def convdw(x,grad_output,dw,stride,groups):
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
return dw
def convdx(w,grad_output,dx,stride,groups):
def convdx(grad_output,w,dx,stride,groups):
C = get_conv_args(dx.shape, w.shape, stride, groups)
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
@ -103,3 +103,9 @@ def convdx(w,grad_output,dx,stride,groups):
tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1))
gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W))
return dx
def processing_op(op,a,b,ret,stride,groups):
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
return ret

View File

@ -1,7 +1,7 @@
import functools
import numpy as np
import pyopencl as cl
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
cl_ctx, cl_queue = None, None
def require_init_gpu():
@ -256,7 +256,7 @@ def convdw(x,grad_output,dw,stride,groups):
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
return dw
def convdx(w,grad_output,dx,stride,groups):
def convdx(grad_output,w,dx,stride,groups):
C = get_conv_args(dx.shape, w.shape, stride, groups)
convdx_prg = clbuild("convdx", """
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
@ -284,3 +284,9 @@ def convdx(w,grad_output,dx,stride,groups):
""")
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
return dx
def processing_op(op,a,b,ret,stride,groups):
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
return ret

View File

@ -1,6 +1,5 @@
import torch
import numpy as np
from tinygrad.helpers import get_conv_args
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TorchBuffer(torch.Tensor):
@ -24,18 +23,21 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
# ************* processing ops *************
from tinygrad.helpers import get_conv_args, ProcessingOps
def conv(x,w,ret,stride,groups):
ret[:] = torch.nn.functional.conv2d(x, w, stride=stride, groups=groups)
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups)
return ret
def convdw(input,grad_output,dw,stride,groups):
def convdw(x,grad_output,dw,stride,groups):
# NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects
# https://github.com/pytorch/pytorch/issues/51430
C = get_conv_args(input.shape, dw.shape, stride, groups)
C = get_conv_args(x.shape, dw.shape, stride, groups)
grad_output = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox).repeat(1, 1, C.cin, 1, 1)
grad_output = grad_output.reshape(C.bs * C.groups * C.rcout * C.cin, 1, C.oy, C.ox)
input = input.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
grad_weight = torch.nn.functional.conv2d(input, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
#print(input.shape, grad_output.shape)
grad_weight = torch.conv2d(x, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
grad_weight = grad_weight.reshape(C.bs, grad_weight.shape[1] // C.bs, *grad_weight.shape[2:]).sum(dim=0)
grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, *grad_weight.shape[1:]).transpose(2, 1)
# narrow removes excess for strided
@ -43,6 +45,17 @@ def convdw(input,grad_output,dw,stride,groups):
2, 0, dw.shape[2]).narrow(3, 0, dw.shape[3])
return dw
def convdx(w,grad_output,dx,stride,groups):
def convdx(grad_output,w,dx,stride,groups):
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=stride, groups=groups)
# correct for non strided
# strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
#C = get_conv_args(dx.shape, w.shape, stride, groups)
#w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
#ret = torch.conv2d(grad_output, w, padding=(C.H-1,C.W-1), groups=groups)
return dx
def processing_op(op,a,b,ret,stride,groups):
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
return ret

View File

@ -1,5 +1,5 @@
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.tensor import Function
# ************* unary ops *************
@ -162,10 +162,10 @@ class Conv2D(Function):
def forward(ctx, x, w, stride=1, groups=1):
C = get_conv_args(x.shape, w.shape, stride, groups)
ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups)
return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
return ctx.op.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
def backward(ctx, grad_output):
x, w, stride, groups = ctx.saved_tensors
dx = ctx.op.convdx(w, grad_output, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
dw = ctx.op.convdw(x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
dx = ctx.op.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
dw = ctx.op.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
return dx, dw