mirror of https://github.com/commaai/tinygrad.git
use numba to double conv speed
This commit is contained in:
parent
dc325af392
commit
a68ead09c0
|
@ -73,7 +73,7 @@ for i in (t := trange(steps)):
|
|||
|
||||
# evaluate
|
||||
def numpy_eval():
|
||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28))))
|
||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32)))
|
||||
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from functools import partialmethod
|
||||
import numpy as np
|
||||
from numba import jit, float32
|
||||
|
||||
# **** start with two base classes ****
|
||||
|
||||
|
@ -170,10 +171,9 @@ class LogSoftmax(Function):
|
|||
return grad_output - np.exp(output)*grad_output.sum(axis=1).reshape((-1, 1))
|
||||
register('logsoftmax', LogSoftmax)
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
ctx.save_for_backward(x, w)
|
||||
|
||||
@jit(nopython=True)
|
||||
def conv2d_inner_forward(x, w):
|
||||
cout,cin,H,W = w.shape
|
||||
ret = np.zeros((x.shape[0], cout, x.shape[2]-(H-1), x.shape[3]-(W-1)), dtype=w.dtype)
|
||||
for j in range(H):
|
||||
|
@ -184,9 +184,8 @@ class Conv2D(Function):
|
|||
ret[:, :, Y, X] += x[:, :, Y+j, X+i].dot(tw.T)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, w = ctx.saved_tensors
|
||||
@jit(nopython=True)
|
||||
def conv2d_inner_backward(grad_output, x, w):
|
||||
dx = np.zeros_like(x)
|
||||
dw = np.zeros_like(w)
|
||||
cout,cin,H,W = w.shape
|
||||
|
@ -200,5 +199,15 @@ class Conv2D(Function):
|
|||
dx[:, :, Y+j, X+i] += gg.dot(tw)
|
||||
dw[:, :, j, i] += gg.T.dot(tx)
|
||||
return dx, dw
|
||||
|
||||
class Conv2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, w):
|
||||
ctx.save_for_backward(x, w)
|
||||
return conv2d_inner_forward(x, w)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return conv2d_inner_backward(grad_output, *ctx.saved_tensors)
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
|
|
Loading…
Reference in New Issue