Convert tests to use pytest instead of unittest (#43)

* Convert tests to use pytest instead of unittest

* Augment PYTHONPATH for pytest via pyproject.toml

* cleanup

---------

Co-authored-by: Maxime Desroches <desroches.maxime@gmail.com>
This commit is contained in:
Uku Loskit 2024-07-10 03:36:23 +03:00 committed by GitHub
parent 72b3479bab
commit 3ad6816953
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 18 deletions

View File

@ -17,8 +17,8 @@ jobs:
- name: Static analysis - name: Static analysis
run: ${{ env.RUN }} "git init && git add -A && pre-commit run --all" run: ${{ env.RUN }} "git init && git add -A && pre-commit run --all"
- name: Unit Tests - name: Unit Tests
run: ${{ env.RUN }} "cd /project/examples; python -m unittest discover" run: ${{ env.RUN }} "pytest"
docker_push: docker_push:
name: docker push name: docker push
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -1,9 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import pytest
import os import os
import sys import sys
import sympy as sp import sympy as sp
import numpy as np import numpy as np
import unittest
if __name__ == '__main__': # generating sympy code if __name__ == '__main__': # generating sympy code
from rednose.helpers.ekf_sym import gen_code from rednose.helpers.ekf_sym import gen_code
@ -83,7 +83,7 @@ class CompareFilter:
return R return R
class TestCompare(unittest.TestCase): class TestCompare:
def test_compare(self): def test_compare(self):
np.random.seed(0) np.random.seed(0)
@ -115,9 +115,9 @@ class TestCompare(unittest.TestCase):
kf.filter_py.predict_and_update_batch(t, ObservationKind.POSITION, z, R) kf.filter_py.predict_and_update_batch(t, ObservationKind.POSITION, z, R)
kf.filter_pyx.predict_and_update_batch(t, ObservationKind.POSITION, z, R) kf.filter_pyx.predict_and_update_batch(t, ObservationKind.POSITION, z, R)
self.assertAlmostEqual(kf.filter_py.get_filter_time(), kf.filter_pyx.get_filter_time()) assert kf.filter_py.get_filter_time() == pytest.approx(kf.filter_pyx.get_filter_time())
self.assertTrue(np.allclose(kf.filter_py.state(), kf.filter_pyx.state())) assert np.allclose(kf.filter_py.state(), kf.filter_pyx.state())
self.assertTrue(np.allclose(kf.filter_py.covs(), kf.filter_pyx.covs())) assert np.allclose(kf.filter_py.covs(), kf.filter_pyx.covs())
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,12 +1,12 @@
import pytest
import os import os
import numpy as np import numpy as np
import unittest
from kinematic_kf import KinematicKalman, ObservationKind, States # pylint: disable=import-error from .kinematic_kf import KinematicKalman, ObservationKind, States
GENERATED_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), 'generated')) GENERATED_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), 'generated'))
class TestKinematic(unittest.TestCase): class TestKinematic:
def test_kinematic_kf(self): def test_kinematic_kf(self):
np.random.seed(0) np.random.seed(0)
@ -49,10 +49,10 @@ class TestKinematic(unittest.TestCase):
xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std = (np.asarray(a) for a in (xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std)) xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std = (np.asarray(a) for a in (xs, xs_meas, xs_kf, vs_kf, xs_kf_std, vs_kf_std))
self.assertAlmostEqual(xs_kf[-1], -0.010866289677966417) assert xs_kf[-1] == pytest.approx(-0.010866289677966417)
self.assertAlmostEqual(xs_kf_std[-1], 0.04477103863330089) assert xs_kf_std[-1] == pytest.approx(0.04477103863330089)
self.assertAlmostEqual(vs_kf[-1], -0.8553720537261753) assert vs_kf[-1] == pytest.approx(-0.8553720537261753)
self.assertAlmostEqual(vs_kf_std[-1], 0.6695762270974388) assert vs_kf_std[-1] == pytest.approx(0.6695762270974388)
if "PLOT" in os.environ: if "PLOT" in os.environ:
import matplotlib.pyplot as plt # pylint: disable=import-error import matplotlib.pyplot as plt # pylint: disable=import-error
@ -80,7 +80,3 @@ class TestKinematic(unittest.TestCase):
plt.legend() plt.legend()
plt.show() plt.show()
if __name__ == "__main__":
unittest.main()

View File

@ -7,3 +7,10 @@ target-version="py311"
select = ["E", "F", "W", "PIE", "C4", "ISC", "RUF100", "A"] select = ["E", "F", "W", "PIE", "C4", "ISC", "RUF100", "A"]
ignore = ["W292", "E741", "E402", "C408", "ISC003"] ignore = ["W292", "E741", "E402", "C408", "ISC003"]
flake8-implicit-str-concat.allow-multiline=false flake8-implicit-str-concat.allow-multiline=false
[tool.ruff.lint.flake8-tidy-imports.banned-api]
"pytest.main".msg = "pytest.main requires special handling that is easy to mess up!"
"unittest".msg = "Use pytest"
[tool.pytest.ini_options]
addopts = "--durations=10 -n auto"

View File

@ -7,3 +7,5 @@ cffi
scons scons
pre-commit pre-commit
Cython Cython
pytest
pytest-xdist