mirror of https://github.com/commaai/openpilot.git
alerts: handle min duration properly (#23191)
* alerts: handle min duration properly
* add active
* tests
* cleanup test
* update refs
old-commit-hash: 07b971d473
This commit is contained in:
parent
853dc0d016
commit
700ad9ec50
|
@ -614,7 +614,7 @@ class Controls:
|
|||
|
||||
clear_event = ET.WARNING if ET.WARNING not in self.current_alert_types else None
|
||||
alerts = self.events.create_alerts(self.current_alert_types, [self.CP, self.sm, self.is_metric])
|
||||
self.AM.add_many(self.sm.frame, alerts, self.enabled)
|
||||
self.AM.add_many(self.sm.frame, alerts)
|
||||
self.AM.process_alerts(self.sm.frame, clear_event)
|
||||
CC.hudControl.visualAlert = self.AM.visual_alert
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from typing import List, Dict, Optional
|
|||
from cereal import car, log
|
||||
from common.basedir import BASEDIR
|
||||
from common.params import Params
|
||||
from common.realtime import DT_CTRL
|
||||
from selfdrive.controls.lib.events import Alert
|
||||
|
||||
|
||||
|
@ -33,12 +32,14 @@ class AlertEntry:
|
|||
start_frame: int = -1
|
||||
end_frame: int = -1
|
||||
|
||||
def active(self, frame: int) -> bool:
|
||||
return frame <= self.end_frame
|
||||
|
||||
class AlertManager:
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
self.activealerts: Dict[str, AlertEntry] = defaultdict(AlertEntry)
|
||||
self.alerts: Dict[str, AlertEntry] = defaultdict(AlertEntry)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.alert: Optional[Alert] = None
|
||||
|
@ -51,25 +52,27 @@ class AlertManager:
|
|||
self.audible_alert = car.CarControl.HUDControl.AudibleAlert.none
|
||||
self.alert_rate: float = 0.
|
||||
|
||||
def add_many(self, frame: int, alerts: List[Alert], enabled: bool = True) -> None:
|
||||
def add_many(self, frame: int, alerts: List[Alert]) -> None:
|
||||
for alert in alerts:
|
||||
self.activealerts[alert.alert_type].alert = alert
|
||||
self.activealerts[alert.alert_type].start_frame = frame
|
||||
self.activealerts[alert.alert_type].end_frame = frame + int(alert.duration / DT_CTRL)
|
||||
key = alert.alert_type
|
||||
self.alerts[key].alert = alert
|
||||
if not self.alerts[key].active(frame):
|
||||
self.alerts[key].start_frame = frame
|
||||
min_end_frame = self.alerts[key].start_frame + alert.duration
|
||||
self.alerts[key].end_frame = max(frame + 1, min_end_frame)
|
||||
|
||||
def process_alerts(self, frame: int, clear_event_type=None) -> None:
|
||||
current_alert = AlertEntry()
|
||||
for k, v in self.activealerts.items():
|
||||
for k, v in self.alerts.items():
|
||||
if v.alert is None:
|
||||
continue
|
||||
|
||||
if v.alert.event_type == clear_event_type:
|
||||
self.activealerts[k].end_frame = -1
|
||||
if clear_event_type is not None and v.alert.event_type == clear_event_type:
|
||||
self.alerts[k].end_frame = -1
|
||||
|
||||
# sort by priority first and then by start_frame
|
||||
active = self.activealerts[k].end_frame > frame
|
||||
greater = current_alert.alert is None or (v.alert.priority, v.start_frame) > (current_alert.alert.priority, current_alert.start_frame)
|
||||
if active and greater:
|
||||
if v.active(frame) and greater:
|
||||
current_alert = v
|
||||
|
||||
# clear current alert
|
||||
|
|
|
@ -123,7 +123,7 @@ class Alert:
|
|||
self.visual_alert = visual_alert
|
||||
self.audible_alert = audible_alert
|
||||
|
||||
self.duration = duration
|
||||
self.duration = int(duration / DT_CTRL)
|
||||
|
||||
self.alert_rate = alert_rate
|
||||
self.creation_delay = creation_delay
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env python3
|
||||
import random
|
||||
import unittest
|
||||
|
||||
from selfdrive.controls.lib.events import Alert, EVENTS
|
||||
from selfdrive.controls.lib.alertmanager import AlertManager
|
||||
|
||||
|
||||
class TestAlertManager(unittest.TestCase):
|
||||
|
||||
def test_duration(self):
|
||||
"""
|
||||
Enforce that an alert lasts for max(alert duration, duration the alert is added)
|
||||
"""
|
||||
for duration in range(1, 100):
|
||||
alert = None
|
||||
while not isinstance(alert, Alert):
|
||||
event = random.choice([e for e in EVENTS.values() if len(e)])
|
||||
alert = random.choice(list(event.values()))
|
||||
|
||||
alert.duration = duration
|
||||
|
||||
# check two cases:
|
||||
# - alert is added to AM for <= the alert's duration
|
||||
# - alert is added to AM for > alert's duration
|
||||
for greater in (True, False):
|
||||
if greater:
|
||||
add_duration = duration + random.randint(1, 10)
|
||||
else:
|
||||
add_duration = random.randint(1, duration)
|
||||
show_duration = max(duration, add_duration)
|
||||
|
||||
AM = AlertManager()
|
||||
for frame in range(duration+10):
|
||||
if frame < add_duration:
|
||||
AM.add_many(frame, [alert, ])
|
||||
AM.process_alerts(frame)
|
||||
|
||||
shown = AM.alert is not None
|
||||
should_show = frame <= show_duration
|
||||
self.assertEqual(shown, should_show, msg=f"{frame=} {add_duration=} {duration=}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -1 +1 @@
|
|||
c09fc7b1409529a9991428845ee14f0e37d95b2d
|
||||
0ae46ae318a63476d8905aa0c32b0e587177868a
|
Loading…
Reference in New Issue