mirror of https://github.com/commaai/tinygrad.git
add profiling for mnist net
This commit is contained in:
parent
8fcada8071
commit
5c179d18ad
|
@ -14,6 +14,7 @@ import time
|
|||
import cProfile
|
||||
import pstats
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def profile_conv(bs, chans, conv, cnt=10):
|
||||
|
@ -30,27 +31,55 @@ def profile_conv(bs, chans, conv, cnt=10):
|
|||
bpt += (et2-et1)
|
||||
return fpt/cnt, bpt/cnt
|
||||
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
def test_forward_backward_3x3(self):
|
||||
# warmup
|
||||
profile_conv(128, 16, 3, cnt=1)
|
||||
|
||||
# profile
|
||||
def start_profile():
|
||||
import time
|
||||
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
|
||||
pr.enable()
|
||||
fpt, bpt = profile_conv(128, 16, 3)
|
||||
return pr
|
||||
|
||||
def stop_profile(pr, sort='cumtime'):
|
||||
pr.disable()
|
||||
ps = pstats.Stats(pr)
|
||||
ps.strip_dirs()
|
||||
ps.sort_stats('cumtime')
|
||||
ps.sort_stats(sort)
|
||||
ps.print_stats(0.3)
|
||||
|
||||
if prof is not None:
|
||||
prof.print_stats()
|
||||
|
||||
class TestConvSpeed(unittest.TestCase):
|
||||
def test_forward_backward_3x3(self):
|
||||
# warmup
|
||||
profile_conv(128, 16, 3, cnt=1)
|
||||
|
||||
pr = start_profile()
|
||||
fpt, bpt = profile_conv(128, 16, 3)
|
||||
stop_profile(pr)
|
||||
|
||||
print("forward pass: %.3f ms" % (fpt*1000))
|
||||
print("backward pass: %.3f ms" % (bpt*1000))
|
||||
|
||||
def test_mnist(self):
|
||||
# https://keras.io/examples/vision/mnist_convnet/
|
||||
conv = 3
|
||||
inter_chan, out_chan = 32, 64
|
||||
c1 = Tensor.randn(inter_chan,1,conv,conv)
|
||||
c2 = Tensor.randn(out_chan,inter_chan,conv,conv)
|
||||
l1 = Tensor.randn(out_chan*5*5, 10)
|
||||
|
||||
for i in range(6):
|
||||
x = Tensor.randn(128, 1, 28, 28)
|
||||
x = x.conv2d(c1).relu().maxpool2x2()
|
||||
x = x.conv2d(c2).relu().maxpool2x2()
|
||||
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||
out = x.dot(l1).logsoftmax()
|
||||
out.mean().backward()
|
||||
if i == 0:
|
||||
pr = start_profile()
|
||||
|
||||
stop_profile(pr, sort='time')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -68,6 +68,9 @@ def im2col(x, H, W):
|
|||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||
"""
|
||||
|
||||
# all the time is spent here
|
||||
tx = tx.ravel()
|
||||
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
||||
def col2im(tx, H, W, OY, OX):
|
||||
|
|
Loading…
Reference in New Issue