tinygrad/test/test_speed_v_torch.py

267 lines
10 KiB
Python

import os
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
import unittest
import torch
torch.set_num_threads(1)
import time
import numpy as np
np.set_printoptions(linewidth=160)
from functools import partial
from tinygrad.lazy import Device
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, DEBUG, CI
from tinygrad.jit import TinyJit
import pytest
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.exclude_gpu, pytest.mark.exclude_clang, pytest.mark.webgpu]
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
if str(torch_device) == "mps":
import torch.mps
sync = lambda: torch.mps.synchronize()
elif str(torch_device) == "cuda":
import torch.cuda
sync = lambda: torch.cuda.synchronize()
else:
sync = lambda: None
def colorize_float(x):
ret = f"{x:7.2f}x"
if x < 0.75:
return colored(ret, 'green')
elif x > 1.15:
return colored(ret, 'red')
else:
return colored(ret, 'yellow')
save_ops, save_mem = 0, 0
CNT = 8
def helper_test_speed(f1, *args):
global save_ops, save_mem
ets = []
ret = None
cache_defeat = np.zeros((2048,2048))
for i in range(CNT):
del ret
# operation cache defeats
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
# force syncing
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
# clear 32MB global memory cache (CPU and global memory only)
cache_defeat += 1
# manual pre sync
if isinstance(args[0], Tensor): Device[args[0].device].synchronize()
else: sync()
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
st = time.perf_counter()
ret = f1(*args)
if isinstance(ret, Tensor): Device[ret.device].synchronize()
else: sync()
et = (time.perf_counter() - st) * 1000
if i >= 1: ets.append(et)
if GlobalCounters.global_ops:
save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem
return ret.cpu().numpy(), np.min(ets)
def helper_test_generic_square(name, N, f1, f2, onearg=False):
torch.manual_seed(0)
torch_a = (torch.rand(N, N) - 0.5).to(torch_device)
torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None
tiny_a = Tensor(torch_a.cpu().numpy())
tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None
helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b))
prefix = None
def helper_test_generic(name, f1, f1_args, f2, f2_args):
global prefix
with torch.no_grad():
val_torch, et_torch = helper_test_speed(f1, *f1_args)
val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args)
desc = "faster" if et_torch > et_tinygrad else "slower"
flops = save_ops*1e-6
mem = save_mem*1e-6
if not CI: print(f"\r{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB")
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
torch.manual_seed(0)
torch_dat = torch.rand(bs, in_chans, img_size_y, img_size_x).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, kernel_size, bias=None).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, kernel_size, bias=None)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(torch_dat): return torch_conv(torch_dat)
def f2(tiny_dat): return tiny_conv(tiny_dat).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:{kernel_size}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
@unittest.skipIf(getenv("BIG") != 1, "no big tests")
class TestBigSpeed(unittest.TestCase):
def test_add(self):
def f(a, b): return a+b
helper_test_generic_square('add', 8192, f, f)
def test_exp(self):
def f(a, b): return a.exp()
helper_test_generic_square('exp', 8192, f, f, onearg=True)
def test_gemm_2048(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 2048, f, f)
def test_gemm_4096(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 4096, f, f)
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130)
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
class TestSpeed(unittest.TestCase):
def test_sub(self):
def f(a, b): return a-b
helper_test_generic_square('sub', 4096, f, f)
def test_pow(self):
def f(a, b): return a.pow(b)
helper_test_generic_square('pow', 2048, f, f)
def test_sum(self):
def f(a, b): return a.sum()
helper_test_generic_square('sum', 2048, f, f, onearg=True)
helper_test_generic_square('sum', 4096, f, f, onearg=True)
def test_partial_sum(self):
R = 256
def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1)
helper_test_generic_square('partial_sum', 4096, f, f, onearg=True)
@unittest.skip("not really used in models")
def test_cumsum(self):
def f0(a, b): return a.cumsum(axis=0)
def f1(a, b): return a.cumsum(axis=1)
helper_test_generic_square('cumsum_0', 256, f0, f0, onearg=True)
helper_test_generic_square('cumsum_1', 256, f1, f1, onearg=True)
def test_array_packing(self):
N = 2048
def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous()
helper_test_generic_square('array_packing', N, f, f, onearg=True)
def test_permute(self):
for N in [1024, 4096]:
# this is a 64MB tensor, M1 L1 cache is 128kB
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f(a, b): return a.permute(1,0).contiguous()
helper_test_generic_square('permute', N, f, f, onearg=True)
def test_double_permute(self):
N = 64
torch.manual_seed(0)
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
def f(a): return a.permute(1,0,3,2).contiguous()
helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,))
def test_neg(self):
def f(a, b): return -a
helper_test_generic_square('neg', 4096, f, f, onearg=True)
def test_exp(self):
def f(a, b): return a.exp()
helper_test_generic_square('exp', 2048, f, f, onearg=True)
def test_relu(self):
def f(a, b): return a.relu()
helper_test_generic_square('relu', 4096, f, f, onearg=True)
def test_max(self):
def f(a, b): return a.max()
helper_test_generic_square('max', 4096, f, f, onearg=True)
def test_mul_sum(self):
def f(a, b): return (a*b).sum()
helper_test_generic_square('mul_sum', 4096, f, f)
def test_add(self):
for N in [1, 1024, 4096]:
def f(a, b): return a + b
helper_test_generic_square('add', N, f, f)
def test_add_constant(self):
def f(a, b): return a+2.0
helper_test_generic_square('add_constant', 4096, f, f, onearg=True)
def test_add_sq(self):
def f(a, b): return a*a + b*b
helper_test_generic_square('add_sq', 4096, f, f)
def test_gemm(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 1024, f, f)
def test_gemm_small(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 256, f, f)
def test_gemm_unrolled(self):
N = 512
def f1(a, b): return a@b.T
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled', N, f1, f2)
def test_gemm_unrolled_permute_l(self):
N = 512
def f1(a, b): return a.T@b.T
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_l', N, f1, f2)
def test_gemm_unrolled_permute_r(self):
N = 512
def f1(a, b): return a@b
def f2(a, b): return (a.reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_r', N, f1, f2)
def test_gemm_unrolled_permute_lr(self):
N = 512
def f1(a, b): return a.T@b
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
def test_openpilot_conv2d(self):
bs, in_chans, out_chans = 1,12,32
torch.manual_seed(0)
torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2))
def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d} k:3", f1, (torch_dat,), TinyJit(f2), (tiny_dat,))
def test_conv2d(self):
for bs in [32]:
for in_chans in IN_CHANS:
for out_chans in [32]:
helper_test_conv(bs, in_chans, out_chans, 3, 34, 34)
if __name__ == '__main__':
unittest.main()