diff --git a/system/webrtc/tests/test_webrtcd.py b/system/webrtc/tests/test_webrtcd.py index c31b63d02b..e5742dba07 100755 --- a/system/webrtc/tests/test_webrtcd.py +++ b/system/webrtc/tests/test_webrtcd.py @@ -11,8 +11,15 @@ from openpilot.system.webrtc.webrtcd import get_stream import aiortc from teleoprtc import WebRTCOfferBuilder +from parameterized import parameterized_class +@parameterized_class(("in_services", "out_services"), [ + (["testJoystick"], ["carState"]), + ([], ["carState"]), + (["testJoystick"], []), + ([], []), +]) class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): async def assertCompletesWithTimeout(self, awaitable, timeout=1): try: @@ -24,7 +31,7 @@ class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): async def test_webrtcd(self): mock_request = MagicMock() async def connect(offer): - body = {'sdp': offer.sdp, 'cameras': offer.video, 'bridge_services_in': [], 'bridge_services_out': ['carState']} + body = {'sdp': offer.sdp, 'cameras': offer.video, 'bridge_services_in': self.in_services, 'bridge_services_out': self.out_services} mock_request.json.side_effect = AsyncMock(return_value=body) response = await get_stream(mock_request) response_json = json.loads(response.text) @@ -33,7 +40,8 @@ class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): builder = WebRTCOfferBuilder(connect) builder.offer_to_receive_video_stream("road") builder.offer_to_receive_audio_stream() - builder.add_messaging() + if len(self.in_services) > 0 or len(self.out_services) > 0: + builder.add_messaging() stream = builder.stream() @@ -42,7 +50,7 @@ class TestWebrtcdProc(unittest.IsolatedAsyncioTestCase): self.assertTrue(stream.has_incoming_video_track("road")) self.assertTrue(stream.has_incoming_audio_track()) - self.assertTrue(stream.has_messaging_channel()) + self.assertEqual(stream.has_messaging_channel(), len(self.in_services) > 0 or len(self.out_services) > 0) video_track, audio_track = stream.get_incoming_video_track("road"), stream.get_incoming_audio_track() await self.assertCompletesWithTimeout(video_track.recv()) diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index 6c1370eae9..95c8ae337d 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -128,9 +128,14 @@ class StreamSession: self.stream = builder.stream() self.identifier = str(uuid.uuid4()) - self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services)) - self.incoming_bridge = CerealIncomingMessageProxy(messaging.PubMaster(incoming_services)) - self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) + self.incoming_bridge: CerealIncomingMessageProxy | None = None + self.outgoing_bridge: CerealOutgoingMessageProxy | None = None + self.outgoing_bridge_runner: CerealProxyRunner | None = None + if len(incoming_services) > 0: + self.incoming_bridge = CerealIncomingMessageProxy(messaging.PubMaster(incoming_services)) + if len(outgoing_services) > 0: + self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services)) + self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) self.audio_output: AudioOutputSpeaker | MediaBlackhole | None = None self.run_task: asyncio.Task | None = None @@ -152,6 +157,7 @@ class StreamSession: return await self.stream.start() async def message_handler(self, message: bytes): + assert self.incoming_bridge is not None try: self.incoming_bridge.send(message) except Exception as ex: @@ -161,10 +167,12 @@ class StreamSession: try: await self.stream.wait_for_connection() if self.stream.has_messaging_channel(): - self.stream.set_message_handler(self.message_handler) - channel = self.stream.get_messaging_channel() - self.outgoing_bridge_runner.proxy.add_channel(channel) - self.outgoing_bridge_runner.start() + if self.incoming_bridge is not None: + self.stream.set_message_handler(self.message_handler) + if self.outgoing_bridge_runner is not None: + channel = self.stream.get_messaging_channel() + self.outgoing_bridge_runner.proxy.add_channel(channel) + self.outgoing_bridge_runner.start() if self.stream.has_incoming_audio_track(): track = self.stream.get_incoming_audio_track(buffered=False) self.audio_output = self.audio_output_cls() @@ -181,7 +189,8 @@ class StreamSession: async def post_run_cleanup(self): await self.stream.stop() - self.outgoing_bridge_runner.stop() + if self.outgoing_bridge is not None: + self.outgoing_bridge_runner.stop() if self.audio_output: self.audio_output.stop() diff --git a/teleoprtc_repo b/teleoprtc_repo index 8489ac3c5a..ab2f09706e 160000 --- a/teleoprtc_repo +++ b/teleoprtc_repo @@ -1 +1 @@ -Subproject commit 8489ac3c5aa0b0e2fec397694f9005e2b5a613e4 +Subproject commit ab2f09706e8f64390e196f079ac69e67131b07f5