use allclose instead of equals in test_jit (#1504)

Closes #1503
This commit is contained in:
nimlgen 2023-08-09 08:22:17 +03:00 committed by GitHub
parent 827d13e64e
commit dabfd7569a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 10 deletions

View File

@ -17,7 +17,7 @@ class TestJit(unittest.TestCase):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add(a, b)
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add.jit_cache) == 1
def test_jit_multiple_outputs(self):
@ -27,9 +27,9 @@ class TestJit(unittest.TestCase):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c, d, e = f(a, b)
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
np.testing.assert_equal(d.numpy(), a.numpy()-b.numpy())
np.testing.assert_equal(e.numpy(), a.numpy()*b.numpy())
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
assert len(f.jit_cache) == 3
def test_nothing_jitted(self):
@ -67,7 +67,7 @@ class TestJit(unittest.TestCase):
a = Tensor.randn(10, 10)
b = Tensor.randn(10, 10)
c = add_kwargs(first=a, second=b)
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add_kwargs.jit_cache) == 1
def test_array_jit(self):
@ -80,9 +80,9 @@ class TestJit(unittest.TestCase):
c = add_array(a, [b])
if i >= 2:
# should fail once jitted since jit can't handle arrays
np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True)
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
else:
np.testing.assert_equal(c.numpy(), a.numpy()+b.numpy())
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(add_array.jit_cache) == 1
def test_method_jit(self):
@ -96,7 +96,7 @@ class TestJit(unittest.TestCase):
for _ in range(5):
b = Tensor.randn(10, 10)
c = fun(b)
np.testing.assert_equal(c.numpy(), fun.a.numpy()+b.numpy())
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert len(fun.__call__.func.__self__.jit_cache) == 1
def test_jit_size1_input(self):
@ -104,7 +104,7 @@ class TestJit(unittest.TestCase):
def f(a, b): return (a+b).realize()
a = Tensor([1, 2, 3])
for i in range(5):
np.testing.assert_equal(f(a, Tensor([i])).cpu().numpy(), (a+i).cpu().numpy())
np.testing.assert_allclose(f(a, Tensor([i])).cpu().numpy(), (a+i).cpu().numpy(), atol=1e-4, rtol=1e-5)
assert len(f.jit_cache) == 1
def test_jit_output_non_tensor_fail(self):
@ -120,7 +120,7 @@ class TestJit(unittest.TestCase):
output2.append(o2)
expect1.append(a.numpy().copy()+b.numpy().copy())
expect2.append(i)
np.testing.assert_equal(output1, expect1)
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
# the jit only works with Tensor outputs
assert output2 != expect2
assert len(f.jit_cache) == 1