profile conv better

This commit is contained in:
George Hotz 2020-10-25 11:11:00 -07:00
parent c74764bac3
commit 5062c2c8ff
1 changed files with 23 additions and 8 deletions

View File

@ -1,35 +1,50 @@
#!/usr/bin/env python
import builtins
import time
import cProfile
import pstats
import unittest
from tinygrad.tensor import Tensor
try:
import line_profiler
prof = line_profiler.LineProfiler()
import builtins
builtins.__dict__['profile'] = prof
# add @profile decorator to probe
except ImportError:
prof = None
import cProfile
import unittest
from tinygrad.tensor import Tensor
def profile_conv(bs, chans, conv, cnt=10):
img = Tensor.zeros(bs, 1, 28, 28)
conv = Tensor.randn(chans, 1, conv, conv)
fpt, bpt = 0.0, 0.0
for i in range(cnt):
et0 = time.time()
out = img.conv2d(conv)
et1 = time.time()
g = out.mean().backward()
et2 = time.time()
fpt += (et1-et0)
bpt += (et2-et1)
return fpt/cnt, bpt/cnt
class TestConvSpeed(unittest.TestCase):
def test_forward_backward_3x3(self):
pr = cProfile.Profile()
pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6)
pr.enable()
profile_conv(128, 16, 3)
fpt, bpt = profile_conv(128, 16, 3)
pr.disable()
pr.print_stats(sort='time')
ps = pstats.Stats(pr)
ps.strip_dirs()
ps.sort_stats('cumtime')
ps.print_stats(0.3)
if prof is not None:
prof.print_stats()
print("forward pass: %.3f ms" % (fpt*1000))
print("backward pass: %.3f ms" % (bpt*1000))
if __name__ == '__main__':
unittest.main()