real matvec test

This commit is contained in:
George Hotz 2023-08-31 17:27:25 -07:00
parent 453e437598
commit e3a062ad17
1 changed files with 4 additions and 1 deletions

View File

@ -1128,9 +1128,12 @@ class TestOps(unittest.TestCase):
def test_clip(self):
helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2))
def test_matvec(self):
def test_matvecmat(self):
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4)
def test_matvec(self):
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu(), atol=1e-4)
# this was the failure in llama early realizing freqs_cis
def test_double_slice(self):
helper_test_op([(4,4)], lambda x: x[:, 1:2][1:2])