2020-10-24 00:46:10 +08:00
|
|
|
#!/usr/bin/env python
|
2020-10-26 02:13:40 +08:00
|
|
|
import time
|
|
|
|
import unittest
|
2020-10-26 06:27:33 +08:00
|
|
|
import torch
|
2023-12-19 07:53:28 +08:00
|
|
|
from tinygrad import Tensor, Device
|
|
|
|
from tinygrad.helpers import Profiling, CI
|
2020-10-26 02:13:40 +08:00
|
|
|
|
2024-05-16 04:46:08 +08:00
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
2020-10-24 00:46:10 +08:00
|
|
|
class TestConvSpeed(unittest.TestCase):
|
2020-12-16 15:44:08 +08:00
|
|
|
|
2020-10-26 05:20:55 +08:00
|
|
|
def test_mnist(self):
|
|
|
|
# https://keras.io/examples/vision/mnist_convnet/
|
|
|
|
conv = 3
|
|
|
|
inter_chan, out_chan = 32, 64
|
|
|
|
|
2020-10-26 06:52:05 +08:00
|
|
|
# ****** torch baseline *******
|
2020-11-11 23:58:43 +08:00
|
|
|
|
2020-10-26 06:42:33 +08:00
|
|
|
torch.backends.mkldnn.enabled = False
|
|
|
|
|
2020-10-26 06:27:33 +08:00
|
|
|
conv = 3
|
|
|
|
inter_chan, out_chan = 32, 64
|
|
|
|
c1 = torch.randn(inter_chan,1,conv,conv, requires_grad=True)
|
|
|
|
c2 = torch.randn(out_chan,inter_chan,conv,conv, requires_grad=True)
|
|
|
|
l1 = torch.randn(out_chan*5*5, 10, requires_grad=True)
|
|
|
|
|
|
|
|
c2d = torch.nn.functional.conv2d
|
|
|
|
mp = torch.nn.MaxPool2d((2,2))
|
|
|
|
lsm = torch.nn.LogSoftmax(dim=1)
|
2020-10-26 05:20:55 +08:00
|
|
|
|
2020-12-07 01:05:49 +08:00
|
|
|
cnt = 5
|
|
|
|
fpt, bpt = 0.0, 0.0
|
|
|
|
for i in range(cnt):
|
|
|
|
et0 = time.time()
|
|
|
|
x = torch.randn(128, 1, 28, 28, requires_grad=True)
|
|
|
|
x = mp(c2d(x,c1).relu())
|
|
|
|
x = mp(c2d(x,c2).relu())
|
|
|
|
x = x.reshape(x.shape[0], -1)
|
|
|
|
out = lsm(x.matmul(l1))
|
|
|
|
out = out.mean()
|
|
|
|
et1 = time.time()
|
|
|
|
out.backward()
|
|
|
|
et2 = time.time()
|
|
|
|
fpt += (et1-et0)
|
|
|
|
bpt += (et2-et1)
|
2020-10-26 06:27:33 +08:00
|
|
|
|
2020-10-26 06:52:05 +08:00
|
|
|
fpt_baseline = (fpt*1000/cnt)
|
|
|
|
bpt_baseline = (bpt*1000/cnt)
|
|
|
|
print("torch forward pass: %.3f ms" % fpt_baseline)
|
|
|
|
print("torch backward pass: %.3f ms" % bpt_baseline)
|
|
|
|
|
|
|
|
# ****** tinygrad compare *******
|
2020-10-26 06:08:18 +08:00
|
|
|
|
2023-05-26 10:39:42 +08:00
|
|
|
c1 = Tensor(c1.detach().numpy(), requires_grad=True)
|
|
|
|
c2 = Tensor(c2.detach().numpy(), requires_grad=True)
|
|
|
|
l1 = Tensor(l1.detach().numpy(), requires_grad=True)
|
2020-10-26 06:52:05 +08:00
|
|
|
|
|
|
|
cnt = 5
|
|
|
|
fpt, bpt = 0.0, 0.0
|
|
|
|
for i in range(1+cnt):
|
|
|
|
et0 = time.time()
|
2021-01-01 22:19:03 +08:00
|
|
|
x = Tensor.randn(128, 1, 28, 28)
|
2020-10-26 10:01:02 +08:00
|
|
|
x = x.conv2d(c1).relu().avg_pool2d()
|
2020-10-26 08:16:47 +08:00
|
|
|
x = x.conv2d(c2).relu().max_pool2d()
|
2020-10-29 23:19:07 +08:00
|
|
|
x = x.reshape(shape=(x.shape[0], -1))
|
2023-02-25 02:11:24 +08:00
|
|
|
out = x.dot(l1).log_softmax()
|
2020-10-26 06:52:05 +08:00
|
|
|
out = out.mean()
|
2023-05-26 10:39:42 +08:00
|
|
|
out.realize()
|
2020-10-26 06:52:05 +08:00
|
|
|
et1 = time.time()
|
|
|
|
out.backward()
|
2023-05-26 10:39:42 +08:00
|
|
|
[x.grad.realize() for x in [c1, c2, l1]]
|
2020-10-26 06:52:05 +08:00
|
|
|
et2 = time.time()
|
|
|
|
if i == 0:
|
2023-11-17 06:15:56 +08:00
|
|
|
pr = Profiling(sort='time', frac=0.2)
|
|
|
|
pr.__enter__()
|
2020-10-26 06:52:05 +08:00
|
|
|
else:
|
|
|
|
fpt += (et1-et0)
|
|
|
|
bpt += (et2-et1)
|
|
|
|
|
2023-11-17 06:15:56 +08:00
|
|
|
pr.__exit__()
|
2020-10-26 06:52:05 +08:00
|
|
|
fpt = (fpt*1000/cnt)
|
|
|
|
bpt = (bpt*1000/cnt)
|
|
|
|
print("forward pass: %.3f ms, %.2fx off baseline %.3f ms" % (fpt, fpt/fpt_baseline, fpt_baseline))
|
|
|
|
print("backward pass: %.3f ms, %.2fx off baseline %.3f ms" % (bpt, bpt/bpt_baseline, bpt_baseline))
|
2020-10-26 06:42:33 +08:00
|
|
|
|
2020-10-26 05:20:55 +08:00
|
|
|
|
2020-10-24 00:46:10 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|