From fc7eabb86fd055e91c20d85fc6339e3894bf0aa2 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 11 Jun 2022 08:12:02 -0700 Subject: [PATCH] processing op --- README.md | 2 +- tinygrad/helpers.py | 1 + tinygrad/llops/ops_cpu.py | 10 ++++++++-- tinygrad/llops/ops_gpu.py | 10 ++++++++-- tinygrad/llops/ops_torch.py | 27 ++++++++++++++++++++------- tinygrad/mlops.py | 8 ++++---- 6 files changed, 42 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 6526d9b9..2b329405 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, A, CMPEQ) # A + B -> C (broadcasting supported) movement_op (RESHAPE, PERMUTE, SLICE) # A -> B (different size) -conv, convdw, convdx # A + B -> C +processing_op (CONV, CONVT, CONVDW) # A + B -> C ``` When tinygrad moves to lazy evaluation, optimizations will happen here. diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 122e0ec6..b1651109 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -41,3 +41,4 @@ UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"]) BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "A", "CMPEQ"]) ReduceOps = Enum("ReduceOps", ["SUM", "MAX"]) MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE"]) +ProcessingOps = Enum("ProcessingOps", ["CONV", "CONVT", "CONVDW"]) diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index a190b3f1..fce1d548 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -1,5 +1,5 @@ import numpy as np -from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps +from tinygrad.helpers import get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps class CPUBuffer(np.ndarray): def relu(x): return np.maximum(x, 0) @@ -89,7 +89,7 @@ def convdw(x,grad_output,dw,stride,groups): gdw[g] += np.tensordot(ggg[:,g], tx[:,g], ((0,2,3),(0,2,3))) return dw -def convdx(w,grad_output,dx,stride,groups): +def convdx(grad_output,w,dx,stride,groups): C = get_conv_args(dx.shape, w.shape, stride, groups) ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) @@ -103,3 +103,9 @@ def convdx(w,grad_output,dx,stride,groups): tg = np.dot(ggg[:,g,:,Y,X].reshape(C.bs, -1), tw[g].reshape(C.rcout, -1)) gdx[:, g, :, iY:iY+C.H, iX:iX+C.W] += tg.reshape((C.bs, C.cin, C.H, C.W)) return dx + +def processing_op(op,a,b,ret,stride,groups): + if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups) + return ret diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 019ad7b6..c7e93fcf 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -1,7 +1,7 @@ import functools import numpy as np import pyopencl as cl -from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps +from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps cl_ctx, cl_queue = None, None def require_init_gpu(): @@ -256,7 +256,7 @@ def convdw(x,grad_output,dw,stride,groups): convdw_prg([C.groups*C.rcout*C.cin, C.H, C.W], None, x.cl, grad_output.cl, dw.cl, *[i32(x) for x in C]) return dw -def convdx(w,grad_output,dx,stride,groups): +def convdx(grad_output,w,dx,stride,groups): C = get_conv_args(dx.shape, w.shape, stride, groups) convdx_prg = clbuild("convdx", """ __kernel void convdx(__global const float *tensw, __global const float *ggg, __global float *dx, @@ -284,3 +284,9 @@ def convdx(w,grad_output,dx,stride,groups): """) convdx_prg([C.bs, C.groups, C.cin], None, w.cl, grad_output.cl, dx.cl, *[i32(x) for x in C]) return dx + +def processing_op(op,a,b,ret,stride,groups): + if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups) + return ret diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index aa0c4db3..eb0fd188 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -1,6 +1,5 @@ import torch import numpy as np -from tinygrad.helpers import get_conv_args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") class TorchBuffer(torch.Tensor): @@ -24,18 +23,21 @@ from tinygrad.llops.ops_cpu import unary_op, binary_op, reduce_op, movement_op # ************* processing ops ************* +from tinygrad.helpers import get_conv_args, ProcessingOps + def conv(x,w,ret,stride,groups): - ret[:] = torch.nn.functional.conv2d(x, w, stride=stride, groups=groups) + ret[:] = torch.conv2d(x, w, stride=stride, groups=groups) return ret -def convdw(input,grad_output,dw,stride,groups): +def convdw(x,grad_output,dw,stride,groups): # NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects # https://github.com/pytorch/pytorch/issues/51430 - C = get_conv_args(input.shape, dw.shape, stride, groups) + C = get_conv_args(x.shape, dw.shape, stride, groups) grad_output = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox).repeat(1, 1, C.cin, 1, 1) grad_output = grad_output.reshape(C.bs * C.groups * C.rcout * C.cin, 1, C.oy, C.ox) - input = input.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix) - grad_weight = torch.nn.functional.conv2d(input, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin) + x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix) + #print(input.shape, grad_output.shape) + grad_weight = torch.conv2d(x, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin) grad_weight = grad_weight.reshape(C.bs, grad_weight.shape[1] // C.bs, *grad_weight.shape[2:]).sum(dim=0) grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, *grad_weight.shape[1:]).transpose(2, 1) # narrow removes excess for strided @@ -43,6 +45,17 @@ def convdw(input,grad_output,dw,stride,groups): 2, 0, dw.shape[2]).narrow(3, 0, dw.shape[3]) return dw -def convdx(w,grad_output,dx,stride,groups): +def convdx(grad_output,w,dx,stride,groups): dx[:] = torch.nn.grad.conv2d_input(dx.shape, w, grad_output, stride=stride, groups=groups) + # correct for non strided + # strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + #C = get_conv_args(dx.shape, w.shape, stride, groups) + #w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W) + #ret = torch.conv2d(grad_output, w, padding=(C.H-1,C.W-1), groups=groups) return dx + +def processing_op(op,a,b,ret,stride,groups): + if op == ProcessingOps.CONV: conv(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVT: convdx(a,b,ret,stride,groups) + elif op == ProcessingOps.CONVDW: convdw(a,b,ret,stride,groups) + return ret diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 3582336c..f904202d 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -1,5 +1,5 @@ import numpy as np # TODO: remove this, it's used for np.prod and np.argsort -from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps +from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps from tinygrad.tensor import Function # ************* unary ops ************* @@ -162,10 +162,10 @@ class Conv2D(Function): def forward(ctx, x, w, stride=1, groups=1): C = get_conv_args(x.shape, w.shape, stride, groups) ctx.save_for_backward(x,w,(C.ys,C.xs), C.groups) - return ctx.op.conv(x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups) + return ctx.op.processing_op(ProcessingOps.CONV, x, w, ctx.buffer((C.bs, C.groups*C.rcout, C.oy, C.ox)), (C.ys,C.xs), C.groups) def backward(ctx, grad_output): x, w, stride, groups = ctx.saved_tensors - dx = ctx.op.convdx(w, grad_output, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None - dw = ctx.op.convdw(x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None + dx = ctx.op.processing_op(ProcessingOps.CONVT, grad_output, w, ctx.buffer(x.shape), stride, groups) if ctx.needs_input_grad[0] else None + dw = ctx.op.processing_op(ProcessingOps.CONVDW, x, grad_output, ctx.buffer(w.shape), stride, groups) if ctx.needs_input_grad[1] else None return dx, dw \ No newline at end of file