tinygrad/test/test_speed_v_torch.py

56 lines
2.0 KiB
Python
Raw Normal View History

2022-10-29 02:22:15 +08:00
import os
2022-10-11 07:06:00 +08:00
import unittest
import torch
import time
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
2022-10-29 02:22:15 +08:00
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "1,16,64").split(",")]
2022-10-11 07:06:00 +08:00
CNT = 5
class TestSpeedVTorch(unittest.TestCase):
def test_conv2d(self):
torch.manual_seed(0)
for bs in [32]:
2022-10-29 02:22:15 +08:00
for in_chans in IN_CHANS:
2022-10-11 07:13:34 +08:00
for out_chans in [64]:
2022-10-11 07:06:00 +08:00
device = 'cuda' if torch.cuda.is_available() else 'cpu'
img_size = 64 if device == 'cuda' else 32
src = torch.rand(bs, in_chans, img_size, img_size)
dat = src.clone().to(device)
2022-10-29 02:22:15 +08:00
src_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None)
2022-10-11 07:06:00 +08:00
conv = src_conv.to(device)
with torch.no_grad():
val_torch = conv(dat).cpu().numpy().sum()
ets_torch = []
for _ in range(CNT):
dat += 1
st = time.monotonic()
val_torch = conv(dat).cpu().numpy().sum()
et_torch = (time.monotonic() - st) * 1000
ets_torch.append(et_torch)
Tensor.no_grad = False
dat = Tensor(src.numpy())
2022-10-29 02:22:15 +08:00
conv = Conv2d(in_chans, out_chans, 3, bias=None)
2022-10-11 07:06:00 +08:00
conv.weight = Tensor(src_conv.weight.detach().cpu().numpy())
val_tinygrad = conv(dat).numpy().sum()
ets_tinygrad = []
for _ in range(CNT):
dat += 1
dat.realize()
st = time.monotonic()
val_tinygrad = conv(dat).numpy().sum()
et_tinygrad = (time.monotonic() - st) * 1000
ets_tinygrad.append(et_tinygrad)
et_torch = np.median(ets_torch)
et_tinygrad = np.median(ets_tinygrad)
print(f"bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} {et_torch:7.2f} ms in torch({device}), {et_tinygrad:7.2f} ms in tinygrad, {et_tinygrad/et_torch:7.2f}x slower", val_torch, val_tinygrad)
relative_error = abs((val_tinygrad-val_torch)/val_torch)
assert relative_error < 0.01
if __name__ == '__main__':
unittest.main()