add profiling for mnist net

This commit is contained in:
George Hotz 2020-10-25 14:20:55 -07:00
parent 8fcada8071
commit 5c179d18ad
2 changed files with 43 additions and 11 deletions

View File

@ -14,6 +14,7 @@ import time
import cProfile import cProfile
import pstats import pstats
import unittest import unittest
import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
def profile_conv(bs, chans, conv, cnt=10): def profile_conv(bs, chans, conv, cnt=10):
@ -30,27 +31,55 @@ def profile_conv(bs, chans, conv, cnt=10):
bpt += (et2-et1) bpt += (et2-et1)
return fpt/cnt, bpt/cnt return fpt/cnt, bpt/cnt
def start_profile():
import time
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
pr.enable()
return pr
def stop_profile(pr, sort='cumtime'):
pr.disable()
ps = pstats.Stats(pr)
ps.strip_dirs()
ps.sort_stats(sort)
ps.print_stats(0.3)
if prof is not None:
prof.print_stats()
class TestConvSpeed(unittest.TestCase): class TestConvSpeed(unittest.TestCase):
def test_forward_backward_3x3(self): def test_forward_backward_3x3(self):
# warmup # warmup
profile_conv(128, 16, 3, cnt=1) profile_conv(128, 16, 3, cnt=1)
# profile pr = start_profile()
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
pr.enable()
fpt, bpt = profile_conv(128, 16, 3) fpt, bpt = profile_conv(128, 16, 3)
pr.disable() stop_profile(pr)
ps = pstats.Stats(pr)
ps.strip_dirs()
ps.sort_stats('cumtime')
ps.print_stats(0.3)
if prof is not None:
prof.print_stats()
print("forward pass: %.3f ms" % (fpt*1000)) print("forward pass: %.3f ms" % (fpt*1000))
print("backward pass: %.3f ms" % (bpt*1000)) print("backward pass: %.3f ms" % (bpt*1000))
def test_mnist(self):
# https://keras.io/examples/vision/mnist_convnet/
conv = 3
inter_chan, out_chan = 32, 64
c1 = Tensor.randn(inter_chan,1,conv,conv)
c2 = Tensor.randn(out_chan,inter_chan,conv,conv)
l1 = Tensor.randn(out_chan*5*5, 10)
for i in range(6):
x = Tensor.randn(128, 1, 28, 28)
x = x.conv2d(c1).relu().maxpool2x2()
x = x.conv2d(c2).relu().maxpool2x2()
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
out = x.dot(l1).logsoftmax()
out.mean().backward()
if i == 0:
pr = start_profile()
stop_profile(pr, sort='time')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -68,6 +68,9 @@ def im2col(x, H, W):
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1) tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
""" """
# all the time is spent here
tx = tx.ravel()
return tx.reshape(-1, cin*W*H) return tx.reshape(-1, cin*W*H)
def col2im(tx, H, W, OY, OX): def col2im(tx, H, W, OY, OX):