2024-08-01 10:09:53 +08:00
import json
import os
import re
import requests
import shutil
import time
import urllib . request
from openpilot . common . basedir import BASEDIR
2024-09-01 01:30:54 +08:00
from openpilot . common . params import Params , UnknownKeyName
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
from openpilot . selfdrive . frogpilot . controls . lib . download_functions import GITHUB_URL , GITLAB_URL , download_file , get_repository_url , handle_error , handle_request_error , verify_download
from openpilot . selfdrive . frogpilot . controls . lib . frogpilot_functions import MODELS_PATH , delete_file
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
VERSION = " v5 "
2024-08-01 10:09:53 +08:00
DEFAULT_MODEL = " north-dakota-v2 "
DEFAULT_MODEL_NAME = " North Dakota V2 (Default) "
def process_model_name ( model_name ) :
2024-09-01 01:30:54 +08:00
cleaned_name = re . sub ( r ' [🗺️👀📡] ' , ' ' , model_name )
cleaned_name = re . sub ( r ' [^a-zA-Z0-9()-] ' , ' ' , cleaned_name )
return cleaned_name . replace ( ' ' , ' ' ) . replace ( ' (Default) ' , ' ' ) . replace ( ' - ' , ' ' )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
class ModelManager :
def __init__ ( self ) :
self . params = Params ( )
self . params_memory = Params ( " /dev/shm/params " )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
self . cancel_download_param = " CancelModelDownload "
self . download_param = " ModelToDownload "
self . download_progress_param = " ModelDownloadProgress "
def handle_verification_failure ( self , model , model_path ) :
if self . params_memory . get_bool ( self . cancel_download_param ) :
return
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
print ( f " Verification failed for model { model } . Retrying from GitLab... " )
model_url = f " { GITLAB_URL } Models/ { model } .thneed "
download_file ( self . cancel_download_param , model_path , self . download_progress_param , model_url , self . download_param , self . params_memory )
if verify_download ( model_path , model_url ) :
print ( f " Model { model } redownloaded and verified successfully from GitLab. " )
else :
handle_error ( model_path , " GitLab verification failed " , " Verification failed " , self . download_param , self . download_progress_param , self . params_memory )
def download_model ( self , model_to_download ) :
model_path = os . path . join ( MODELS_PATH , f " { model_to_download } .thneed " )
if os . path . exists ( model_path ) :
handle_error ( model_path , " Model already exists... " , " Model already exists... " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
self . repo_url = get_repository_url ( )
if not self . repo_url :
handle_error ( model_path , " GitHub and GitLab are offline... " , " Repository unavailable " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
model_url = f " { self . repo_url } Models/ { model_to_download } .thneed "
print ( f " Downloading model: { model_to_download } " )
download_file ( self . cancel_download_param , model_path , self . download_progress_param , model_url , self . download_param , self . params_memory )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
if verify_download ( model_path , model_url ) :
print ( f " Model { model_to_download } downloaded and verified successfully! " )
self . params_memory . put ( self . download_progress_param , " Downloaded! " )
self . params_memory . remove ( self . download_param )
else :
self . handle_verification_failure ( model_to_download , model_path )
def fetch_models ( self , url ) :
try :
with urllib . request . urlopen ( url , timeout = 10 ) as response :
return json . loads ( response . read ( ) . decode ( ' utf-8 ' ) ) [ ' models ' ]
except Exception as error :
handle_request_error ( error , None , None , None , None )
return [ ]
def update_model_params ( self , model_info ) :
available_models , available_model_names , experimental_models , navigation_models , radarless_models = [ ] , [ ] , [ ] , [ ] , [ ]
for model in model_info :
available_models . append ( model [ ' id ' ] )
available_model_names . append ( model [ ' name ' ] )
if model . get ( " experimental " , False ) :
experimental_models . append ( model [ ' id ' ] )
if " 🗺️ " in model [ ' name ' ] :
navigation_models . append ( model [ ' id ' ] )
if " 📡 " not in model [ ' name ' ] :
radarless_models . append ( model [ ' id ' ] )
self . params . put_nonblocking ( " AvailableModels " , ' , ' . join ( available_models ) )
self . params . put_nonblocking ( " AvailableModelsNames " , ' , ' . join ( available_model_names ) )
self . params . put_nonblocking ( " ExperimentalModels " , ' , ' . join ( experimental_models ) )
self . params . put_nonblocking ( " NavigationModels " , ' , ' . join ( navigation_models ) )
self . params . put_nonblocking ( " RadarlessModels " , ' , ' . join ( radarless_models ) )
print ( " Models list updated successfully. " )
if available_models :
models_downloaded = self . are_all_models_downloaded ( available_models , available_model_names )
self . params . put_bool_nonblocking ( " ModelsDownloaded " , models_downloaded )
def are_all_models_downloaded ( self , available_models , available_model_names ) :
automatically_update_models = self . params . get_bool ( " AutomaticallyUpdateModels " )
all_models_downloaded = True
for model in available_models :
model_path = os . path . join ( MODELS_PATH , f " { model } .thneed " )
model_url = f " { self . repo_url } Models/ { model } .thneed "
if os . path . exists ( model_path ) :
if automatically_update_models :
if not verify_download ( model_path , model_url ) :
print ( f " Model { model } is outdated. Re-downloading... " )
delete_file ( model_path )
self . remove_model_params ( available_model_names , available_models , model )
self . queue_model_download ( model )
all_models_downloaded = False
else :
if automatically_update_models :
print ( f " Model { model } isn ' t downloaded. Downloading... " )
self . remove_model_params ( available_model_names , available_models , model )
self . queue_model_download ( model )
all_models_downloaded = False
return all_models_downloaded
def remove_model_params ( self , available_model_names , available_models , model ) :
part_model_param = process_model_name ( available_model_names [ available_models . index ( model ) ] )
try :
self . params . check_key ( part_model_param + " CalibrationParams " )
except UnknownKeyName :
return
self . params . remove ( part_model_param + " CalibrationParams " )
self . params . remove ( part_model_param + " LiveTorqueParameters " )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
def queue_model_download ( self , model , model_name = None ) :
while self . params_memory . get ( self . download_param , encoding = ' utf-8 ' ) :
time . sleep ( 1 )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
self . params_memory . put ( self . download_param , model )
if model_name :
self . params_memory . put ( self . download_progress_param , f " Downloading { model_name } ... " )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
def validate_models ( self ) :
current_model = self . params . get ( " Model " , encoding = ' utf-8 ' )
current_model_name = self . params . get ( " ModelName " , encoding = ' utf-8 ' )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
if " (Default) " in current_model_name and current_model_name != DEFAULT_MODEL_NAME :
self . params . put_nonblocking ( " ModelName " , current_model_name . replace ( " (Default) " , " " ) )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
available_models = self . params . get ( " AvailableModels " , encoding = ' utf-8 ' )
if not available_models :
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
for model_file in os . listdir ( MODELS_PATH ) :
if model_file . replace ( " .thneed " , " " ) not in available_models . split ( ' , ' ) :
if model_file == current_model :
self . params . put_nonblocking ( " Model " , DEFAULT_MODEL )
self . params . put_nonblocking ( " ModelName " , DEFAULT_MODEL_NAME )
delete_file ( os . path . join ( MODELS_PATH , model_file ) )
print ( f " Deleted model file: { model_file } " )
def copy_default_model ( self ) :
default_model_path = os . path . join ( MODELS_PATH , f " { DEFAULT_MODEL } .thneed " )
if not os . path . exists ( default_model_path ) :
source_path = os . path . join ( BASEDIR , " selfdrive " , " modeld " , " models " , " supercombo.thneed " )
if os . path . exists ( source_path ) :
shutil . copyfile ( source_path , default_model_path )
print ( f " Copied default model from { source_path } to { default_model_path } " )
else :
print ( f " Source default model not found at { source_path } . Exiting... " )
def update_models ( self , boot_run = True ) :
self . repo_url = get_repository_url ( )
if boot_run :
self . copy_default_model ( )
boot_checks = 0
while self . repo_url is None and boot_checks < 60 :
boot_checks + = 1
if boot_checks > 60 :
break
2024-08-01 10:09:53 +08:00
time . sleep ( 1 )
2024-09-01 01:30:54 +08:00
self . validate_models ( )
elif self . repo_url is None :
print ( " GitHub and GitLab are offline... " )
return
model_info = self . fetch_models ( f " { self . repo_url } Versions/model_names_ { VERSION } .json " )
if model_info :
self . update_model_params ( model_info )
def download_all_models ( self ) :
self . repo_url = get_repository_url ( )
if not self . repo_url :
handle_error ( None , " GitHub and GitLab are offline... " , " Repository unavailable " , self . download_param , self . download_progress_param , self . params_memory )
return
model_info = self . fetch_models ( f " { self . repo_url } Versions/model_names_ { VERSION } .json " )
if not model_info :
handle_error ( None , " Unable to update model list... " , " Model list unavailable " , self . download_param , self . download_progress_param , self . params_memory )
return
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
available_models = self . params . get ( " AvailableModels " , encoding = ' utf-8 ' )
if not available_models :
handle_error ( None , " There ' s no model to download... " , " There ' s no model to download... " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
available_models = available_models . split ( ' , ' )
available_model_names = self . params . get ( " AvailableModelsNames " , encoding = ' utf-8 ' ) . split ( ' , ' )
for model in available_models :
if self . params_memory . get_bool ( self . cancel_download_param ) :
return
if not os . path . exists ( os . path . join ( MODELS_PATH , f " { model } .thneed " ) ) :
model_index = available_models . index ( model )
model_name = available_model_names [ model_index ]
cleaned_model_name = re . sub ( r ' [🗺️👀📡] ' , ' ' , model_name ) . strip ( )
print ( f " Downloading model: { cleaned_model_name } " )
self . queue_model_download ( model , cleaned_model_name )
while self . params_memory . get ( self . download_param , encoding = ' utf-8 ' ) :
time . sleep ( 1 )
while not all ( os . path . exists ( os . path . join ( MODELS_PATH , f " { model } .thneed " ) ) for model in available_models ) :
time . sleep ( 1 )
self . params_memory . put ( self . download_progress_param , " All models downloaded! " )
self . params_memory . remove ( " DownloadAllModels " )
self . params . put_bool_nonblocking ( " ModelsDownloaded " , True )