diff --git a/test/test_ops.py b/test/test_ops.py index 277f1bca..742f848e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -839,7 +839,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) @unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"], "not supported on these in CI") def test_gemm_fp16(self): - helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half())) + helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3) def test_gemm(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y)) def test_big_gemm(self):