Merge remote-tracking branch 'upstream/master' into pytest-again

This commit is contained in:
Adrian Garcia Badaracco 2020-10-21 13:19:58 -05:00
commit 58b4f191a4
No known key found for this signature in database
GPG Key ID: D3DCF9EA27A66543
1 changed files with 20 additions and 4 deletions

View File

@ -2,6 +2,12 @@
from functools import partialmethod
import numpy as np
# optional jit
try:
from numba import jit
except ImportError:
jit = lambda x: x
# **** start with two base classes ****
class Tensor:
@ -170,10 +176,11 @@ class LogSoftmax(Function):
return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1))
register('logsoftmax', LogSoftmax)
class Conv2D(Function):
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
@jit
def inner_forward(x, w):
cout,cin,H,W = w.shape
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
for j in range(H):
@ -185,8 +192,8 @@ class Conv2D(Function):
return ret
@staticmethod
def backward(ctx, grad_output):
x, w = ctx.saved_tensors
@jit
def inner_backward(grad_output, x, w):
dx = np.zeros_like(x)
dw = np.zeros_like(w)
cout,cin,H,W = w.shape
@ -200,5 +207,14 @@ class Conv2D(Function):
dx[:, :, Y+j, X+i] += gg.dot(tw)
dw[:, :, j, i] += gg.T.dot(tx)
return dx, dw
@staticmethod
def forward(ctx, x, w):
ctx.save_for_backward(x, w)
return Conv2D.inner_forward(x, w)
@staticmethod
def backward(ctx, grad_output):
return Conv2D.inner_backward(grad_output, *ctx.saved_tensors)
register('conv2d', Conv2D)