reorder ops in ops cpu

This commit is contained in:
George Hotz 2020-12-30 11:00:01 -05:00
parent e5b2803b5d
commit 30f8132646
2 changed files with 54 additions and 53 deletions

View File

@ -111,7 +111,7 @@ You need to support 14 first class ops:
Relu, Log, Exp # unary ops
Add, Sub, Mul, Pow # binary ops (with broadcasting)
Sum, Max # reduce ops (with axis argument)
Reshape, Transpose, Slice # moving things around ops
Reshape, Transpose, Slice # movement ops
Matmul, Conv2D # heavy data processing ops
```

View File

@ -2,7 +2,47 @@ import warnings
import numpy as np
from .tensor import Function, register
# ************* basic ops *************
# ************* unary ops *************
class ReLU(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.maximum(input, 0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (input >= 0)
register('relu', ReLU)
class Log(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.log(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output / input
register('log', Log)
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = np.exp(input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return grad_output * ret
register('exp', Exp)
# ************* binary ops *************
def unbroadcast(out, in_sh):
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None
@ -57,6 +97,8 @@ class Pow(Function):
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
register('pow', Pow)
# ************* reduce ops *************
class Sum(Function):
@staticmethod
def forward(ctx, input, axis=None):
@ -87,22 +129,6 @@ class Max(Function):
return ret
register('max', Max)
# ************* GEMM *************
class Matmul(Function):
@staticmethod
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input @ weight
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
return grad_input, grad_weight
register('matmul', Matmul)
# ************* movement ops *************
def inner_slice(x, arg):
@ -147,46 +173,21 @@ class Transpose(Function):
return np.transpose(x, np.argsort(ctx.order))
register('transpose', Transpose)
# ************* activation ops *************
# ************* processing ops *************
class ReLU(Function):
class Matmul(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.maximum(input, 0)
def forward(ctx, input, weight):
ctx.save_for_backward(input, weight)
return input @ weight
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output * (input >= 0)
register('relu', ReLU)
class Log(Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return np.log(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
return grad_output / input
register('log', Log)
class Exp(Function):
@staticmethod
def forward(ctx, input):
ret = np.exp(input)
ctx.save_for_backward(ret)
return ret
@staticmethod
def backward(ctx, grad_output):
ret, = ctx.saved_tensors
return grad_output * ret
register('exp', Exp)
# ************* conv ops *************
input, weight = ctx.saved_tensors
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
return grad_input, grad_weight
register('matmul', Matmul)
class Conv2D(Function):
@staticmethod