From 44e96c58b48b880ad02fde52eadf04bc389c3484 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 25 Jan 2023 18:11:26 -0800 Subject: [PATCH] touch up pytorch speed tests --- test/test_speed_v_torch.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index a43f08d9..d0ed90a6 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -85,14 +85,6 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a-b helper_test_generic_square('sub', 4096, f, f) - def test_constant_sub(self): - def f(a, b): return 1.0-a - helper_test_generic_square('sub', 4096, f, f) - - def test_constant_zero_sub(self): - def f(a, b): return 0.0-a - helper_test_generic_square('sub', 4096, f, f) - def test_pow(self): def f(a, b): return a.pow(b) helper_test_generic_square('pow', 2048, f, f) @@ -151,6 +143,14 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a + b helper_test_generic_square('add', N, f, f) + def test_add_constant(self): + def f(a, b): return a+2.0 + helper_test_generic_square('add_constant', 4096, f, f) + + def test_add_constant_zero(self): + def f(a, b): return a+0.0 + helper_test_generic_square('add_constant_zero', 4096, f, f) + def test_add_sq(self): def f(a, b): return a*a + b*b helper_test_generic_square('add_sq', 4096, f, f)