From 30f81326463f493ae2218d13ed894e0b1333a33a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 30 Dec 2020 11:00:01 -0500 Subject: [PATCH] reorder ops in ops cpu --- README.md | 2 +- tinygrad/ops_cpu.py | 105 ++++++++++++++++++++++---------------------- 2 files changed, 54 insertions(+), 53 deletions(-) diff --git a/README.md b/README.md index 725f948f..fa442f2d 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ You need to support 14 first class ops: Relu, Log, Exp # unary ops Add, Sub, Mul, Pow # binary ops (with broadcasting) Sum, Max # reduce ops (with axis argument) -Reshape, Transpose, Slice # moving things around ops +Reshape, Transpose, Slice # movement ops Matmul, Conv2D # heavy data processing ops ``` diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 70bb8a21..78691ca4 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -2,7 +2,47 @@ import warnings import numpy as np from .tensor import Function, register -# ************* basic ops ************* +# ************* unary ops ************* + +class ReLU(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return np.maximum(input, 0) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output * (input >= 0) +register('relu', ReLU) + +class Log(Function): + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + return np.log(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + return grad_output / input +register('log', Log) + +class Exp(Function): + @staticmethod + def forward(ctx, input): + ret = np.exp(input) + ctx.save_for_backward(ret) + return ret + + @staticmethod + def backward(ctx, grad_output): + ret, = ctx.saved_tensors + return grad_output * ret +register('exp', Exp) + +# ************* binary ops ************* + def unbroadcast(out, in_sh): # adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i] sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None @@ -57,6 +97,8 @@ class Pow(Function): unbroadcast((x**y) * np.log(x) * grad_output, y.shape) register('pow', Pow) +# ************* reduce ops ************* + class Sum(Function): @staticmethod def forward(ctx, input, axis=None): @@ -87,22 +129,6 @@ class Max(Function): return ret register('max', Max) -# ************* GEMM ************* - -class Matmul(Function): - @staticmethod - def forward(ctx, input, weight): - ctx.save_for_backward(input, weight) - return input @ weight - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = grad_output @ np.swapaxes(weight, -2, -1) - grad_weight = np.swapaxes(input, -2, -1) @ grad_output - return grad_input, grad_weight -register('matmul', Matmul) - # ************* movement ops ************* def inner_slice(x, arg): @@ -147,46 +173,21 @@ class Transpose(Function): return np.transpose(x, np.argsort(ctx.order)) register('transpose', Transpose) -# ************* activation ops ************* +# ************* processing ops ************* -class ReLU(Function): +class Matmul(Function): @staticmethod - def forward(ctx, input): - ctx.save_for_backward(input) - return np.maximum(input, 0) + def forward(ctx, input, weight): + ctx.save_for_backward(input, weight) + return input @ weight @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors - return grad_output * (input >= 0) -register('relu', ReLU) - -class Log(Function): - @staticmethod - def forward(ctx, input): - ctx.save_for_backward(input) - return np.log(input) - - @staticmethod - def backward(ctx, grad_output): - input, = ctx.saved_tensors - return grad_output / input -register('log', Log) - -class Exp(Function): - @staticmethod - def forward(ctx, input): - ret = np.exp(input) - ctx.save_for_backward(ret) - return ret - - @staticmethod - def backward(ctx, grad_output): - ret, = ctx.saved_tensors - return grad_output * ret -register('exp', Exp) - -# ************* conv ops ************* + input, weight = ctx.saved_tensors + grad_input = grad_output @ np.swapaxes(weight, -2, -1) + grad_weight = np.swapaxes(input, -2, -1) @ grad_output + return grad_input, grad_weight +register('matmul', Matmul) class Conv2D(Function): @staticmethod