Files
onepilot/sunnypilot/models/manager.py
github-actions[bot] 7fa972be6a sunnypilot v2026.02.09-4080
version: sunnypilot v2025.003.000 (dev)
date: 2026-02-09T02:04:38
master commit: 254f55ac15a40343d7255f2f098de3442e0c4a6f
2026-02-09 02:04:38 +00:00

227 lines
8.7 KiB
Python

"""
Copyright (c) 2021-, Haibin Wen, sunnypilot, and a number of other contributors.
This file is part of sunnypilot and is licensed under the MIT License.
See the LICENSE.md file in the root directory for more details.
"""
import asyncio
import os
import time
import aiohttp
from openpilot.common.params import Params
from openpilot.common.realtime import Ratekeeper
from openpilot.common.swaglog import cloudlog
from openpilot.system.hardware.hw import Paths
from cereal import messaging, custom
from openpilot.sunnypilot.models.fetcher import ModelFetcher
from openpilot.sunnypilot.models.helpers import verify_file, get_active_bundle
class ModelManagerSP:
"""Manages model downloads and status reporting"""
def __init__(self):
self.params = Params()
self.model_fetcher = ModelFetcher(self.params)
self.pm = messaging.PubMaster(["modelManagerSP"])
self.available_models: list[custom.ModelManagerSP.ModelBundle] = []
self.selected_bundle: custom.ModelManagerSP.ModelBundle = None
self.active_bundle: custom.ModelManagerSP.ModelBundle = get_active_bundle(self.params)
self._chunk_size = 128 * 1000 # 128 KB chunks
self._download_start_times: dict[str, float] = {} # Track start time per model
def _calculate_eta(self, filename: str, progress: float) -> int:
"""Calculate ETA based on elapsed time and current progress"""
if filename not in self._download_start_times or progress <= 0:
return 60 # Default ETA for new downloads
elapsed_time = time.monotonic() - self._download_start_times[filename]
if elapsed_time <= 0:
return 60
# If we're at X% after Y seconds, we can estimate total time as (Y / X) * 100
total_estimated_time = (elapsed_time / progress) * 100
eta = total_estimated_time - elapsed_time
return max(1, int(eta)) # Return at least 1 second if download is ongoing
async def _download_file(self, url: str, path: str, model) -> None:
"""Downloads a file with progress tracking"""
self._download_start_times[model.fileName] = time.monotonic()
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
bytes_downloaded = 0
with open(path, 'wb') as f:
async for chunk in response.content.iter_chunked(self._chunk_size): # type: bytes
f.write(chunk)
bytes_downloaded += len(chunk)
if not self.params.get("ModelManager_DownloadIndex"):
raise Exception("Download cancelled")
if total_size > 0:
progress = (bytes_downloaded / total_size) * 100
model.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloading
model.downloadProgress.progress = progress
model.downloadProgress.eta = self._calculate_eta(model.fileName, progress)
self._report_status()
# Clean up start time after download completes
del self._download_start_times[model.fileName]
async def _process_artifact(self, artifact, destination_path: str) -> None:
"""Processes a single model download including verification"""
if not artifact.downloadUri.uri:
return None
url = artifact.downloadUri.uri
expected_hash = artifact.downloadUri.sha256
filename = artifact.fileName
full_path = os.path.join(destination_path, filename)
try:
# Check existing file
if os.path.exists(full_path) and await verify_file(full_path, expected_hash):
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.cached
artifact.downloadProgress.progress = 100
artifact.downloadProgress.eta = 0
self._report_status()
return
# Download and verify
await self._download_file(url, full_path, artifact)
if not await verify_file(full_path, expected_hash):
raise ValueError(f"Hash validation failed for {filename}")
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloaded
artifact.downloadProgress.eta = 0
self._report_status()
except Exception as e:
cloudlog.error(f"Error downloading {filename}: {str(e)}")
if os.path.exists(full_path):
os.remove(full_path)
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.failed
artifact.downloadProgress.eta = 0
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
self._report_status()
# Clean up start time if it exists
self._download_start_times.pop(artifact.fileName, None)
raise
async def _process_model(self, model, destination_path: str) -> None:
"""Processes a single model download including verification"""
model_artifact = model.artifact
metadata_artifact = model.metadata
await self._process_artifact(metadata_artifact, destination_path)
await self._process_artifact(model_artifact, destination_path)
def _report_status(self) -> None:
"""Reports current status through messaging system"""
msg = messaging.new_message('modelManagerSP', valid=True)
model_manager_state = msg.modelManagerSP
if self.selected_bundle:
model_manager_state.selectedBundle = self.selected_bundle
if self.active_bundle:
model_manager_state.activeBundle = self.active_bundle
model_manager_state.availableBundles = self.available_models
self.pm.send('modelManagerSP', msg)
async def _download_bundle(self, model_bundle: custom.ModelManagerSP.ModelBundle, destination_path: str) -> None:
"""Downloads all models in a bundle"""
self.selected_bundle = model_bundle
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.downloading
os.makedirs(destination_path, exist_ok=True)
try:
tasks = [self._process_model(model, destination_path) for model in self.selected_bundle.models]
await asyncio.gather(*tasks)
self.active_bundle = self.selected_bundle
self.active_bundle.status = custom.ModelManagerSP.DownloadStatus.downloaded
self.params.put("ModelManager_ActiveBundle", self.active_bundle.to_dict())
self.selected_bundle = None
except Exception:
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
raise
finally:
self._report_status()
def download(self, model_bundle: custom.ModelManagerSP.ModelBundle, destination_path: str) -> None:
"""Main entry point for downloading a model bundle"""
asyncio.run(self._download_bundle(model_bundle, destination_path))
def main_thread(self) -> None:
"""Main thread for model management"""
rk = Ratekeeper(1, print_delay_threshold=None)
while True:
try:
self.available_models = self.model_fetcher.get_available_bundles()
self.active_bundle = get_active_bundle(self.params)
if index_to_download := self.params.get("ModelManager_DownloadIndex"):
if model_to_download := next((model for model in self.available_models if model.index == index_to_download), None):
try:
self.download(model_to_download, Paths.model_root())
except Exception as e:
cloudlog.exception(e)
finally:
self.params.remove("ModelManager_DownloadIndex")
self.selected_bundle = None
if self.params.get("ModelManager_ClearCache"):
self.clear_model_cache()
self.params.remove("ModelManager_ClearCache")
self._report_status()
rk.keep_time()
except Exception as e:
cloudlog.exception(f"Error in main thread: {str(e)}")
rk.keep_time()
def clear_model_cache(self) -> None:
"""
Clears the model cache directory of all files except those in the active model bundle.
"""
# Get list of files used by active model bundle
active_files = []
if self.active_bundle is not None: # When the default model is active
for model in self.active_bundle.models:
if hasattr(model, 'artifact') and model.artifact.fileName:
active_files.append(model.artifact.fileName)
if hasattr(model, 'metadata') and model.metadata.fileName:
active_files.append(model.metadata.fileName)
# Remove all files except active ones
model_dir = Paths.model_root()
try:
for filename in os.listdir(model_dir):
if filename not in active_files:
file_path = os.path.join(model_dir, filename)
if os.path.isfile(file_path):
os.remove(file_path)
cloudlog.info("Model cache cleared, keeping active model files")
except Exception as e:
cloudlog.exception(f"Error clearing model cache: {str(e)}")
def main():
ModelManagerSP().main_thread()
if __name__ == "__main__":
main()