rename max_pool2d to match torch, remove more fast conv crap

This commit is contained in:
George Hotz 2020-10-25 17:16:47 -07:00
parent 5d1373c71b
commit 567707a5f6
4 changed files with 10 additions and 54 deletions

View File

@ -1,15 +1,4 @@
#!/usr/bin/env python
# if you'd like to use the line profiler
try:
import line_profiler
prof = line_profiler.LineProfiler()
import builtins
builtins.__dict__['profile'] = prof
# add @profile decorator to probe
except ImportError:
prof = None
import time
import cProfile
import pstats
@ -18,20 +7,6 @@ import numpy as np
import torch
from tinygrad.tensor import Tensor
def profile_conv(bs, chans, conv, cnt=10):
img = Tensor.zeros(bs, 1, 28, 28)
conv = Tensor.randn(chans, 1, conv, conv)
fpt, bpt = 0.0, 0.0
for i in range(cnt):
et0 = time.time()
out = img.conv2d(conv)
et1 = time.time()
g = out.mean().backward()
et2 = time.time()
fpt += (et1-et0)
bpt += (et2-et1)
return fpt/cnt, bpt/cnt
def start_profile():
import time
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
@ -45,27 +20,12 @@ def stop_profile(pr, sort='cumtime'):
ps.sort_stats(sort)
ps.print_stats(0.2)
if prof is not None:
prof.print_stats()
class TestConvSpeed(unittest.TestCase):
def test_forward_backward_3x3(self):
# warmup
profile_conv(128, 16, 3, cnt=1)
pr = start_profile()
fpt, bpt = profile_conv(128, 16, 3)
stop_profile(pr)
print("forward pass: %.3f ms" % (fpt*1000))
print("backward pass: %.3f ms" % (bpt*1000))
def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
inter_chan, out_chan = 32, 64
# ****** torch baseline *******
torch.backends.mkldnn.enabled = False
@ -83,7 +43,7 @@ class TestConvSpeed(unittest.TestCase):
with torch.autograd.profiler.profile(record_shapes=True) as tprof:
cnt = 5
fpt, bpt = 0.0, 0.0
for i in range(1+cnt):
for i in range(cnt):
et0 = time.time()
x = torch.randn(128, 1, 28, 28, requires_grad=True)
x = mp(c2d(x,c1).relu())
@ -94,13 +54,9 @@ class TestConvSpeed(unittest.TestCase):
et1 = time.time()
out.backward()
et2 = time.time()
if i == 0:
pr = start_profile()
else:
fpt += (et1-et0)
bpt += (et2-et1)
fpt += (et1-et0)
bpt += (et2-et1)
stop_profile(pr, sort='time')
fpt_baseline = (fpt*1000/cnt)
bpt_baseline = (bpt*1000/cnt)
print("torch forward pass: %.3f ms" % fpt_baseline)
@ -119,8 +75,8 @@ class TestConvSpeed(unittest.TestCase):
for i in range(1+cnt):
et0 = time.time()
x = Tensor.randn(128, 1, 28, 28)
x = x.conv2d(c1).relu().maxpool2x2()
x = x.conv2d(c2).relu().maxpool2x2()
x = x.conv2d(c1).relu().max_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
out = x.dot(l1).logsoftmax()
out = out.mean()

View File

@ -32,8 +32,8 @@ class TinyConvNet:
def forward(self, x):
x.data = x.data.reshape((-1, 1, 28, 28)) # hacks
x = x.conv2d(self.c1).relu().maxpool2x2()
x = x.conv2d(self.c2).relu().maxpool2x2()
x = x.conv2d(self.c1).relu().max_pool2d()
x = x.conv2d(self.c2).relu().max_pool2d()
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
return x.dot(self.l1).logsoftmax()

View File

@ -90,7 +90,7 @@ class TestOps(unittest.TestCase):
xt = Tensor(x.detach().numpy())
# in tinygrad
ret = xt.maxpool2x2()
ret = xt.max_pool2d()
assert ret.shape == (5,2,10//2,8//2)
ret.mean().backward()

View File

@ -124,7 +124,7 @@ class Conv2D(Function):
return dx, dw
register('conv2d', Conv2D)
class MaxPool2x2(Function):
class MaxPool2D(Function):
@staticmethod
def forward(ctx, x):
my, mx = (x.shape[2]//2)*2, (x.shape[3]//2)*2
@ -147,5 +147,5 @@ class MaxPool2x2(Function):
for X in range(2):
ret[:, :, Y:my:2, X:mx:2] = grad_output * (idxs == (Y*2+X))
return ret
register('maxpool2x2', MaxPool2x2)
register('max_pool2d', MaxPool2D)