increase atol and rtol test_gemm_fp16 (#5866)

* increase atol and rtol test_gemm_fp16

made it pass with NOOPT which has larger accumulated error

* revert that
This commit is contained in:
chenyu 2024-08-01 19:09:58 -04:00 committed by GitHub
parent b03b8e18c2
commit b392b8edc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -839,7 +839,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"], "not supported on these in CI")
def test_gemm_fp16(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()))
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
def test_gemm(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y))
def test_big_gemm(self):