diff --git a/test/test_jit.py b/test/test_jit.py index 6c1d929d..269aa477 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -17,7 +17,7 @@ class TestJit(unittest.TestCase): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c = add(a, b) - np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy()) + np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert len(add.jit_cache) == 1 def test_jit_multiple_outputs(self): @@ -27,9 +27,9 @@ class TestJit(unittest.TestCase): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c, d, e = f(a, b) - np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy()) - np.testing.assert_equal(d.numpy(), a.numpy()-b.numpy()) - np.testing.assert_equal(e.numpy(), a.numpy()*b.numpy()) + np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) + np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5) + np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5) assert len(f.jit_cache) == 3 def test_nothing_jitted(self): @@ -67,7 +67,7 @@ class TestJit(unittest.TestCase): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) c = add_kwargs(first=a, second=b) - np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy()) + np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert len(add_kwargs.jit_cache) == 1 def test_array_jit(self): @@ -80,9 +80,9 @@ class TestJit(unittest.TestCase): c = add_array(a, [b]) if i >= 2: # should fail once jitted since jit can't handle arrays - np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True) + np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5) else: - np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy()) + np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert len(add_array.jit_cache) == 1 def test_method_jit(self): @@ -96,7 +96,7 @@ class TestJit(unittest.TestCase): for _ in range(5): b = Tensor.randn(10, 10) c = fun(b) - np.testing.assert_equal(c.numpy(), fun.a.numpy()+b.numpy()) + np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5) assert len(fun.__call__.func.__self__.jit_cache) == 1 def test_jit_size1_input(self): @@ -104,7 +104,7 @@ class TestJit(unittest.TestCase): def f(a, b): return (a+b).realize() a = Tensor([1, 2, 3]) for i in range(5): - np.testing.assert_equal(f(a, Tensor([i])).cpu().numpy(), (a+i).cpu().numpy()) + np.testing.assert_allclose(f(a, Tensor([i])).cpu().numpy(), (a+i).cpu().numpy(), atol=1e-4, rtol=1e-5) assert len(f.jit_cache) == 1 def test_jit_output_non_tensor_fail(self): @@ -120,7 +120,7 @@ class TestJit(unittest.TestCase): output2.append(o2) expect1.append(a.numpy().copy()+b.numpy().copy()) expect2.append(i) - np.testing.assert_equal(output1, expect1) + np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5) # the jit only works with Tensor outputs assert output2 != expect2 assert len(f.jit_cache) == 1