mirror of https://github.com/commaai/tinygrad.git
rename max_pool2d to match torch, remove more fast conv crap
This commit is contained in:
parent
5d1373c71b
commit
567707a5f6
|
@ -1,15 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# if you'd like to use the line profiler
|
||||
try:
|
||||
import line_profiler
|
||||
prof = line_profiler.LineProfiler()
|
||||
import builtins
|
||||
builtins.__dict__['profile'] = prof
|
||||
# add @profile decorator to probe
|
||||
except ImportError:
|
||||
prof = None
|
||||
|
||||
import time
|
||||
import cProfile
|
||||
import pstats
|
||||
|
@ -18,20 +7,6 @@ import numpy as np
|
|||
import torch
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def profile_conv(bs, chans, conv, cnt=10):
|
||||
img = Tensor.zeros(bs, 1, 28, 28)
|
||||
conv = Tensor.randn(chans, 1, conv, conv)
|
||||
fpt, bpt = 0.0, 0.0
|
||||
for i in range(cnt):
|
||||
et0 = time.time()
|
||||
out = img.conv2d(conv)
|
||||
et1 = time.time()
|
||||
g = out.mean().backward()
|
||||
et2 = time.time()
|
||||
fpt += (et1-et0)
|
||||
bpt += (et2-et1)
|
||||
return fpt/cnt, bpt/cnt
|
||||
|
||||
def start_profile():
|
||||
import time
|
||||
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
|
||||
|
@ -45,27 +20,12 @@ def stop_profile(pr, sort='cumtime'):
|
|||
ps.sort_stats(sort)
|
||||
ps.print_stats(0.2)
|
||||
|
||||
if prof is not None:
|
||||
prof.print_stats()
|
||||
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
def test_forward_backward_3x3(self):
|
||||
# warmup
|
||||
profile_conv(128, 16, 3, cnt=1)
|
||||
|
||||
pr = start_profile()
|
||||
fpt, bpt = profile_conv(128, 16, 3)
|
||||
stop_profile(pr)
|
||||
|
||||
print("forward pass: %.3f ms" % (fpt*1000))
|
||||
print("backward pass: %.3f ms" % (bpt*1000))
|
||||
|
||||
def test_mnist(self):
|
||||
# https://keras.io/examples/vision/mnist_convnet/
|
||||
conv = 3
|
||||
inter_chan, out_chan = 32, 64
|
||||
|
||||
|
||||
# ****** torch baseline *******
|
||||
|
||||
torch.backends.mkldnn.enabled = False
|
||||
|
@ -83,7 +43,7 @@ class TestConvSpeed(unittest.TestCase):
|
|||
with torch.autograd.profiler.profile(record_shapes=True) as tprof:
|
||||
cnt = 5
|
||||
fpt, bpt = 0.0, 0.0
|
||||
for i in range(1+cnt):
|
||||
for i in range(cnt):
|
||||
et0 = time.time()
|
||||
x = torch.randn(128, 1, 28, 28, requires_grad=True)
|
||||
x = mp(c2d(x,c1).relu())
|
||||
|
@ -94,13 +54,9 @@ class TestConvSpeed(unittest.TestCase):
|
|||
et1 = time.time()
|
||||
out.backward()
|
||||
et2 = time.time()
|
||||
if i == 0:
|
||||
pr = start_profile()
|
||||
else:
|
||||
fpt += (et1-et0)
|
||||
bpt += (et2-et1)
|
||||
|
||||
stop_profile(pr, sort='time')
|
||||
fpt_baseline = (fpt*1000/cnt)
|
||||
bpt_baseline = (bpt*1000/cnt)
|
||||
print("torch forward pass: %.3f ms" % fpt_baseline)
|
||||
|
@ -119,8 +75,8 @@ class TestConvSpeed(unittest.TestCase):
|
|||
for i in range(1+cnt):
|
||||
et0 = time.time()
|
||||
x = Tensor.randn(128, 1, 28, 28)
|
||||
x = x.conv2d(c1).relu().maxpool2x2()
|
||||
x = x.conv2d(c2).relu().maxpool2x2()
|
||||
x = x.conv2d(c1).relu().max_pool2d()
|
||||
x = x.conv2d(c2).relu().max_pool2d()
|
||||
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||
out = x.dot(l1).logsoftmax()
|
||||
out = out.mean()
|
||||
|
|
|
@ -32,8 +32,8 @@ class TinyConvNet:
|
|||
|
||||
def forward(self, x):
|
||||
x.data = x.data.reshape((-1, 1, 28, 28)) # hacks
|
||||
x = x.conv2d(self.c1).relu().maxpool2x2()
|
||||
x = x.conv2d(self.c2).relu().maxpool2x2()
|
||||
x = x.conv2d(self.c1).relu().max_pool2d()
|
||||
x = x.conv2d(self.c2).relu().max_pool2d()
|
||||
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||
return x.dot(self.l1).logsoftmax()
|
||||
|
||||
|
|
|
@ -90,7 +90,7 @@ class TestOps(unittest.TestCase):
|
|||
xt = Tensor(x.detach().numpy())
|
||||
|
||||
# in tinygrad
|
||||
ret = xt.maxpool2x2()
|
||||
ret = xt.max_pool2d()
|
||||
assert ret.shape == (5,2,10//2,8//2)
|
||||
ret.mean().backward()
|
||||
|
||||
|
|
|
@ -124,7 +124,7 @@ class Conv2D(Function):
|
|||
return dx, dw
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
class MaxPool2x2(Function):
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
my, mx = (x.shape[2]//2)*2, (x.shape[3]//2)*2
|
||||
|
@ -147,5 +147,5 @@ class MaxPool2x2(Function):
|
|||
for X in range(2):
|
||||
ret[:, :, Y:my:2, X:mx:2] = grad_output * (idxs == (Y*2+X))
|
||||
return ret
|
||||
register('maxpool2x2', MaxPool2x2)
|
||||
register('max_pool2d', MaxPool2D)
|
||||
|
||||
|
|
Loading…
Reference in New Issue