mirror of https://github.com/commaai/tinygrad.git
processing op
This commit is contained in:
parent
72186ebd5a
commit
fc7eabb86f
|
@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
|||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported)
|
||||
movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size)
|
||||
conv, convdw, convdx # A + B -> C
|
||||
processing_op (CONV, CONVT, CONVDW) # A + B -> C
|
||||
```
|
||||
|
||||
When tinygrad moves to lazy evaluation, optimizations will happen here.
|
||||
|
|
|
@ -41,3 +41,4 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
|||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"])
|
||||
ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np
|
||||
from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
|
||||
class CPUBuffer(np.ndarray):
|
||||
def relu(x): return np.maximum(x, 0)
|
||||
|
@ -89,7 +89,7 @@ def convdw(x,grad_output,dw,stride,groups):
|
|||
gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3)))
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,stride,groups):
|
||||
def convdx(grad_output,w,dx,stride,groups):
|
||||
C = get_conv_args(dx.shape, w.shape, stride, groups)
|
||||
ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox)
|
||||
tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W)
|
||||
|
@ -103,3 +103,9 @@ def convdx(w,grad_output,dx,stride,groups):
|
|||
tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1))
|
||||
gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W))
|
||||
return dx
|
||||
|
||||
def processing_op(op,a,b,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
|
||||
return ret
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
|
@ -256,7 +256,7 @@ def convdw(x,grad_output,dw,stride,groups):
|
|||
convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C])
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,stride,groups):
|
||||
def convdx(grad_output,w,dx,stride,groups):
|
||||
C = get_conv_args(dx.shape, w.shape, stride, groups)
|
||||
convdx_prg = clbuild("convdx", """
|
||||
__kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx,
|
||||
|
@ -284,3 +284,9 @@ def convdx(w,grad_output,dx,stride,groups):
|
|||
""")
|
||||
convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C])
|
||||
return dx
|
||||
|
||||
def processing_op(op,a,b,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
|
||||
return ret
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import get_conv_args
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
class TorchBuffer(torch.Tensor):
|
||||
|
@ -24,18 +23,21 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op
|
|||
|
||||
# ************* processing ops *************
|
||||
|
||||
from tinygrad.helpers import get_conv_args, ProcessingOps
|
||||
|
||||
def conv(x,w,ret,stride,groups):
|
||||
ret[:] = torch.nn.functional.conv2d(x, w, stride=stride, groups=groups)
|
||||
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups)
|
||||
return ret
|
||||
|
||||
def convdw(input,grad_output,dw,stride,groups):
|
||||
def convdw(x,grad_output,dw,stride,groups):
|
||||
# NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects
|
||||
# https://github.com/pytorch/pytorch/issues/51430
|
||||
C = get_conv_args(input.shape, dw.shape, stride, groups)
|
||||
C = get_conv_args(x.shape, dw.shape, stride, groups)
|
||||
grad_output = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox).repeat(1, 1, C.cin, 1, 1)
|
||||
grad_output = grad_output.reshape(C.bs * C.groups * C.rcout * C.cin, 1, C.oy, C.ox)
|
||||
input = input.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
|
||||
grad_weight = torch.nn.functional.conv2d(input, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
|
||||
x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
|
||||
#print(input.shape, grad_output.shape)
|
||||
grad_weight = torch.conv2d(x, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
|
||||
grad_weight = grad_weight.reshape(C.bs, grad_weight.shape[1] // C.bs, *grad_weight.shape[2:]).sum(dim=0)
|
||||
grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, *grad_weight.shape[1:]).transpose(2, 1)
|
||||
# narrow removes excess for strided
|
||||
|
@ -43,6 +45,17 @@ def convdw(input,grad_output,dw,stride,groups):
|
|||
2, 0, dw.shape[2]).narrow(3, 0, dw.shape[3])
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,stride,groups):
|
||||
def convdx(grad_output,w,dx,stride,groups):
|
||||
dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=stride, groups=groups)
|
||||
# correct for non strided
|
||||
# strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
#C = get_conv_args(dx.shape, w.shape, stride, groups)
|
||||
#w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
|
||||
#ret = torch.conv2d(grad_output, w, padding=(C.H-1,C.W-1), groups=groups)
|
||||
return dx
|
||||
|
||||
def processing_op(op,a,b,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups)
|
||||
elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups)
|
||||
return ret
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import numpy as np # TODO: remove this, it's used for np.prod and np.argsort
|
||||
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
# ************* unary ops *************
|
||||
|
@ -162,10 +162,10 @@ class Conv2D(Function):
|
|||
def forward(ctx, x, w, stride=1, groups=1):
|
||||
C = get_conv_args(x.shape, w.shape, stride, groups)
|
||||
ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups)
|
||||
return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
|
||||
return ctx.op.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups)
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
x, w, stride, groups = ctx.saved_tensors
|
||||
dx = ctx.op.convdx(w, grad_output, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.op.convdw(x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
|
||||
dx = ctx.op.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None
|
||||
dw = ctx.op.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None
|
||||
return dx, dw
|
Loading…
Reference in New Issue