From d25046e66a17220b5d21ebffb81ddd39abb31207 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 22 Aug 2023 17:33:58 -0700 Subject: [PATCH] matvec tests (#1634) * matvec tests * f16 * f16 is broken --- test/test_speed_v_torch.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index b1dce972..1b56664a 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -86,6 +86,17 @@ def helper_test_generic_square(name, N, f1, f2, onearg=False): helper_test_generic(f"{name:30s} {N:5d}x{N:5d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b)) +def helper_test_matvec(name, N, M): + torch.manual_seed(0) + dt = torch.float32 + torch_a = (torch.rand(N, dtype=dt) - 0.5).to(torch_device) + torch_b = (torch.rand(N, M, dtype=dt) - 0.5).to(torch_device) + + tiny_a = Tensor(torch_a.cpu().numpy()) + tiny_b = Tensor(torch_b.cpu().numpy()) + + helper_test_generic(f"{name:30s} {N:5d}x{M:5d}", lambda a,b: a@b, (torch_a, torch_b), TinyJit(lambda a,b:(a@b).realize()), (tiny_a, tiny_b)) + prefix = None def helper_test_generic(name, f1, f1_args, f2, f2_args): global prefix @@ -129,6 +140,8 @@ class TestBigSpeed(unittest.TestCase): def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128) def test_large_conv_3x3(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130) def test_large_conv_5x5(self): helper_test_conv(bs=4, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130) + def test_matvec_4096_16384(self): helper_test_matvec('matvec_4096_16384', 4096, 16384) + def test_matvec_16384_4096(self): helper_test_matvec('matvec_16384_4096', 16384, 4096) @unittest.skipIf(getenv("BIG") == 1, "only big tests") class TestSpeed(unittest.TestCase): @@ -247,6 +260,11 @@ class TestSpeed(unittest.TestCase): def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2) + def test_matvec_1024_1024(self): helper_test_matvec('matvec_1024_1024', 1024, 1024) + def test_matvec_1024_4096(self): helper_test_matvec('matvec_1024_4096', 1024, 4096) + def test_matvec_4096_1024(self): helper_test_matvec('matvec_4096_1024', 4096, 1024) + def test_matvec_4096_4096(self): helper_test_matvec('matvec_4096_4096', 4096, 4096) + def test_openpilot_conv2d(self): bs, in_chans, out_chans = 1,12,32 torch.manual_seed(0)