From 5062c2c8ff3f0d85b8ace41a31c049bb23fdea6b Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Oct 2020 11:11:00 -0700 Subject: [PATCH] profile conv better --- test/test_conv_speed.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/test/test_conv_speed.py b/test/test_conv_speed.py index a68f6f5d..4e70c2ec 100644 --- a/test/test_conv_speed.py +++ b/test/test_conv_speed.py @@ -1,35 +1,50 @@ #!/usr/bin/env python -import builtins +import time +import cProfile +import pstats +import unittest +from tinygrad.tensor import Tensor + try: import line_profiler prof = line_profiler.LineProfiler() + import builtins builtins.__dict__['profile'] = prof # add @profile decorator to probe except ImportError: prof = None -import cProfile -import unittest -from tinygrad.tensor import Tensor - def profile_conv(bs, chans, conv, cnt=10): img = Tensor.zeros(bs, 1, 28, 28) conv = Tensor.randn(chans, 1, conv, conv) + fpt, bpt = 0.0, 0.0 for i in range(cnt): + et0 = time.time() out = img.conv2d(conv) + et1 = time.time() g = out.mean().backward() + et2 = time.time() + fpt += (et1-et0) + bpt += (et2-et1) + return fpt/cnt, bpt/cnt class TestConvSpeed(unittest.TestCase): def test_forward_backward_3x3(self): - pr = cProfile.Profile() + pr = cProfile.Profile(timer=lambda: int(time.time()*1e9), timeunit=1e-6) pr.enable() - profile_conv(128, 16, 3) + fpt, bpt = profile_conv(128, 16, 3) pr.disable() - pr.print_stats(sort='time') + ps = pstats.Stats(pr) + ps.strip_dirs() + ps.sort_stats('cumtime') + ps.print_stats(0.3) if prof is not None: prof.print_stats() + print("forward pass: %.3f ms" % (fpt*1000)) + print("backward pass: %.3f ms" % (bpt*1000)) + if __name__ == '__main__': unittest.main()