diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 825089a9..7ca3f112 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -12,4 +12,4 @@ for i in range(CNT): c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() -np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=1e-2) +np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2)