From d7b99c4296f7a5c098e756bad73af8ef38370261 Mon Sep 17 00:00:00 2001 From: Uku Loskit Date: Wed, 10 Jul 2024 08:47:23 +0300 Subject: [PATCH] Convert tests to pytests (#626) * make test pass * linter --------- Co-authored-by: Maxime Desroches --- .github/workflows/tests.yml | 2 +- Dockerfile | 2 +- msgq/tests/test_fake.py | 70 ++++++++++++-------------- msgq/tests/test_messaging.py | 30 +++-------- msgq/tests/test_poller.py | 18 +++---- msgq/visionipc/tests/test_visionipc.py | 61 ++++++++++++---------- pyproject.toml | 11 ++++ 7 files changed, 94 insertions(+), 100 deletions(-) mode change 100755 => 100644 msgq/tests/test_messaging.py mode change 100755 => 100644 msgq/visionipc/tests/test_visionipc.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 16219b0..cab495a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: msgq/test_runner && \ msgq/visionipc/test_runner" - name: python tests - run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m unittest discover ." + run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m pytest" - name: Upload coverage run: | docker commit msgq msgqci diff --git a/Dockerfile b/Dockerfile index 982d8fa..77ef04c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,7 +35,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ zlib1g-dev \ && rm -rf /var/lib/apt/lists/* -RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy +RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy pytest WORKDIR /project/msgq/ RUN cd /tmp/ && \ diff --git a/msgq/tests/test_fake.py b/msgq/tests/test_fake.py index b5ed297..d2c5131 100644 --- a/msgq/tests/test_fake.py +++ b/msgq/tests/test_fake.py @@ -1,5 +1,5 @@ +import pytest import os -import unittest import multiprocessing import platform import msgq @@ -9,18 +9,18 @@ from typing import Optional WAIT_TIMEOUT = 5 -@unittest.skipIf(platform.system() == "Darwin", "Events not supported on macOS") -class TestEvents(unittest.TestCase): +@pytest.mark.skipif(condition=platform.system() == "Darwin", reason="Events not supported on macOS") +class TestEvents: def test_mutation(self): handle = msgq.fake_event_handle("carState") event = handle.recv_called_event - self.assertFalse(event.peek()) + assert not event.peek() event.set() - self.assertTrue(event.peek()) + assert event.peek() event.clear() - self.assertFalse(event.peek()) + assert not event.peek() del event @@ -31,9 +31,9 @@ class TestEvents(unittest.TestCase): event.set() try: event.wait(WAIT_TIMEOUT) - self.assertTrue(event.peek()) + assert event.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") def test_wait_multiprocess(self): handle = msgq.fake_event_handle("carState") @@ -46,9 +46,9 @@ class TestEvents(unittest.TestCase): p = multiprocessing.Process(target=set_event_run) p.start() event.wait(WAIT_TIMEOUT) - self.assertTrue(event.peek()) + assert event.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") p.kill() @@ -58,34 +58,34 @@ class TestEvents(unittest.TestCase): try: event.wait(0) - self.fail("event.wait() did not time out") + pytest.fail("event.wait() did not time out") except RuntimeError: - self.assertFalse(event.peek()) + assert not event.peek() -@unittest.skipIf(platform.system() == "Darwin", "FakeSockets not supported on macOS") -@unittest.skipIf("ZMQ" in os.environ, "FakeSockets not supported on ZMQ") +@pytest.mark.skipif(condition=platform.system() == "Darwin", reason="FakeSockets not supported on macOS") +@pytest.mark.skipif(condition="ZMQ" in os.environ, reason="FakeSockets not supported on ZMQ") @parameterized_class([{"prefix": None}, {"prefix": "test"}]) -class TestFakeSockets(unittest.TestCase): +class TestFakeSockets: prefix: Optional[str] = None - def setUp(self): + def setup_method(self): msgq.toggle_fake_events(True) if self.prefix is not None: msgq.set_fake_prefix(self.prefix) else: msgq.delete_fake_prefix() - def tearDown(self): + def teardown_method(self): msgq.toggle_fake_events(False) msgq.delete_fake_prefix() def test_event_handle_init(self): handle = msgq.fake_event_handle("controlsState", override=True) - self.assertFalse(handle.enabled) - self.assertGreaterEqual(handle.recv_called_event.fd, 0) - self.assertGreaterEqual(handle.recv_ready_event.fd, 0) + assert not handle.enabled + assert handle.recv_called_event.fd >= 0 + assert handle.recv_ready_event.fd >= 0 def test_non_managed_socket_state(self): # non managed socket should have zero state @@ -93,9 +93,9 @@ class TestFakeSockets(unittest.TestCase): handle = msgq.fake_event_handle("ubloxGnss", override=False) - self.assertFalse(handle.enabled) - self.assertEqual(handle.recv_called_event.fd, 0) - self.assertEqual(handle.recv_ready_event.fd, 0) + assert not handle.enabled + assert handle.recv_called_event.fd == 0 + assert handle.recv_ready_event.fd == 0 def test_managed_socket_state(self): # managed socket should not change anything about the state @@ -108,9 +108,9 @@ class TestFakeSockets(unittest.TestCase): _ = msgq.pub_sock("ubloxGnss") - self.assertEqual(handle.enabled, expected_enabled) - self.assertEqual(handle.recv_called_event.fd, expected_recv_called_fd) - self.assertEqual(handle.recv_ready_event.fd, expected_recv_ready_fd) + assert handle.enabled == expected_enabled + assert handle.recv_called_event.fd == expected_recv_called_fd + assert handle.recv_ready_event.fd == expected_recv_ready_fd def test_sockets_enable_disable(self): carState_handle = msgq.fake_event_handle("ubloxGnss", enable=True) @@ -125,16 +125,16 @@ class TestFakeSockets(unittest.TestCase): recv_ready.set() pub_sock.send(b"test") _ = sub_sock.receive() - self.assertTrue(recv_called.peek()) + assert recv_called.peek() recv_called.clear() carState_handle.enabled = False recv_ready.set() pub_sock.send(b"test") _ = sub_sock.receive() - self.assertFalse(recv_called.peek()) + assert not recv_called.peek() except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") def test_synced_pub_sub(self): def daemon_repub_process_run(): @@ -177,16 +177,12 @@ class TestFakeSockets(unittest.TestCase): recv_called.wait(WAIT_TIMEOUT) msg = sub_sock.receive(non_blocking=True) - self.assertIsNotNone(msg) - self.assertEqual(len(msg), 8) + assert msg is not None + assert len(msg) == 8 frame = int.from_bytes(msg, 'little') - self.assertEqual(frame, i) + assert frame == i except RuntimeError: - self.fail("event.wait() timed out") + pytest.fail("event.wait() timed out") finally: p.kill() - - -if __name__ == "__main__": - unittest.main() diff --git a/msgq/tests/test_messaging.py b/msgq/tests/test_messaging.py old mode 100755 new mode 100644 index bbeeb3d..40dfd7f --- a/msgq/tests/test_messaging.py +++ b/msgq/tests/test_messaging.py @@ -1,10 +1,7 @@ -#!/usr/bin/env python3 import os import random -import threading import time import string -import unittest import msgq @@ -18,20 +15,9 @@ def zmq_sleep(t=1): if "ZMQ" in os.environ: time.sleep(t) -def zmq_expected_failure(func): - if "ZMQ" in os.environ: - return unittest.expectedFailure(func) - else: - return func +class TestPubSubSockets: -def delayed_send(delay, sock, dat): - def send_func(): - sock.send(dat) - threading.Timer(delay, send_func).start() - -class TestPubSubSockets(unittest.TestCase): - - def setUp(self): + def setup_method(self): # ZMQ pub socket takes too long to die # sleep to prevent multiple publishers error between tests zmq_sleep() @@ -46,7 +32,7 @@ class TestPubSubSockets(unittest.TestCase): msg = random_bytes() pub_sock.send(msg) recvd = sub_sock.receive() - self.assertEqual(msg, recvd) + assert msg == recvd def test_conflate(self): sock = random_sock() @@ -65,10 +51,10 @@ class TestPubSubSockets(unittest.TestCase): time.sleep(0.1) recvd_msgs = msgq.drain_sock_raw(sub_sock) if conflate: - self.assertEqual(len(recvd_msgs), 1) + assert len(recvd_msgs) == 1 else: # TODO: compare actual data - self.assertEqual(len(recvd_msgs), len(sent_msgs)) + assert len(recvd_msgs) == len(sent_msgs) def test_receive_timeout(self): sock = random_sock() @@ -79,9 +65,5 @@ class TestPubSubSockets(unittest.TestCase): start_time = time.monotonic() recvd = sub_sock.receive() - self.assertLess(time.monotonic() - start_time, 0.2) + assert (time.monotonic() - start_time) < 0.2 assert recvd is None - - -if __name__ == "__main__": - unittest.main() diff --git a/msgq/tests/test_poller.py b/msgq/tests/test_poller.py index a68ff4f..6ef2c04 100644 --- a/msgq/tests/test_poller.py +++ b/msgq/tests/test_poller.py @@ -1,4 +1,4 @@ -import unittest +import pytest import time import msgq import concurrent.futures @@ -20,7 +20,7 @@ def poller(): return r -class TestPoller(unittest.TestCase): +class TestPoller: def test_poll_once(self): context = msgq.Context() @@ -41,7 +41,7 @@ class TestPoller(unittest.TestCase): del pub context.term() - self.assertEqual(result, [b"a"]) + assert result == [b"a"] def test_poll_and_create_many_subscribers(self): context = msgq.Context() @@ -68,12 +68,12 @@ class TestPoller(unittest.TestCase): del pub context.term() - self.assertEqual(result, [b"a"]) + assert result == [b"a"] def test_multiple_publishers_exception(self): context = msgq.Context() - with self.assertRaises(msgq.MultiplePublishersError): + with pytest.raises(msgq.MultiplePublishersError): pub1 = msgq.PubSocket() pub1.connect(context, SERVICE_NAME) @@ -106,7 +106,7 @@ class TestPoller(unittest.TestCase): r = sub.receive(non_blocking=True) if r is not None: - self.assertEqual(b'a'*i, r) + assert b'a'*i == r msg_seen = True i += 1 @@ -131,12 +131,8 @@ class TestPoller(unittest.TestCase): pub.send(b'a') pub.send(b'b') - self.assertEqual(b'b', sub.receive()) + assert b'b' == sub.receive() del pub del sub context.term() - - -if __name__ == "__main__": - unittest.main() diff --git a/msgq/visionipc/tests/test_visionipc.py b/msgq/visionipc/tests/test_visionipc.py old mode 100755 new mode 100644 index 1c34613..c3b60b5 --- a/msgq/visionipc/tests/test_visionipc.py +++ b/msgq/visionipc/tests/test_visionipc.py @@ -1,8 +1,6 @@ -#!/usr/bin/env python3 import os import time import random -import unittest import numpy as np from msgq.visionipc import VisionIpcServer, VisionIpcClient, VisionStreamType @@ -11,7 +9,7 @@ def zmq_sleep(t=1): time.sleep(t) -class TestVisionIpc(unittest.TestCase): +class TestVisionIpc: def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, height=100, conflate=False): self.server = VisionIpcServer(name) @@ -21,7 +19,7 @@ class TestVisionIpc(unittest.TestCase): if len(stream_types): self.client = VisionIpcClient(name, stream_types[0], conflate) - self.assertTrue(self.client.connect(True)) + assert self.client.connect(True) else: self.client = None @@ -30,28 +28,37 @@ class TestVisionIpc(unittest.TestCase): def test_connect(self): self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) - self.assertTrue(self.client.is_connected) + assert self.client.is_connected + del self.client + del self.server def test_available_streams(self): for k in range(4): stream_types = set(random.choices([x.value for x in VisionStreamType], k=k)) self.setup_vipc("camerad", *stream_types) available_streams = VisionIpcClient.available_streams("camerad", True) - self.assertEqual(available_streams, stream_types) + assert available_streams == stream_types + del self.client + del self.server def test_buffers(self): width, height, num_buffers = 100, 200, 5 self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, num_buffers=num_buffers, width=width, height=height) - self.assertEqual(self.client.width, width) - self.assertEqual(self.client.height, height) - self.assertGreater(self.client.buffer_len, 0) - self.assertEqual(self.client.num_buffers, num_buffers) + assert self.client.width == width + assert self.client.height == height + assert self.client.buffer_len > 0 + assert self.client.num_buffers == num_buffers + del self.client + del self.server def test_yuv_rgb(self): _, client_yuv = self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, rgb=False) _, client_rgb = self.setup_vipc("navd", VisionStreamType.VISION_STREAM_MAP, rgb=True) - self.assertTrue(client_rgb.rgb) - self.assertFalse(client_yuv.rgb) + assert client_rgb.rgb + assert not client_yuv.rgb + del client_yuv + del client_rgb + del self.server def test_send_single_buffer(self): self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) @@ -61,9 +68,11 @@ class TestVisionIpc(unittest.TestCase): self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=1337) recv_buf = self.client.recv() - self.assertIsNot(recv_buf, None) - self.assertEqual(recv_buf.data.view('