openpilot1/selfdrive/frogpilot/assets/model_manager.py

269 lines
12 KiB
Python

import json
import os
import re
import requests
import shutil
import time
import urllib.parse
import urllib.request
from openpilot.common.basedir import BASEDIR
from openpilot.common.params import Params
from openpilot.selfdrive.frogpilot.assets.download_functions import GITHUB_URL, GITLAB_URL, download_file, get_repository_url, handle_error, handle_request_error, verify_download
from openpilot.selfdrive.frogpilot.frogpilot_functions import MODELS_PATH
from openpilot.selfdrive.frogpilot.frogpilot_utilities import delete_file
VERSION = "v10"
DEFAULT_MODEL = "north-dakota"
DEFAULT_MODEL_NAME = "North Dakota (Default)"
def process_model_name(model_name):
cleaned_name = re.sub(r'[🗺️👀📡]', '', model_name)
cleaned_name = re.sub(r'[^a-zA-Z0-9()-]', '', cleaned_name)
return cleaned_name.replace(' ', '').replace('(Default)', '').replace('-', '')
class ModelManager:
def __init__(self):
self.params = Params()
self.params_memory = Params("/dev/shm/params")
self.cancel_download_param = "CancelModelDownload"
self.download_param = "ModelToDownload"
self.download_progress_param = "ModelDownloadProgress"
@staticmethod
def fetch_models(url):
try:
with urllib.request.urlopen(url, timeout=10) as response:
return json.loads(response.read().decode('utf-8'))['models']
except Exception as error:
handle_request_error(error, None, None, None, None)
return []
@staticmethod
def fetch_all_model_sizes(repo_url):
project_path = "FrogAi/FrogPilot-Resources"
branch = "Models"
if "github" in repo_url:
api_url = f"https://api.github.com/repos/{project_path}/contents?ref={branch}"
elif "gitlab" in repo_url:
api_url = f"https://gitlab.com/api/v4/projects/{urllib.parse.quote_plus(project_path)}/repository/tree?ref={branch}"
else:
raise ValueError(f"Unsupported repository URL format: {repo_url}. Supported formats are GitHub and GitLab URLs.")
try:
response = requests.get(api_url)
response.raise_for_status()
thneed_files = [file for file in response.json() if file['name'].endswith('.thneed')]
if "gitlab" in repo_url:
model_sizes = {}
for file in thneed_files:
file_path = file['path']
metadata_url = f"https://gitlab.com/api/v4/projects/{urllib.parse.quote_plus(project_path)}/repository/files/{urllib.parse.quote_plus(file_path)}/raw?ref={branch}"
metadata_response = requests.head(metadata_url)
metadata_response.raise_for_status()
model_sizes[file['name'].replace('.thneed', '')] = int(metadata_response.headers.get('content-length', 0))
return model_sizes
else:
return {file['name'].replace('.thneed', ''): file['size'] for file in thneed_files if 'size' in file}
except:
return {}
@staticmethod
def copy_default_model():
classic_default_model_path = os.path.join(MODELS_PATH, f"{DEFAULT_MODEL}.thneed")
source_path = os.path.join(BASEDIR, "selfdrive", "classic_modeld", "models", "supercombo.thneed")
if os.path.isfile(source_path):
shutil.copyfile(source_path, classic_default_model_path)
print(f"Copied the classic default model from {source_path} to {classic_default_model_path}")
default_model_path = os.path.join(MODELS_PATH, "secret-good-openpilot.thneed")
source_path = os.path.join(BASEDIR, "selfdrive", "modeld", "models", "supercombo.thneed")
if os.path.isfile(source_path):
shutil.copyfile(source_path, default_model_path)
print(f"Copied the default model from {source_path} to {default_model_path}")
def handle_verification_failure(self, model, model_path, temp_model_path):
if self.params_memory.get_bool(self.cancel_download_param):
return
print(f"Verification failed for model {model}. Retrying from GitLab...")
model_url = f"{GITLAB_URL}Models/{model}.thneed"
download_file(self.cancel_download_param, model_path, temp_model_path, self.download_progress_param, model_url, self.download_param, self.params_memory)
if verify_download(model_path, temp_model_path, model_url):
print(f"Model {model} redownloaded and verified successfully from GitLab.")
else:
handle_error(model_path, "GitLab verification failed", "Verification failed", self.download_param, self.download_progress_param, self.params_memory)
def download_model(self, model_to_download):
model_path = os.path.join(MODELS_PATH, f"{model_to_download}.thneed")
temp_model_path = f"{os.path.splitext(model_path)[0]}_temp.thneed"
if os.path.isfile(model_path):
handle_error(model_path, "Model already exists...", "Model already exists...", self.download_param, self.download_progress_param, self.params_memory)
return
repo_url = get_repository_url()
if not repo_url:
handle_error(temp_model_path, "GitHub and GitLab are offline...", "Repository unavailable", self.download_param, self.download_progress_param, self.params_memory)
return
model_url = f"{repo_url}Models/{model_to_download}.thneed"
print(f"Downloading model: {model_to_download}")
download_file(self.cancel_download_param, model_path, temp_model_path, self.download_progress_param, model_url, self.download_param, self.params_memory)
if verify_download(model_path, temp_model_path, model_url):
print(f"Model {model_to_download} downloaded and verified successfully!")
self.params_memory.put(self.download_progress_param, "Downloaded!")
self.params_memory.remove(self.download_param)
else:
self.handle_verification_failure(model_to_download, model_path, temp_model_path)
def queue_model_download(self, model, model_name=None):
while self.params_memory.get(self.download_param, encoding='utf-8'):
time.sleep(1)
self.params_memory.put(self.download_param, model)
if model_name:
self.params_memory.put(self.download_progress_param, f"Downloading {model_name}...")
def update_model_params(self, model_info, repo_url):
available_models = []
for model in model_info:
available_models.append(model['id'])
self.params.put_nonblocking("AvailableModels", ','.join(available_models))
self.params.put_nonblocking("AvailableModelsNames", ','.join([model['name'] for model in model_info]))
self.params.put_nonblocking("ClassicModels", ','.join([model['id'] for model in model_info if model.get("classic_model", False)]))
self.params.put_nonblocking("ExperimentalModels", ','.join([model['id'] for model in model_info if model.get("experimental", False)]))
self.params.put_nonblocking("NavigationModels", ','.join([model['id'] for model in model_info if "🗺️" in model['name']]))
self.params.put_nonblocking("RadarlessModels", ','.join([model['id'] for model in model_info if "📡" not in model['name']]))
print("Models list updated successfully.")
if available_models:
models_downloaded = self.are_all_models_downloaded(available_models, repo_url)
self.params.put_bool_nonblocking("ModelsDownloaded", models_downloaded)
def are_all_models_downloaded(self, available_models, repo_url):
automatically_update_models = self.params.get_bool("AutomaticallyUpdateModels")
all_models_downloaded = True
model_sizes = self.fetch_all_model_sizes(repo_url)
download_queue = []
for model in available_models:
model_path = os.path.join(MODELS_PATH, f"{model}.thneed")
expected_size = model_sizes.get(model)
if expected_size is None:
print(f"Size data for {model} not available.")
continue
if os.path.isfile(model_path):
local_size = os.path.getsize(model_path)
if automatically_update_models and local_size != expected_size:
print(f"Model {model} is outdated. Re-downloading...")
delete_file(model_path)
download_queue.append(model)
all_models_downloaded = False
else:
if automatically_update_models:
print(f"Model {model} isn't downloaded. Downloading...")
download_queue.append(model)
all_models_downloaded = False
for model in download_queue:
self.queue_model_download(model)
return all_models_downloaded
def validate_models(self):
current_model = self.params.get("Model", encoding='utf-8')
current_model_name = self.params.get("ModelName", encoding='utf-8')
if "(Default)" in current_model_name and current_model_name != DEFAULT_MODEL_NAME:
self.params.put_nonblocking("ModelName", current_model_name.replace(" (Default)", ""))
available_models = self.params.get("AvailableModels", encoding='utf-8')
if not available_models:
return
current_model_path = os.path.join(MODELS_PATH, f"{current_model}.thneed")
if not os.path.isfile(current_model_path):
print(f"Model {current_model} is not downloaded. Downloading...")
self.download_model(current_model)
for model_file in os.listdir(MODELS_PATH):
model_name = model_file.replace(".thneed", "")
if model_name not in available_models.split(','):
if model_name == current_model:
self.params.put_nonblocking("Model", DEFAULT_MODEL)
self.params.put_nonblocking("ModelName", DEFAULT_MODEL_NAME)
delete_file(os.path.join(MODELS_PATH, model_file))
print(f"Deleted model file: {model_file} - Reason: Model is not in the list of available models")
def update_models(self, boot_run=False):
if boot_run:
self.copy_default_model()
repo_url = get_repository_url()
if repo_url is None:
print("GitHub and GitLab are offline...")
return
model_info = self.fetch_models(f"{repo_url}Versions/model_names_{VERSION}.json")
if model_info:
self.update_model_params(model_info, repo_url)
if boot_run:
self.validate_models()
def download_all_models(self):
repo_url = get_repository_url()
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", self.download_param, self.download_progress_param, self.params_memory)
return
model_info = self.fetch_models(f"{repo_url}Versions/model_names_{VERSION}.json")
if not model_info:
handle_error(None, "Unable to update model list...", "Model list unavailable", self.download_param, self.download_progress_param, self.params_memory)
return
available_models = self.params.get("AvailableModels", encoding='utf-8')
if not available_models:
handle_error(None, "There's no model to download...", "There's no model to download...", self.download_param, self.download_progress_param, self.params_memory)
return
available_models = available_models.split(',')
available_model_names = self.params.get("AvailableModelsNames", encoding='utf-8').split(',')
for model in available_models:
if self.params_memory.get_bool(self.cancel_download_param):
return
if not os.path.isfile(os.path.join(MODELS_PATH, f"{model}.thneed")):
model_index = available_models.index(model)
model_name = available_model_names[model_index]
cleaned_model_name = re.sub(r'[🗺️👀📡]', '', model_name).strip()
print(f"Downloading model: {cleaned_model_name}")
self.queue_model_download(model, cleaned_model_name)
while self.params_memory.get(self.download_param, encoding='utf-8'):
time.sleep(1)
while not all(os.path.isfile(os.path.join(MODELS_PATH, f"{model}.thneed")) for model in available_models):
time.sleep(1)
self.params_memory.put(self.download_progress_param, "All models downloaded!")
self.params_memory.remove("DownloadAllModels")
self.params.put_bool_nonblocking("ModelsDownloaded", True)