touch up pytorch speed tests

This commit is contained in:
George Hotz 2023-01-25 18:11:26 -08:00
parent 8db345d846
commit 44e96c58b4
1 changed files with 8 additions and 8 deletions

View File

@ -85,14 +85,6 @@ class TestSpeed(unittest.TestCase):
def f(a, b): return a-b
helper_test_generic_square('sub', 4096, f, f)
def test_constant_sub(self):
def f(a, b): return 1.0-a
helper_test_generic_square('sub', 4096, f, f)
def test_constant_zero_sub(self):
def f(a, b): return 0.0-a
helper_test_generic_square('sub', 4096, f, f)
def test_pow(self):
def f(a, b): return a.pow(b)
helper_test_generic_square('pow', 2048, f, f)
@ -151,6 +143,14 @@ class TestSpeed(unittest.TestCase):
def f(a, b): return a + b
helper_test_generic_square('add', N, f, f)
def test_add_constant(self):
def f(a, b): return a+2.0
helper_test_generic_square('add_constant', 4096, f, f)
def test_add_constant_zero(self):
def f(a, b): return a+0.0
helper_test_generic_square('add_constant_zero', 4096, f, f)
def test_add_sq(self):
def f(a, b): return a*a + b*b
helper_test_generic_square('add_sq', 4096, f, f)