Convert tests to pytests (#626)

* make test pass

* linter

---------

Co-authored-by: Maxime Desroches <desroches.maxime@gmail.com>
This commit is contained in:
Uku Loskit 2024-07-10 08:47:23 +03:00 committed by GitHub
parent 74074d650f
commit d7b99c4296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 94 additions and 100 deletions

View File

@ -43,7 +43,7 @@ jobs:
msgq/test_runner && \ msgq/test_runner && \
msgq/visionipc/test_runner" msgq/visionipc/test_runner"
- name: python tests - name: python tests
run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m unittest discover ." run: $RUN_NAMED "${{ matrix.backend }}=1 coverage run -m pytest"
- name: Upload coverage - name: Upload coverage
run: | run: |
docker commit msgq msgqci docker commit msgq msgqci

View File

@ -35,7 +35,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
zlib1g-dev \ zlib1g-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy RUN pip3 install --break-system-packages --no-cache-dir pyyaml Cython scons pycapnp pre-commit ruff parameterized coverage numpy pytest
WORKDIR /project/msgq/ WORKDIR /project/msgq/
RUN cd /tmp/ && \ RUN cd /tmp/ && \

View File

@ -1,5 +1,5 @@
import pytest
import os import os
import unittest
import multiprocessing import multiprocessing
import platform import platform
import msgq import msgq
@ -9,18 +9,18 @@ from typing import Optional
WAIT_TIMEOUT = 5 WAIT_TIMEOUT = 5
@unittest.skipIf(platform.system() == "Darwin", "Events not supported on macOS") @pytest.mark.skipif(condition=platform.system() == "Darwin", reason="Events not supported on macOS")
class TestEvents(unittest.TestCase): class TestEvents:
def test_mutation(self): def test_mutation(self):
handle = msgq.fake_event_handle("carState") handle = msgq.fake_event_handle("carState")
event = handle.recv_called_event event = handle.recv_called_event
self.assertFalse(event.peek()) assert not event.peek()
event.set() event.set()
self.assertTrue(event.peek()) assert event.peek()
event.clear() event.clear()
self.assertFalse(event.peek()) assert not event.peek()
del event del event
@ -31,9 +31,9 @@ class TestEvents(unittest.TestCase):
event.set() event.set()
try: try:
event.wait(WAIT_TIMEOUT) event.wait(WAIT_TIMEOUT)
self.assertTrue(event.peek()) assert event.peek()
except RuntimeError: except RuntimeError:
self.fail("event.wait() timed out") pytest.fail("event.wait() timed out")
def test_wait_multiprocess(self): def test_wait_multiprocess(self):
handle = msgq.fake_event_handle("carState") handle = msgq.fake_event_handle("carState")
@ -46,9 +46,9 @@ class TestEvents(unittest.TestCase):
p = multiprocessing.Process(target=set_event_run) p = multiprocessing.Process(target=set_event_run)
p.start() p.start()
event.wait(WAIT_TIMEOUT) event.wait(WAIT_TIMEOUT)
self.assertTrue(event.peek()) assert event.peek()
except RuntimeError: except RuntimeError:
self.fail("event.wait() timed out") pytest.fail("event.wait() timed out")
p.kill() p.kill()
@ -58,34 +58,34 @@ class TestEvents(unittest.TestCase):
try: try:
event.wait(0) event.wait(0)
self.fail("event.wait() did not time out") pytest.fail("event.wait() did not time out")
except RuntimeError: except RuntimeError:
self.assertFalse(event.peek()) assert not event.peek()
@unittest.skipIf(platform.system() == "Darwin", "FakeSockets not supported on macOS") @pytest.mark.skipif(condition=platform.system() == "Darwin", reason="FakeSockets not supported on macOS")
@unittest.skipIf("ZMQ" in os.environ, "FakeSockets not supported on ZMQ") @pytest.mark.skipif(condition="ZMQ" in os.environ, reason="FakeSockets not supported on ZMQ")
@parameterized_class([{"prefix": None}, {"prefix": "test"}]) @parameterized_class([{"prefix": None}, {"prefix": "test"}])
class TestFakeSockets(unittest.TestCase): class TestFakeSockets:
prefix: Optional[str] = None prefix: Optional[str] = None
def setUp(self): def setup_method(self):
msgq.toggle_fake_events(True) msgq.toggle_fake_events(True)
if self.prefix is not None: if self.prefix is not None:
msgq.set_fake_prefix(self.prefix) msgq.set_fake_prefix(self.prefix)
else: else:
msgq.delete_fake_prefix() msgq.delete_fake_prefix()
def tearDown(self): def teardown_method(self):
msgq.toggle_fake_events(False) msgq.toggle_fake_events(False)
msgq.delete_fake_prefix() msgq.delete_fake_prefix()
def test_event_handle_init(self): def test_event_handle_init(self):
handle = msgq.fake_event_handle("controlsState", override=True) handle = msgq.fake_event_handle("controlsState", override=True)
self.assertFalse(handle.enabled) assert not handle.enabled
self.assertGreaterEqual(handle.recv_called_event.fd, 0) assert handle.recv_called_event.fd >= 0
self.assertGreaterEqual(handle.recv_ready_event.fd, 0) assert handle.recv_ready_event.fd >= 0
def test_non_managed_socket_state(self): def test_non_managed_socket_state(self):
# non managed socket should have zero state # non managed socket should have zero state
@ -93,9 +93,9 @@ class TestFakeSockets(unittest.TestCase):
handle = msgq.fake_event_handle("ubloxGnss", override=False) handle = msgq.fake_event_handle("ubloxGnss", override=False)
self.assertFalse(handle.enabled) assert not handle.enabled
self.assertEqual(handle.recv_called_event.fd, 0) assert handle.recv_called_event.fd == 0
self.assertEqual(handle.recv_ready_event.fd, 0) assert handle.recv_ready_event.fd == 0
def test_managed_socket_state(self): def test_managed_socket_state(self):
# managed socket should not change anything about the state # managed socket should not change anything about the state
@ -108,9 +108,9 @@ class TestFakeSockets(unittest.TestCase):
_ = msgq.pub_sock("ubloxGnss") _ = msgq.pub_sock("ubloxGnss")
self.assertEqual(handle.enabled, expected_enabled) assert handle.enabled == expected_enabled
self.assertEqual(handle.recv_called_event.fd, expected_recv_called_fd) assert handle.recv_called_event.fd == expected_recv_called_fd
self.assertEqual(handle.recv_ready_event.fd, expected_recv_ready_fd) assert handle.recv_ready_event.fd == expected_recv_ready_fd
def test_sockets_enable_disable(self): def test_sockets_enable_disable(self):
carState_handle = msgq.fake_event_handle("ubloxGnss", enable=True) carState_handle = msgq.fake_event_handle("ubloxGnss", enable=True)
@ -125,16 +125,16 @@ class TestFakeSockets(unittest.TestCase):
recv_ready.set() recv_ready.set()
pub_sock.send(b"test") pub_sock.send(b"test")
_ = sub_sock.receive() _ = sub_sock.receive()
self.assertTrue(recv_called.peek()) assert recv_called.peek()
recv_called.clear() recv_called.clear()
carState_handle.enabled = False carState_handle.enabled = False
recv_ready.set() recv_ready.set()
pub_sock.send(b"test") pub_sock.send(b"test")
_ = sub_sock.receive() _ = sub_sock.receive()
self.assertFalse(recv_called.peek()) assert not recv_called.peek()
except RuntimeError: except RuntimeError:
self.fail("event.wait() timed out") pytest.fail("event.wait() timed out")
def test_synced_pub_sub(self): def test_synced_pub_sub(self):
def daemon_repub_process_run(): def daemon_repub_process_run():
@ -177,16 +177,12 @@ class TestFakeSockets(unittest.TestCase):
recv_called.wait(WAIT_TIMEOUT) recv_called.wait(WAIT_TIMEOUT)
msg = sub_sock.receive(non_blocking=True) msg = sub_sock.receive(non_blocking=True)
self.assertIsNotNone(msg) assert msg is not None
self.assertEqual(len(msg), 8) assert len(msg) == 8
frame = int.from_bytes(msg, 'little') frame = int.from_bytes(msg, 'little')
self.assertEqual(frame, i) assert frame == i
except RuntimeError: except RuntimeError:
self.fail("event.wait() timed out") pytest.fail("event.wait() timed out")
finally: finally:
p.kill() p.kill()
if __name__ == "__main__":
unittest.main()

30
msgq/tests/test_messaging.py Executable file → Normal file
View File

@ -1,10 +1,7 @@
#!/usr/bin/env python3
import os import os
import random import random
import threading
import time import time
import string import string
import unittest
import msgq import msgq
@ -18,20 +15,9 @@ def zmq_sleep(t=1):
if "ZMQ" in os.environ: if "ZMQ" in os.environ:
time.sleep(t) time.sleep(t)
def zmq_expected_failure(func): class TestPubSubSockets:
if "ZMQ" in os.environ:
return unittest.expectedFailure(func)
else:
return func
def delayed_send(delay, sock, dat): def setup_method(self):
def send_func():
sock.send(dat)
threading.Timer(delay, send_func).start()
class TestPubSubSockets(unittest.TestCase):
def setUp(self):
# ZMQ pub socket takes too long to die # ZMQ pub socket takes too long to die
# sleep to prevent multiple publishers error between tests # sleep to prevent multiple publishers error between tests
zmq_sleep() zmq_sleep()
@ -46,7 +32,7 @@ class TestPubSubSockets(unittest.TestCase):
msg = random_bytes() msg = random_bytes()
pub_sock.send(msg) pub_sock.send(msg)
recvd = sub_sock.receive() recvd = sub_sock.receive()
self.assertEqual(msg, recvd) assert msg == recvd
def test_conflate(self): def test_conflate(self):
sock = random_sock() sock = random_sock()
@ -65,10 +51,10 @@ class TestPubSubSockets(unittest.TestCase):
time.sleep(0.1) time.sleep(0.1)
recvd_msgs = msgq.drain_sock_raw(sub_sock) recvd_msgs = msgq.drain_sock_raw(sub_sock)
if conflate: if conflate:
self.assertEqual(len(recvd_msgs), 1) assert len(recvd_msgs) == 1
else: else:
# TODO: compare actual data # TODO: compare actual data
self.assertEqual(len(recvd_msgs), len(sent_msgs)) assert len(recvd_msgs) == len(sent_msgs)
def test_receive_timeout(self): def test_receive_timeout(self):
sock = random_sock() sock = random_sock()
@ -79,9 +65,5 @@ class TestPubSubSockets(unittest.TestCase):
start_time = time.monotonic() start_time = time.monotonic()
recvd = sub_sock.receive() recvd = sub_sock.receive()
self.assertLess(time.monotonic() - start_time, 0.2) assert (time.monotonic() - start_time) < 0.2
assert recvd is None assert recvd is None
if __name__ == "__main__":
unittest.main()

View File

@ -1,4 +1,4 @@
import unittest import pytest
import time import time
import msgq import msgq
import concurrent.futures import concurrent.futures
@ -20,7 +20,7 @@ def poller():
return r return r
class TestPoller(unittest.TestCase): class TestPoller:
def test_poll_once(self): def test_poll_once(self):
context = msgq.Context() context = msgq.Context()
@ -41,7 +41,7 @@ class TestPoller(unittest.TestCase):
del pub del pub
context.term() context.term()
self.assertEqual(result, [b"a"]) assert result == [b"a"]
def test_poll_and_create_many_subscribers(self): def test_poll_and_create_many_subscribers(self):
context = msgq.Context() context = msgq.Context()
@ -68,12 +68,12 @@ class TestPoller(unittest.TestCase):
del pub del pub
context.term() context.term()
self.assertEqual(result, [b"a"]) assert result == [b"a"]
def test_multiple_publishers_exception(self): def test_multiple_publishers_exception(self):
context = msgq.Context() context = msgq.Context()
with self.assertRaises(msgq.MultiplePublishersError): with pytest.raises(msgq.MultiplePublishersError):
pub1 = msgq.PubSocket() pub1 = msgq.PubSocket()
pub1.connect(context, SERVICE_NAME) pub1.connect(context, SERVICE_NAME)
@ -106,7 +106,7 @@ class TestPoller(unittest.TestCase):
r = sub.receive(non_blocking=True) r = sub.receive(non_blocking=True)
if r is not None: if r is not None:
self.assertEqual(b'a'*i, r) assert b'a'*i == r
msg_seen = True msg_seen = True
i += 1 i += 1
@ -131,12 +131,8 @@ class TestPoller(unittest.TestCase):
pub.send(b'a') pub.send(b'a')
pub.send(b'b') pub.send(b'b')
self.assertEqual(b'b', sub.receive()) assert b'b' == sub.receive()
del pub del pub
del sub del sub
context.term() context.term()
if __name__ == "__main__":
unittest.main()

61
msgq/visionipc/tests/test_visionipc.py Executable file → Normal file
View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3
import os import os
import time import time
import random import random
import unittest
import numpy as np import numpy as np
from msgq.visionipc import VisionIpcServer, VisionIpcClient, VisionStreamType from msgq.visionipc import VisionIpcServer, VisionIpcClient, VisionStreamType
@ -11,7 +9,7 @@ def zmq_sleep(t=1):
time.sleep(t) time.sleep(t)
class TestVisionIpc(unittest.TestCase): class TestVisionIpc:
def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, height=100, conflate=False): def setup_vipc(self, name, *stream_types, num_buffers=1, rgb=False, width=100, height=100, conflate=False):
self.server = VisionIpcServer(name) self.server = VisionIpcServer(name)
@ -21,7 +19,7 @@ class TestVisionIpc(unittest.TestCase):
if len(stream_types): if len(stream_types):
self.client = VisionIpcClient(name, stream_types[0], conflate) self.client = VisionIpcClient(name, stream_types[0], conflate)
self.assertTrue(self.client.connect(True)) assert self.client.connect(True)
else: else:
self.client = None self.client = None
@ -30,28 +28,37 @@ class TestVisionIpc(unittest.TestCase):
def test_connect(self): def test_connect(self):
self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD)
self.assertTrue(self.client.is_connected) assert self.client.is_connected
del self.client
del self.server
def test_available_streams(self): def test_available_streams(self):
for k in range(4): for k in range(4):
stream_types = set(random.choices([x.value for x in VisionStreamType], k=k)) stream_types = set(random.choices([x.value for x in VisionStreamType], k=k))
self.setup_vipc("camerad", *stream_types) self.setup_vipc("camerad", *stream_types)
available_streams = VisionIpcClient.available_streams("camerad", True) available_streams = VisionIpcClient.available_streams("camerad", True)
self.assertEqual(available_streams, stream_types) assert available_streams == stream_types
del self.client
del self.server
def test_buffers(self): def test_buffers(self):
width, height, num_buffers = 100, 200, 5 width, height, num_buffers = 100, 200, 5
self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, num_buffers=num_buffers, width=width, height=height) self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, num_buffers=num_buffers, width=width, height=height)
self.assertEqual(self.client.width, width) assert self.client.width == width
self.assertEqual(self.client.height, height) assert self.client.height == height
self.assertGreater(self.client.buffer_len, 0) assert self.client.buffer_len > 0
self.assertEqual(self.client.num_buffers, num_buffers) assert self.client.num_buffers == num_buffers
del self.client
del self.server
def test_yuv_rgb(self): def test_yuv_rgb(self):
_, client_yuv = self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, rgb=False) _, client_yuv = self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, rgb=False)
_, client_rgb = self.setup_vipc("navd", VisionStreamType.VISION_STREAM_MAP, rgb=True) _, client_rgb = self.setup_vipc("navd", VisionStreamType.VISION_STREAM_MAP, rgb=True)
self.assertTrue(client_rgb.rgb) assert client_rgb.rgb
self.assertFalse(client_yuv.rgb) assert not client_yuv.rgb
del client_yuv
del client_rgb
del self.server
def test_send_single_buffer(self): def test_send_single_buffer(self):
self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD)
@ -61,9 +68,11 @@ class TestVisionIpc(unittest.TestCase):
self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=1337) self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=1337)
recv_buf = self.client.recv() recv_buf = self.client.recv()
self.assertIsNot(recv_buf, None) assert recv_buf is not None
self.assertEqual(recv_buf.data.view('<i4')[0], 1234) assert recv_buf.data.view('<i4')[0] == 1234
self.assertEqual(self.client.frame_id, 1337) assert self.client.frame_id == 1337
del self.client
del self.server
def test_no_conflate(self): def test_no_conflate(self):
self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD) self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD)
@ -73,12 +82,14 @@ class TestVisionIpc(unittest.TestCase):
self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=2) self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=2)
recv_buf = self.client.recv() recv_buf = self.client.recv()
self.assertIsNot(recv_buf, None) assert recv_buf is not None
self.assertEqual(self.client.frame_id, 1) assert self.client.frame_id == 1
recv_buf = self.client.recv() recv_buf = self.client.recv()
self.assertIsNot(recv_buf, None) assert recv_buf is not None
self.assertEqual(self.client.frame_id, 2) assert self.client.frame_id == 2
del self.client
del self.server
def test_conflate(self): def test_conflate(self):
self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, conflate=True) self.setup_vipc("camerad", VisionStreamType.VISION_STREAM_ROAD, conflate=True)
@ -88,12 +99,10 @@ class TestVisionIpc(unittest.TestCase):
self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=2) self.server.send(VisionStreamType.VISION_STREAM_ROAD, buf, frame_id=2)
recv_buf = self.client.recv() recv_buf = self.client.recv()
self.assertIsNot(recv_buf, None) assert recv_buf is not None
self.assertEqual(self.client.frame_id, 2) assert self.client.frame_id == 2
recv_buf = self.client.recv() recv_buf = self.client.recv()
self.assertIs(recv_buf, None) assert recv_buf is None
del self.client
del self.server
if __name__ == "__main__":
unittest.main()

View File

@ -7,6 +7,10 @@ lint.flake8-implicit-str-concat.allow-multiline=false
line-length = 160 line-length = 160
target-version="py311" target-version="py311"
[tool.ruff.lint.flake8-tidy-imports.banned-api]
"pytest.main".msg = "pytest.main requires special handling that is easy to mess up!"
"unittest".msg = "Use pytest"
[mypy.tool] [mypy.tool]
# third-party packages # third-party packages
ignore_missing_imports=true ignore_missing_imports=true
@ -19,3 +23,10 @@ warn_unused_ignores=true
# restrict dynamic typing # restrict dynamic typing
warn_return_any=true warn_return_any=true
check_untyped_defs=true check_untyped_defs=true
[tool.pytest.ini_options]
addopts = "--durations=10"
testpaths = [
"msgq/tests",
"msgq/visionipc/tests",
]