mirror of https://github.com/1okko/openpilot.git
loggerd: speedup unit tests (#31115)
* first speed it up * pytestify * no sleep
This commit is contained in:
parent
3846130d8e
commit
694fc378dd
|
@ -5,7 +5,6 @@ import random
|
|||
import string
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
@ -31,10 +30,7 @@ CEREAL_SERVICES = [f for f in log.Event.schema.union_fields if f in SERVICE_LIST
|
|||
and SERVICE_LIST[f].should_log and "encode" not in f.lower()]
|
||||
|
||||
|
||||
class TestLoggerd(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ.pop("LOG_ROOT", None)
|
||||
|
||||
class TestLoggerd:
|
||||
def _get_latest_log_dir(self):
|
||||
log_dirs = sorted(Path(Paths.log_root()).iterdir(), key=lambda f: f.stat().st_mtime)
|
||||
return log_dirs[-1]
|
||||
|
@ -68,21 +64,21 @@ class TestLoggerd(unittest.TestCase):
|
|||
|
||||
def _check_init_data(self, msgs):
|
||||
msg = msgs[0]
|
||||
self.assertEqual(msg.which(), 'initData')
|
||||
assert msg.which() == 'initData'
|
||||
|
||||
def _check_sentinel(self, msgs, route):
|
||||
start_type = SentinelType.startOfRoute if route else SentinelType.startOfSegment
|
||||
self.assertTrue(msgs[1].sentinel.type == start_type)
|
||||
assert msgs[1].sentinel.type == start_type
|
||||
|
||||
end_type = SentinelType.endOfRoute if route else SentinelType.endOfSegment
|
||||
self.assertTrue(msgs[-1].sentinel.type == end_type)
|
||||
assert msgs[-1].sentinel.type == end_type
|
||||
|
||||
def _publish_random_messages(self, services: List[str]) -> Dict[str, list]:
|
||||
pm = messaging.PubMaster(services)
|
||||
|
||||
managed_processes["loggerd"].start()
|
||||
for s in services:
|
||||
self.assertTrue(pm.wait_for_readers_to_update(s, timeout=5))
|
||||
assert pm.wait_for_readers_to_update(s, timeout=5)
|
||||
|
||||
sent_msgs = defaultdict(list)
|
||||
for _ in range(random.randint(2, 10) * 100):
|
||||
|
@ -93,10 +89,9 @@ class TestLoggerd(unittest.TestCase):
|
|||
m = messaging.new_message(s, random.randint(2, 10))
|
||||
pm.send(s, m)
|
||||
sent_msgs[s].append(m)
|
||||
time.sleep(0.01)
|
||||
|
||||
for s in services:
|
||||
self.assertTrue(pm.wait_for_readers_to_update(s, timeout=5))
|
||||
assert pm.wait_for_readers_to_update(s, timeout=5)
|
||||
managed_processes["loggerd"].stop()
|
||||
|
||||
return sent_msgs
|
||||
|
@ -121,15 +116,15 @@ class TestLoggerd(unittest.TestCase):
|
|||
lr = list(LogReader(str(self._gen_bootlog())))
|
||||
initData = lr[0].initData
|
||||
|
||||
self.assertTrue(initData.dirty != bool(os.environ["CLEAN"]))
|
||||
self.assertEqual(initData.version, get_version())
|
||||
assert initData.dirty != bool(os.environ["CLEAN"])
|
||||
assert initData.version == get_version()
|
||||
|
||||
if os.path.isfile("/proc/cmdline"):
|
||||
with open("/proc/cmdline") as f:
|
||||
self.assertEqual(list(initData.kernelArgs), f.read().strip().split(" "))
|
||||
assert list(initData.kernelArgs) == f.read().strip().split(" ")
|
||||
|
||||
with open("/proc/version") as f:
|
||||
self.assertEqual(initData.kernelVersion, f.read())
|
||||
assert initData.kernelVersion == f.read()
|
||||
|
||||
# check params
|
||||
logged_params = {entry.key: entry.value for entry in initData.params.entries}
|
||||
|
@ -137,8 +132,8 @@ class TestLoggerd(unittest.TestCase):
|
|||
assert set(logged_params.keys()) == expected_params, set(logged_params.keys()) ^ expected_params
|
||||
assert logged_params['AccessToken'] == b'', f"DONT_LOG param value was logged: {repr(logged_params['AccessToken'])}"
|
||||
for param_key, initData_key, v in fake_params:
|
||||
self.assertEqual(getattr(initData, initData_key), v)
|
||||
self.assertEqual(logged_params[param_key].decode(), v)
|
||||
assert getattr(initData, initData_key) == v
|
||||
assert logged_params[param_key].decode() == v
|
||||
|
||||
params.put("AccessToken", "")
|
||||
|
||||
|
@ -162,11 +157,10 @@ class TestLoggerd(unittest.TestCase):
|
|||
os.environ["LOGGERD_SEGMENT_LENGTH"] = str(length)
|
||||
managed_processes["loggerd"].start()
|
||||
managed_processes["encoderd"].start()
|
||||
time.sleep(1)
|
||||
assert pm.wait_for_readers_to_update("roadCameraState", timeout=5)
|
||||
|
||||
fps = 20.0
|
||||
for n in range(1, int(num_segs*length*fps)+1):
|
||||
time_start = time.monotonic()
|
||||
for stream_type, frame_spec, state in streams:
|
||||
dat = np.empty(frame_spec[2], dtype=np.uint8)
|
||||
vipc_server.send(stream_type, dat[:].flatten().tobytes(), n, n/fps, n/fps)
|
||||
|
@ -175,7 +169,9 @@ class TestLoggerd(unittest.TestCase):
|
|||
frame = getattr(camera_state, state)
|
||||
frame.frameId = n
|
||||
pm.send(state, camera_state)
|
||||
time.sleep(max((1.0/fps) - (time.monotonic() - time_start), 0))
|
||||
|
||||
for _, _, state in streams:
|
||||
assert pm.wait_for_readers_to_update(state, timeout=5, dt=0.001)
|
||||
|
||||
managed_processes["loggerd"].stop()
|
||||
managed_processes["encoderd"].stop()
|
||||
|
@ -185,7 +181,7 @@ class TestLoggerd(unittest.TestCase):
|
|||
p = Path(f"{route_path}--{n}")
|
||||
logged = {f.name for f in p.iterdir() if f.is_file()}
|
||||
diff = logged ^ expected_files
|
||||
self.assertEqual(len(diff), 0, f"didn't get all expected files. run={_} seg={n} {route_path=}, {diff=}\n{logged=} {expected_files=}")
|
||||
assert len(diff) == 0, f"didn't get all expected files. run={_} seg={n} {route_path=}, {diff=}\n{logged=} {expected_files=}"
|
||||
|
||||
def test_bootlog(self):
|
||||
# generate bootlog with fake launch log
|
||||
|
@ -216,7 +212,7 @@ class TestLoggerd(unittest.TestCase):
|
|||
with open(path, "rb") as f:
|
||||
expected_val = f.read()
|
||||
bootlog_val = [e.value for e in boot.pstore.entries if e.key == fn][0]
|
||||
self.assertEqual(expected_val, bootlog_val)
|
||||
assert expected_val == bootlog_val
|
||||
|
||||
def test_qlog(self):
|
||||
qlog_services = [s for s in CEREAL_SERVICES if SERVICE_LIST[s].decimation is not None]
|
||||
|
@ -242,11 +238,11 @@ class TestLoggerd(unittest.TestCase):
|
|||
|
||||
if s in no_qlog_services:
|
||||
# check services with no specific decimation aren't in qlog
|
||||
self.assertEqual(recv_cnt, 0, f"got {recv_cnt} {s} msgs in qlog")
|
||||
assert recv_cnt == 0, f"got {recv_cnt} {s} msgs in qlog"
|
||||
else:
|
||||
# check logged message count matches decimation
|
||||
expected_cnt = (len(msgs) - 1) // SERVICE_LIST[s].decimation + 1
|
||||
self.assertEqual(recv_cnt, expected_cnt, f"expected {expected_cnt} msgs for {s}, got {recv_cnt}")
|
||||
assert recv_cnt == expected_cnt, f"expected {expected_cnt} msgs for {s}, got {recv_cnt}"
|
||||
|
||||
def test_rlog(self):
|
||||
services = random.sample(CEREAL_SERVICES, random.randint(5, 10))
|
||||
|
@ -263,22 +259,19 @@ class TestLoggerd(unittest.TestCase):
|
|||
for m in lr:
|
||||
sent = sent_msgs[m.which()].pop(0)
|
||||
sent.clear_write_flag()
|
||||
self.assertEqual(sent.to_bytes(), m.as_builder().to_bytes())
|
||||
assert sent.to_bytes() == m.as_builder().to_bytes()
|
||||
|
||||
def test_preserving_flagged_segments(self):
|
||||
services = set(random.sample(CEREAL_SERVICES, random.randint(5, 10))) | {"userFlag"}
|
||||
self._publish_random_messages(services)
|
||||
|
||||
segment_dir = self._get_latest_log_dir()
|
||||
self.assertEqual(getxattr(segment_dir, PRESERVE_ATTR_NAME), PRESERVE_ATTR_VALUE)
|
||||
assert getxattr(segment_dir, PRESERVE_ATTR_NAME) == PRESERVE_ATTR_VALUE
|
||||
|
||||
def test_not_preserving_unflagged_segments(self):
|
||||
services = set(random.sample(CEREAL_SERVICES, random.randint(5, 10))) - {"userFlag"}
|
||||
self._publish_random_messages(services)
|
||||
|
||||
segment_dir = self._get_latest_log_dir()
|
||||
self.assertIsNone(getxattr(segment_dir, PRESERVE_ATTR_NAME))
|
||||
assert getxattr(segment_dir, PRESERVE_ATTR_NAME) is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue