diff --git a/common/api/__init__.py b/common/api/__init__.py index 7f517bc378..8b261486ba 100644 --- a/common/api/__init__.py +++ b/common/api/__init__.py @@ -20,3 +20,7 @@ class Api: def api_get(endpoint, method='GET', timeout=None, access_token=None, **params): return CommaConnectApi(None).api_get(endpoint, method, timeout, access_token, **params) + + +def get_key_pair(): + return CommaConnectApi(None).get_key_pair() diff --git a/common/api/base.py b/common/api/base.py index 8778e14239..682b266056 100644 --- a/common/api/base.py +++ b/common/api/base.py @@ -1,18 +1,22 @@ import jwt +import os import requests import unicodedata from datetime import datetime, timedelta, UTC from openpilot.system.hardware.hw import Paths from openpilot.system.version import get_version + # name : jwt signature algorithm +KEYS = {"id_rsa" : "RS256", + "id_ecdsa" : "ES256"} + class BaseApi: def __init__(self, dongle_id, api_host, user_agent="openpilot-"): self.dongle_id = dongle_id self.api_host = api_host self.user_agent = user_agent - with open(f'{Paths.persist_root()}/comma/id_rsa') as f: - self.private_key = f.read() + self.jwt_algorithm, self.private_key, _ = self.get_key_pair() def get(self, *args, **kwargs): return self.request('GET', *args, **kwargs) @@ -34,7 +38,7 @@ class BaseApi: } if payload_extra is not None: payload.update(payload_extra) - token = jwt.encode(payload, self.private_key, algorithm='RS256') + token = jwt.encode(payload, self.private_key, algorithm=self.jwt_algorithm) if isinstance(token, bytes): token = token.decode('utf8') return token @@ -56,3 +60,11 @@ class BaseApi: headers['User-Agent'] = self.user_agent + version return requests.request(method, f"{self.api_host}/{endpoint}", timeout=timeout, headers=headers, json=json, params=params) + + @staticmethod + def get_key_pair(): + for key in KEYS: + if os.path.isfile(Paths.persist_root() + f'/comma/{key}') and os.path.isfile(Paths.persist_root() + f'/comma/{key}.pub'): + with open(Paths.persist_root() + f'/comma/{key}') as private, open(Paths.persist_root() + f'/comma/{key}.pub') as public: + return KEYS[key], private.read(), public.read() + return None, None, None diff --git a/tools/lib/api.py b/tools/lib/api.py index c6e2d98914..f84fe75869 100644 --- a/tools/lib/api.py +++ b/tools/lib/api.py @@ -1,5 +1,6 @@ import os import requests +from requests.adapters import HTTPAdapter, Retry API_HOST = os.getenv('API_HOST', 'https://api.commadotai.com') # TODO: this should be merged into common.api @@ -11,6 +12,9 @@ class CommaApi: if token: self.session.headers['Authorization'] = 'JWT ' + token + retries = Retry(total=5, backoff_factor=1, status_forcelist=[500, 502, 503, 504]) + self.session.mount('https://', HTTPAdapter(max_retries=retries)) + def request(self, method, endpoint, **kwargs): with self.session.request(method, API_HOST + '/' + endpoint, **kwargs) as resp: resp_json = resp.json()