2024-08-01 10:09:53 +08:00
import json
import os
import re
2024-10-22 01:27:33 +08:00
import requests
2024-08-01 10:09:53 +08:00
import shutil
import time
2024-10-22 01:27:33 +08:00
import urllib . parse
2024-08-01 10:09:53 +08:00
import urllib . request
from openpilot . common . basedir import BASEDIR
2024-10-22 01:27:33 +08:00
from openpilot . common . params import Params
2024-08-01 10:09:53 +08:00
2024-10-22 01:27:33 +08:00
from openpilot . selfdrive . frogpilot . assets . download_functions import GITHUB_URL , GITLAB_URL , download_file , get_repository_url , handle_error , handle_request_error , verify_download
from openpilot . selfdrive . frogpilot . frogpilot_functions import MODELS_PATH
from openpilot . selfdrive . frogpilot . frogpilot_utilities import delete_file
2024-08-01 10:09:53 +08:00
2024-10-22 01:27:33 +08:00
VERSION = " v10 "
2024-08-01 10:09:53 +08:00
2024-09-02 03:11:41 +08:00
DEFAULT_MODEL = " north-dakota "
DEFAULT_MODEL_NAME = " North Dakota (Default) "
2024-08-01 10:09:53 +08:00
2024-10-22 01:27:33 +08:00
2024-08-01 10:09:53 +08:00
def process_model_name ( model_name ) :
2024-09-01 01:30:54 +08:00
cleaned_name = re . sub ( r ' [🗺️👀📡] ' , ' ' , model_name )
cleaned_name = re . sub ( r ' [^a-zA-Z0-9()-] ' , ' ' , cleaned_name )
return cleaned_name . replace ( ' ' , ' ' ) . replace ( ' (Default) ' , ' ' ) . replace ( ' - ' , ' ' )
2024-08-01 10:09:53 +08:00
2024-10-22 01:27:33 +08:00
2024-09-01 01:30:54 +08:00
class ModelManager :
def __init__ ( self ) :
self . params = Params ( )
self . params_memory = Params ( " /dev/shm/params " )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
self . cancel_download_param = " CancelModelDownload "
self . download_param = " ModelToDownload "
self . download_progress_param = " ModelDownloadProgress "
2024-10-22 01:27:33 +08:00
@staticmethod
def fetch_models ( url ) :
try :
with urllib . request . urlopen ( url , timeout = 10 ) as response :
return json . loads ( response . read ( ) . decode ( ' utf-8 ' ) ) [ ' models ' ]
except Exception as error :
handle_request_error ( error , None , None , None , None )
return [ ]
@staticmethod
def fetch_all_model_sizes ( repo_url ) :
project_path = " FrogAi/FrogPilot-Resources "
branch = " Models "
if " github " in repo_url :
api_url = f " https://api.github.com/repos/ { project_path } /contents?ref= { branch } "
elif " gitlab " in repo_url :
api_url = f " https://gitlab.com/api/v4/projects/ { urllib . parse . quote_plus ( project_path ) } /repository/tree?ref= { branch } "
else :
raise ValueError ( f " Unsupported repository URL format: { repo_url } . Supported formats are GitHub and GitLab URLs. " )
try :
response = requests . get ( api_url )
response . raise_for_status ( )
thneed_files = [ file for file in response . json ( ) if file [ ' name ' ] . endswith ( ' .thneed ' ) ]
if " gitlab " in repo_url :
model_sizes = { }
for file in thneed_files :
file_path = file [ ' path ' ]
metadata_url = f " https://gitlab.com/api/v4/projects/ { urllib . parse . quote_plus ( project_path ) } /repository/files/ { urllib . parse . quote_plus ( file_path ) } /raw?ref= { branch } "
metadata_response = requests . head ( metadata_url )
metadata_response . raise_for_status ( )
model_sizes [ file [ ' name ' ] . replace ( ' .thneed ' , ' ' ) ] = int ( metadata_response . headers . get ( ' content-length ' , 0 ) )
return model_sizes
else :
return { file [ ' name ' ] . replace ( ' .thneed ' , ' ' ) : file [ ' size ' ] for file in thneed_files if ' size ' in file }
2024-11-03 10:36:15 +08:00
except :
return { }
2024-10-22 01:27:33 +08:00
@staticmethod
def copy_default_model ( ) :
classic_default_model_path = os . path . join ( MODELS_PATH , f " { DEFAULT_MODEL } .thneed " )
source_path = os . path . join ( BASEDIR , " selfdrive " , " classic_modeld " , " models " , " supercombo.thneed " )
if os . path . isfile ( source_path ) :
shutil . copyfile ( source_path , classic_default_model_path )
print ( f " Copied the classic default model from { source_path } to { classic_default_model_path } " )
default_model_path = os . path . join ( MODELS_PATH , " secret-good-openpilot.thneed " )
source_path = os . path . join ( BASEDIR , " selfdrive " , " modeld " , " models " , " supercombo.thneed " )
if os . path . isfile ( source_path ) :
shutil . copyfile ( source_path , default_model_path )
print ( f " Copied the default model from { source_path } to { default_model_path } " )
2024-11-03 10:36:15 +08:00
def handle_verification_failure ( self , model , model_path , temp_model_path ) :
2024-09-01 01:30:54 +08:00
if self . params_memory . get_bool ( self . cancel_download_param ) :
return
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
print ( f " Verification failed for model { model } . Retrying from GitLab... " )
model_url = f " { GITLAB_URL } Models/ { model } .thneed "
2024-11-03 10:36:15 +08:00
download_file ( self . cancel_download_param , model_path , temp_model_path , self . download_progress_param , model_url , self . download_param , self . params_memory )
2024-09-01 01:30:54 +08:00
2024-11-03 10:36:15 +08:00
if verify_download ( model_path , temp_model_path , model_url ) :
2024-09-01 01:30:54 +08:00
print ( f " Model { model } redownloaded and verified successfully from GitLab. " )
else :
handle_error ( model_path , " GitLab verification failed " , " Verification failed " , self . download_param , self . download_progress_param , self . params_memory )
def download_model ( self , model_to_download ) :
model_path = os . path . join ( MODELS_PATH , f " { model_to_download } .thneed " )
2024-11-03 10:36:15 +08:00
temp_model_path = f " { os . path . splitext ( model_path ) [ 0 ] } _temp.thneed "
2024-09-02 03:11:41 +08:00
if os . path . isfile ( model_path ) :
2024-09-01 01:30:54 +08:00
handle_error ( model_path , " Model already exists... " , " Model already exists... " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-02 03:11:41 +08:00
repo_url = get_repository_url ( )
if not repo_url :
2024-11-03 10:36:15 +08:00
handle_error ( temp_model_path , " GitHub and GitLab are offline... " , " Repository unavailable " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-02 03:11:41 +08:00
model_url = f " { repo_url } Models/ { model_to_download } .thneed "
2024-09-01 01:30:54 +08:00
print ( f " Downloading model: { model_to_download } " )
2024-11-03 10:36:15 +08:00
download_file ( self . cancel_download_param , model_path , temp_model_path , self . download_progress_param , model_url , self . download_param , self . params_memory )
2024-08-01 10:09:53 +08:00
2024-11-03 10:36:15 +08:00
if verify_download ( model_path , temp_model_path , model_url ) :
2024-09-01 01:30:54 +08:00
print ( f " Model { model_to_download } downloaded and verified successfully! " )
self . params_memory . put ( self . download_progress_param , " Downloaded! " )
self . params_memory . remove ( self . download_param )
else :
2024-11-03 10:36:15 +08:00
self . handle_verification_failure ( model_to_download , model_path , temp_model_path )
2024-09-01 01:30:54 +08:00
2024-10-22 01:27:33 +08:00
def queue_model_download ( self , model , model_name = None ) :
while self . params_memory . get ( self . download_param , encoding = ' utf-8 ' ) :
time . sleep ( 1 )
2024-09-01 01:30:54 +08:00
2024-10-22 01:27:33 +08:00
self . params_memory . put ( self . download_param , model )
if model_name :
self . params_memory . put ( self . download_progress_param , f " Downloading { model_name } ... " )
2024-09-01 01:30:54 +08:00
2024-10-22 01:27:33 +08:00
def update_model_params ( self , model_info , repo_url ) :
available_models = [ ]
2024-09-01 01:30:54 +08:00
for model in model_info :
available_models . append ( model [ ' id ' ] )
self . params . put_nonblocking ( " AvailableModels " , ' , ' . join ( available_models ) )
2024-10-22 01:27:33 +08:00
self . params . put_nonblocking ( " AvailableModelsNames " , ' , ' . join ( [ model [ ' name ' ] for model in model_info ] ) )
self . params . put_nonblocking ( " ClassicModels " , ' , ' . join ( [ model [ ' id ' ] for model in model_info if model . get ( " classic_model " , False ) ] ) )
self . params . put_nonblocking ( " ExperimentalModels " , ' , ' . join ( [ model [ ' id ' ] for model in model_info if model . get ( " experimental " , False ) ] ) )
self . params . put_nonblocking ( " NavigationModels " , ' , ' . join ( [ model [ ' id ' ] for model in model_info if " 🗺️ " in model [ ' name ' ] ] ) )
self . params . put_nonblocking ( " RadarlessModels " , ' , ' . join ( [ model [ ' id ' ] for model in model_info if " 📡 " not in model [ ' name ' ] ] ) )
2024-09-01 01:30:54 +08:00
print ( " Models list updated successfully. " )
if available_models :
2024-10-22 01:27:33 +08:00
models_downloaded = self . are_all_models_downloaded ( available_models , repo_url )
2024-09-01 01:30:54 +08:00
self . params . put_bool_nonblocking ( " ModelsDownloaded " , models_downloaded )
2024-10-22 01:27:33 +08:00
def are_all_models_downloaded ( self , available_models , repo_url ) :
2024-09-01 01:30:54 +08:00
automatically_update_models = self . params . get_bool ( " AutomaticallyUpdateModels " )
all_models_downloaded = True
2024-10-22 01:27:33 +08:00
model_sizes = self . fetch_all_model_sizes ( repo_url )
2024-09-02 03:11:41 +08:00
download_queue = [ ]
2024-10-22 01:27:33 +08:00
2024-09-01 01:30:54 +08:00
for model in available_models :
model_path = os . path . join ( MODELS_PATH , f " { model } .thneed " )
2024-10-22 01:27:33 +08:00
expected_size = model_sizes . get ( model )
if expected_size is None :
print ( f " Size data for { model } not available. " )
continue
2024-09-01 01:30:54 +08:00
2024-09-02 03:11:41 +08:00
if os . path . isfile ( model_path ) :
2024-10-22 01:27:33 +08:00
local_size = os . path . getsize ( model_path )
if automatically_update_models and local_size != expected_size :
print ( f " Model { model } is outdated. Re-downloading... " )
delete_file ( model_path )
download_queue . append ( model )
all_models_downloaded = False
2024-09-01 01:30:54 +08:00
else :
if automatically_update_models :
print ( f " Model { model } isn ' t downloaded. Downloading... " )
2024-09-02 03:11:41 +08:00
download_queue . append ( model )
2024-09-01 01:30:54 +08:00
all_models_downloaded = False
2024-09-02 03:11:41 +08:00
for model in download_queue :
self . queue_model_download ( model )
2024-09-01 01:30:54 +08:00
2024-09-02 03:11:41 +08:00
return all_models_downloaded
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
def validate_models ( self ) :
current_model = self . params . get ( " Model " , encoding = ' utf-8 ' )
current_model_name = self . params . get ( " ModelName " , encoding = ' utf-8 ' )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
if " (Default) " in current_model_name and current_model_name != DEFAULT_MODEL_NAME :
self . params . put_nonblocking ( " ModelName " , current_model_name . replace ( " (Default) " , " " ) )
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
available_models = self . params . get ( " AvailableModels " , encoding = ' utf-8 ' )
if not available_models :
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
2024-09-02 03:11:41 +08:00
current_model_path = os . path . join ( MODELS_PATH , f " { current_model } .thneed " )
if not os . path . isfile ( current_model_path ) :
print ( f " Model { current_model } is not downloaded. Downloading... " )
self . download_model ( current_model )
2024-09-01 01:30:54 +08:00
for model_file in os . listdir ( MODELS_PATH ) :
2024-09-02 03:11:41 +08:00
model_name = model_file . replace ( " .thneed " , " " )
if model_name not in available_models . split ( ' , ' ) :
if model_name == current_model :
2024-09-01 01:30:54 +08:00
self . params . put_nonblocking ( " Model " , DEFAULT_MODEL )
self . params . put_nonblocking ( " ModelName " , DEFAULT_MODEL_NAME )
delete_file ( os . path . join ( MODELS_PATH , model_file ) )
2024-10-22 01:27:33 +08:00
print ( f " Deleted model file: { model_file } - Reason: Model is not in the list of available models " )
2024-09-02 03:11:41 +08:00
def update_models ( self , boot_run = False ) :
2024-09-01 01:30:54 +08:00
if boot_run :
self . copy_default_model ( )
2024-09-02 03:11:41 +08:00
repo_url = get_repository_url ( )
if repo_url is None :
2024-09-01 01:30:54 +08:00
print ( " GitHub and GitLab are offline... " )
return
2024-10-22 01:27:33 +08:00
model_info = self . fetch_models ( f " { repo_url } Versions/model_names_ { VERSION } .json " )
if model_info :
self . update_model_params ( model_info , repo_url )
2024-09-02 03:11:41 +08:00
if boot_run :
self . validate_models ( )
2024-09-01 01:30:54 +08:00
def download_all_models ( self ) :
2024-09-02 03:11:41 +08:00
repo_url = get_repository_url ( )
if not repo_url :
2024-09-01 01:30:54 +08:00
handle_error ( None , " GitHub and GitLab are offline... " , " Repository unavailable " , self . download_param , self . download_progress_param , self . params_memory )
return
2024-09-02 03:11:41 +08:00
model_info = self . fetch_models ( f " { repo_url } Versions/model_names_ { VERSION } .json " )
2024-09-01 01:30:54 +08:00
if not model_info :
handle_error ( None , " Unable to update model list... " , " Model list unavailable " , self . download_param , self . download_progress_param , self . params_memory )
return
2024-08-01 10:09:53 +08:00
2024-09-01 01:30:54 +08:00
available_models = self . params . get ( " AvailableModels " , encoding = ' utf-8 ' )
if not available_models :
handle_error ( None , " There ' s no model to download... " , " There ' s no model to download... " , self . download_param , self . download_progress_param , self . params_memory )
2024-08-01 10:09:53 +08:00
return
2024-09-01 01:30:54 +08:00
available_models = available_models . split ( ' , ' )
available_model_names = self . params . get ( " AvailableModelsNames " , encoding = ' utf-8 ' ) . split ( ' , ' )
for model in available_models :
if self . params_memory . get_bool ( self . cancel_download_param ) :
return
2024-09-02 03:11:41 +08:00
if not os . path . isfile ( os . path . join ( MODELS_PATH , f " { model } .thneed " ) ) :
2024-09-01 01:30:54 +08:00
model_index = available_models . index ( model )
model_name = available_model_names [ model_index ]
cleaned_model_name = re . sub ( r ' [🗺️👀📡] ' , ' ' , model_name ) . strip ( )
print ( f " Downloading model: { cleaned_model_name } " )
self . queue_model_download ( model , cleaned_model_name )
while self . params_memory . get ( self . download_param , encoding = ' utf-8 ' ) :
time . sleep ( 1 )
2024-09-02 03:11:41 +08:00
while not all ( os . path . isfile ( os . path . join ( MODELS_PATH , f " { model } .thneed " ) ) for model in available_models ) :
2024-09-01 01:30:54 +08:00
time . sleep ( 1 )
self . params_memory . put ( self . download_progress_param , " All models downloaded! " )
self . params_memory . remove ( " DownloadAllModels " )
self . params . put_bool_nonblocking ( " ModelsDownloaded " , True )