a couple new tests

This commit is contained in:
George Hotz 2023-06-13 12:36:05 -07:00
parent ba4eadb04c
commit 80e665bddb
1 changed files with 7 additions and 0 deletions

View File

@ -146,7 +146,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(), ()], torch.minimum, Tensor.minimum)
def test_add(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x+y, Tensor.add)
def test_add_number(self):
helper_test_op([(), ()], lambda x,y: x+y, Tensor.add)
def test_add3(self):
helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: x+y+z)
def test_add_simple(self):
helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True)
def test_broadcasted_add(self):
@ -185,6 +188,7 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x/2, lambda x: x/2)
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
def test_pow(self):
# TODO: why is a=0 for these tests?
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
@ -193,6 +197,7 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
def test_pow_const(self):
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0)
helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x)
helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0)
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
@ -303,6 +308,8 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.,1.]])
def test_sum_full(self):
helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum())
def test_sum_small_full(self):
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
def test_sum_relu(self):
helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu())
def test_sum(self):