import abc import asyncio import dataclasses import logging from typing import Callable, Awaitable, Dict, List, Any, Optional import aiortc from aiortc.contrib.media import MediaRelay from teleoprtc.tracks import parse_video_track_id @dataclasses.dataclass class StreamingOffer: sdp: str video: List[str] ConnectionProvider = Callable[[StreamingOffer], Awaitable[aiortc.RTCSessionDescription]] MessageHandler = Callable[[bytes], Awaitable[None]] class WebRTCBaseStream(abc.ABC): def __init__(self, consumed_camera_types: List[str], consume_audio: bool, video_producer_tracks: List[aiortc.MediaStreamTrack], audio_producer_tracks: List[aiortc.MediaStreamTrack], should_add_data_channel: bool): self.peer_connection = aiortc.RTCPeerConnection() self.media_relay = MediaRelay() self.expected_incoming_camera_types = consumed_camera_types self.expected_incoming_audio = consume_audio self.expected_number_of_incoming_media: Optional[int] = None self.incoming_camera_tracks: Dict[str, aiortc.MediaStreamTrack] = dict() self.incoming_audio_tracks: List[aiortc.MediaStreamTrack] = [] self.outgoing_video_tracks: List[aiortc.MediaStreamTrack] = video_producer_tracks self.outgoing_audio_tracks: List[aiortc.MediaStreamTrack] = audio_producer_tracks self.should_add_data_channel = should_add_data_channel self.messaging_channel: Optional[aiortc.RTCDataChannel] = None self.incoming_message_handlers: List[MessageHandler] = [] self.incoming_media_ready_event = asyncio.Event() self.messaging_channel_ready_event = asyncio.Event() self.connection_attempted_event = asyncio.Event() self.connection_stopped_event = asyncio.Event() self.peer_connection.on("connectionstatechange", self._on_connectionstatechange) self.peer_connection.on("datachannel", self._on_incoming_datachannel) self.peer_connection.on("track", self._on_incoming_track) self.logger = logging.getLogger("WebRTCStream") def _log_debug(self, msg: Any, *args): self.logger.debug(f"{type(self)}() {msg}", *args) @property def _number_of_incoming_media(self) -> int: media = len(self.incoming_camera_tracks) + len(self.incoming_audio_tracks) # if stream does not add data_channel, then it means its incoming media += int(self.messaging_channel is not None) if not self.should_add_data_channel else 0 return media def _add_consumer_transceivers(self): for _ in self.expected_incoming_camera_types: self.peer_connection.addTransceiver("video", direction="recvonly") if self.expected_incoming_audio: self.peer_connection.addTransceiver("audio", direction="recvonly") def _find_trackless_transceiver(self, kind: str) -> Optional[aiortc.RTCRtpTransceiver]: transceivers = self.peer_connection.getTransceivers() target_transceiver = None for t in transceivers: if t.kind == kind and t.sender.track is None: target_transceiver = t break return target_transceiver def _add_producer_tracks(self): for track in self.outgoing_video_tracks: target_transceiver = self._find_trackless_transceiver(track.kind) if target_transceiver is None: self.peer_connection.addTransceiver(track.kind, direction="sendonly") sender = self.peer_connection.addTrack(track) if hasattr(track, "codec_preference") and track.codec_preference() is not None: transceiver = next(t for t in self.peer_connection.getTransceivers() if t.sender == sender) self._force_codec(transceiver, track.codec_preference(), "video") for track in self.outgoing_audio_tracks: target_transceiver = self._find_trackless_transceiver(track.kind) if target_transceiver is None: self.peer_connection.addTransceiver(track.kind, direction="sendonly") self.peer_connection.addTrack(track) def _add_messaging_channel(self, channel: Optional[aiortc.RTCDataChannel] = None): if not channel: channel = self.peer_connection.createDataChannel("data", ordered=True) for handler in self.incoming_message_handlers: channel.on("message", handler) if channel.readyState == "open": self.messaging_channel_ready_event.set() else: channel.on("open", lambda: self.messaging_channel_ready_event.set()) self.messaging_channel = channel def _force_codec(self, transceiver: aiortc.RTCRtpTransceiver, codec: str, stream_type: str): codec_mime = f"{stream_type}/{codec.upper()}" rtp_codecs = aiortc.RTCRtpSender.getCapabilities(stream_type).codecs rtp_codec = [c for c in rtp_codecs if c.mimeType == codec_mime] transceiver.setCodecPreferences(rtp_codec) def _on_connectionstatechange(self): self._log_debug("connection state is %s", self.peer_connection.connectionState) if self.peer_connection.connectionState in ['connected', 'failed']: self.connection_attempted_event.set() if self.peer_connection.connectionState in ['disconnected', 'closed', 'failed']: self.connection_stopped_event.set() def _on_incoming_track(self, track: aiortc.MediaStreamTrack): self._log_debug("got track: %s %s", track.kind, track.id) if track.kind == "video": camera_type, _ = parse_video_track_id(track.id) if camera_type in self.expected_incoming_camera_types: self.incoming_camera_tracks[camera_type] = track elif track.kind == "audio": if self.expected_incoming_audio: self.incoming_audio_tracks.append(track) self._on_after_media() def _on_incoming_datachannel(self, channel: aiortc.RTCDataChannel): self._log_debug("got data channel: %s", channel.label) if channel.label == "data" and self.messaging_channel is None: self._add_messaging_channel(channel) self._on_after_media() def _on_after_media(self): if self._number_of_incoming_media == self.expected_number_of_incoming_media: self.incoming_media_ready_event.set() def _parse_incoming_streams(self, remote_sdp: str): desc = aiortc.sdp.SessionDescription.parse(remote_sdp) sending_medias = [m for m in desc.media if m.direction in ["sendonly", "sendrecv"]] incoming_media_count = len(sending_medias) if not self.should_add_data_channel: channel_medias = [m for m in desc.media if m.kind == "application"] incoming_media_count += len(channel_medias) self.expected_number_of_incoming_media = incoming_media_count def has_incoming_video_track(self, camera_type: str) -> bool: return camera_type in self.incoming_camera_tracks def has_incoming_audio_track(self) -> bool: return len(self.incoming_audio_tracks) > 0 def has_messaging_channel(self) -> bool: return self.messaging_channel is not None def get_incoming_video_track(self, camera_type: str, buffered: bool = False) -> aiortc.MediaStreamTrack: assert camera_type in self.incoming_camera_tracks, "Video tracks are not enabled on this stream" assert self.is_started, "Stream must be started" track = self.incoming_camera_tracks[camera_type] relay_track = self.media_relay.subscribe(track, buffered=buffered) return relay_track def get_incoming_audio_track(self, buffered: bool = False) -> aiortc.MediaStreamTrack: assert len(self.incoming_audio_tracks) > 0, "Audio tracks are not enabled on this stream" assert self.is_started, "Stream must be started" track = self.incoming_audio_tracks[0] relay_track = self.media_relay.subscribe(track, buffered=buffered) return relay_track def get_messaging_channel(self) -> aiortc.RTCDataChannel: assert self.messaging_channel is not None, "Messaging channel is not enabled on this stream" assert self.is_started, "Stream must be started" return self.messaging_channel def set_message_handler(self, message_handler: MessageHandler): self.incoming_message_handlers.append(message_handler) if self.messaging_channel is not None: self.messaging_channel.on("message", message_handler) @property def is_started(self) -> bool: return self.peer_connection is not None and \ self.peer_connection.localDescription is not None and \ self.peer_connection.remoteDescription is not None and \ self.peer_connection.connectionState != "closed" @property def is_connected_and_ready(self) -> bool: return self.peer_connection is not None and \ self.peer_connection.connectionState == "connected" and \ self.expected_number_of_incoming_media != 0 and self.incoming_media_ready_event.is_set() async def wait_for_connection(self): assert self.is_started await self.connection_attempted_event.wait() if self.peer_connection.connectionState != 'connected': raise ValueError("Connection failed.") if self.expected_number_of_incoming_media: await self.incoming_media_ready_event.wait() if self.messaging_channel is not None: await self.messaging_channel_ready_event.wait() async def wait_for_disconnection(self): assert self.is_connected_and_ready await self.connection_stopped_event.wait() async def stop(self): await self.peer_connection.close() @abc.abstractmethod async def start(self) -> aiortc.RTCSessionDescription: raise NotImplementedError class WebRTCOfferStream(WebRTCBaseStream): def __init__(self, session_provider: ConnectionProvider, *args, **kwargs): super().__init__(*args, **kwargs) self.session_provider = session_provider async def start(self) -> aiortc.RTCSessionDescription: self._add_consumer_transceivers() if self.should_add_data_channel: self._add_messaging_channel() self._add_producer_tracks() offer = await self.peer_connection.createOffer() await self.peer_connection.setLocalDescription(offer) actual_offer = self.peer_connection.localDescription streaming_offer = StreamingOffer( sdp=actual_offer.sdp, video=list(self.expected_incoming_camera_types), ) remote_answer = await self.session_provider(streaming_offer) self._parse_incoming_streams(remote_sdp=remote_answer.sdp) await self.peer_connection.setRemoteDescription(remote_answer) actual_answer = self.peer_connection.remoteDescription return actual_answer class WebRTCAnswerStream(WebRTCBaseStream): def __init__(self, session: aiortc.RTCSessionDescription, *args, **kwargs): super().__init__(*args, **kwargs) self.session = session def _probe_video_codecs(self) -> List[str]: codecs = [] for track in self.outgoing_video_tracks: if hasattr(track, "codec_preference") and track.codec_preference() is not None: codecs.append(track.codec_preference()) return codecs def _override_incoming_video_codecs(self, remote_sdp: str, codecs: List[str]) -> str: desc = aiortc.sdp.SessionDescription.parse(remote_sdp) codec_mimes = [f"video/{c}" for c in codecs] for m in desc.media: if m.kind != "video": continue preferred_codecs: List[aiortc.RTCRtpCodecParameters] = [c for c in m.rtp.codecs if c.mimeType in codec_mimes] if len(preferred_codecs) == 0: raise ValueError(f"None of {preferred_codecs} codecs is supported in remote SDP") m.rtp.codecs = preferred_codecs m.fmt = [c.payloadType for c in preferred_codecs] return str(desc) async def start(self) -> aiortc.RTCSessionDescription: assert self.peer_connection.remoteDescription is None, "Connection already established" self._add_consumer_transceivers() # since we sent already encoded frames in some cases (e.g. livestream video tracks are in H264), we need to force aiortc to actually use it # we do that by overriding supported codec information on incoming sdp preferred_codecs = self._probe_video_codecs() if len(preferred_codecs) > 0: self.session.sdp = self._override_incoming_video_codecs(self.session.sdp, preferred_codecs) self._parse_incoming_streams(remote_sdp=self.session.sdp) await self.peer_connection.setRemoteDescription(self.session) self._add_producer_tracks() answer = await self.peer_connection.createAnswer() await self.peer_connection.setLocalDescription(answer) actual_answer = self.peer_connection.localDescription return actual_answer