mirror of https://github.com/commaai/tinygrad.git
reorder ops in ops cpu
This commit is contained in:
parent
e5b2803b5d
commit
30f8132646
|
@ -111,7 +111,7 @@ You need to support 14 first class ops:
|
|||
Relu, Log, Exp # unary ops
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Reshape, Transpose, Slice # moving things around ops
|
||||
Reshape, Transpose, Slice # movement ops
|
||||
Matmul, Conv2D # heavy data processing ops
|
||||
```
|
||||
|
||||
|
|
|
@ -2,7 +2,47 @@ import warnings
|
|||
import numpy as np
|
||||
from .tensor import Function, register
|
||||
|
||||
# ************* basic ops *************
|
||||
# ************* unary ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.maximum(input, 0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * (input >= 0)
|
||||
register('relu', ReLU)
|
||||
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.log(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output / input
|
||||
register('log', Log)
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = np.exp(input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
register('exp', Exp)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
def unbroadcast(out, in_sh):
|
||||
# adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i]
|
||||
sum_axis = tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1]) if in_sh != (1,) else None
|
||||
|
@ -57,6 +97,8 @@ class Pow(Function):
|
|||
unbroadcast((x**y) * np.log(x) * grad_output, y.shape)
|
||||
register('pow', Pow)
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
|
@ -87,22 +129,6 @@ class Max(Function):
|
|||
return ret
|
||||
register('max', Max)
|
||||
|
||||
# ************* GEMM *************
|
||||
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input @ weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
|
||||
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
|
||||
return grad_input, grad_weight
|
||||
register('matmul', Matmul)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
def inner_slice(x, arg):
|
||||
|
@ -147,46 +173,21 @@ class Transpose(Function):
|
|||
return np.transpose(x, np.argsort(ctx.order))
|
||||
register('transpose', Transpose)
|
||||
|
||||
# ************* activation ops *************
|
||||
# ************* processing ops *************
|
||||
|
||||
class ReLU(Function):
|
||||
class Matmul(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.maximum(input, 0)
|
||||
def forward(ctx, input, weight):
|
||||
ctx.save_for_backward(input, weight)
|
||||
return input @ weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output * (input >= 0)
|
||||
register('relu', ReLU)
|
||||
|
||||
class Log(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return np.log(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
return grad_output / input
|
||||
register('log', Log)
|
||||
|
||||
class Exp(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
ret = np.exp(input)
|
||||
ctx.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
ret, = ctx.saved_tensors
|
||||
return grad_output * ret
|
||||
register('exp', Exp)
|
||||
|
||||
# ************* conv ops *************
|
||||
input, weight = ctx.saved_tensors
|
||||
grad_input = grad_output @ np.swapaxes(weight, -2, -1)
|
||||
grad_weight = np.swapaxes(input, -2, -1) @ grad_output
|
||||
return grad_input, grad_weight
|
||||
register('matmul', Matmul)
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
|
|
Loading…
Reference in New Issue