mirror of https://github.com/commaai/tinygrad.git
add profiling for mnist net
This commit is contained in:
parent
8fcada8071
commit
5c179d18ad
|
@ -14,6 +14,7 @@ import time
|
||||||
import cProfile
|
import cProfile
|
||||||
import pstats
|
import pstats
|
||||||
import unittest
|
import unittest
|
||||||
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
def profile_conv(bs, chans, conv, cnt=10):
|
def profile_conv(bs, chans, conv, cnt=10):
|
||||||
|
@ -30,27 +31,55 @@ def profile_conv(bs, chans, conv, cnt=10):
|
||||||
bpt += (et2-et1)
|
bpt += (et2-et1)
|
||||||
return fpt/cnt, bpt/cnt
|
return fpt/cnt, bpt/cnt
|
||||||
|
|
||||||
|
def start_profile():
|
||||||
|
import time
|
||||||
|
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
|
||||||
|
pr.enable()
|
||||||
|
return pr
|
||||||
|
|
||||||
|
def stop_profile(pr, sort='cumtime'):
|
||||||
|
pr.disable()
|
||||||
|
ps = pstats.Stats(pr)
|
||||||
|
ps.strip_dirs()
|
||||||
|
ps.sort_stats(sort)
|
||||||
|
ps.print_stats(0.3)
|
||||||
|
|
||||||
|
if prof is not None:
|
||||||
|
prof.print_stats()
|
||||||
|
|
||||||
class TestConvSpeed(unittest.TestCase):
|
class TestConvSpeed(unittest.TestCase):
|
||||||
def test_forward_backward_3x3(self):
|
def test_forward_backward_3x3(self):
|
||||||
# warmup
|
# warmup
|
||||||
profile_conv(128, 16, 3, cnt=1)
|
profile_conv(128, 16, 3, cnt=1)
|
||||||
|
|
||||||
# profile
|
pr = start_profile()
|
||||||
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
|
|
||||||
pr.enable()
|
|
||||||
fpt, bpt = profile_conv(128, 16, 3)
|
fpt, bpt = profile_conv(128, 16, 3)
|
||||||
pr.disable()
|
stop_profile(pr)
|
||||||
ps = pstats.Stats(pr)
|
|
||||||
ps.strip_dirs()
|
|
||||||
ps.sort_stats('cumtime')
|
|
||||||
ps.print_stats(0.3)
|
|
||||||
|
|
||||||
if prof is not None:
|
|
||||||
prof.print_stats()
|
|
||||||
|
|
||||||
print("forward pass: %.3f ms" % (fpt*1000))
|
print("forward pass: %.3f ms" % (fpt*1000))
|
||||||
print("backward pass: %.3f ms" % (bpt*1000))
|
print("backward pass: %.3f ms" % (bpt*1000))
|
||||||
|
|
||||||
|
def test_mnist(self):
|
||||||
|
# https://keras.io/examples/vision/mnist_convnet/
|
||||||
|
conv = 3
|
||||||
|
inter_chan, out_chan = 32, 64
|
||||||
|
c1 = Tensor.randn(inter_chan,1,conv,conv)
|
||||||
|
c2 = Tensor.randn(out_chan,inter_chan,conv,conv)
|
||||||
|
l1 = Tensor.randn(out_chan*5*5, 10)
|
||||||
|
|
||||||
|
for i in range(6):
|
||||||
|
x = Tensor.randn(128, 1, 28, 28)
|
||||||
|
x = x.conv2d(c1).relu().maxpool2x2()
|
||||||
|
x = x.conv2d(c2).relu().maxpool2x2()
|
||||||
|
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
|
||||||
|
out = x.dot(l1).logsoftmax()
|
||||||
|
out.mean().backward()
|
||||||
|
if i == 0:
|
||||||
|
pr = start_profile()
|
||||||
|
|
||||||
|
stop_profile(pr, sort='time')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
||||||
|
|
|
@ -68,6 +68,9 @@ def im2col(x, H, W):
|
||||||
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
tx[:, Y, X] = x[:, :, Y:Y+H, X:X+W].reshape(bs, -1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# all the time is spent here
|
||||||
|
tx = tx.ravel()
|
||||||
|
|
||||||
return tx.reshape(-1, cin*W*H)
|
return tx.reshape(-1, cin*W*H)
|
||||||
|
|
||||||
def col2im(tx, H, W, OY, OX):
|
def col2im(tx, H, W, OY, OX):
|
||||||
|
|
Loading…
Reference in New Issue