make URLFile safe after fork() (#31309)
* make URLFile safe after fork() * cache the pool manager in each instance * type hints old-commit-hash: a8aa04e6bda2fc8ca31db055f584bfc52d104d2c
This commit is contained in:
@@ -6,6 +6,7 @@ from hashlib import sha256
|
||||
from urllib3 import PoolManager
|
||||
from urllib3.util import Timeout
|
||||
from tenacity import retry, wait_random_exponential, stop_after_attempt
|
||||
from typing import Optional
|
||||
|
||||
from openpilot.common.file_helpers import atomic_write_in_dir
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
@@ -25,9 +26,12 @@ class URLFileException(Exception):
|
||||
|
||||
|
||||
class URLFile:
|
||||
_tlocal = threading.local()
|
||||
_pid: Optional[int] = None
|
||||
_pool_manager: Optional[PoolManager] = None
|
||||
_pool_manager_lock = threading.Lock()
|
||||
|
||||
def __init__(self, url, debug=False, cache=None):
|
||||
self._pool_manager = None
|
||||
self._url = url
|
||||
self._pos = 0
|
||||
self._length = None
|
||||
@@ -41,11 +45,6 @@ class URLFile:
|
||||
if not self._force_download:
|
||||
os.makedirs(Paths.download_cache_root(), exist_ok=True)
|
||||
|
||||
try:
|
||||
self._http_client = URLFile._tlocal.http_client
|
||||
except AttributeError:
|
||||
self._http_client = URLFile._tlocal.http_client = PoolManager()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -55,10 +54,20 @@ class URLFile:
|
||||
self._local_file.close()
|
||||
self._local_file = None
|
||||
|
||||
def _http_client(self) -> PoolManager:
|
||||
if self._pool_manager is None:
|
||||
pid = os.getpid()
|
||||
with URLFile._pool_manager_lock:
|
||||
if URLFile._pid != pid or URLFile._pool_manager is None: # unsafe to share after fork
|
||||
URLFile._pid = pid
|
||||
URLFile._pool_manager = PoolManager(num_pools=10, maxsize=10)
|
||||
self._pool_manager = URLFile._pool_manager
|
||||
return self._pool_manager
|
||||
|
||||
@retry(wait=wait_random_exponential(multiplier=1, max=5), stop=stop_after_attempt(3), reraise=True)
|
||||
def get_length_online(self):
|
||||
timeout = Timeout(connect=50.0, read=500.0)
|
||||
response = self._http_client.request('HEAD', self._url, timeout=timeout, preload_content=False)
|
||||
response = self._http_client().request('HEAD', self._url, timeout=timeout, preload_content=False)
|
||||
if not (200 <= response.status <= 299):
|
||||
return -1
|
||||
length = response.headers.get('content-length', 0)
|
||||
@@ -131,7 +140,7 @@ class URLFile:
|
||||
t1 = time.time()
|
||||
|
||||
timeout = Timeout(connect=50.0, read=500.0)
|
||||
response = self._http_client.request('GET', self._url, timeout=timeout, preload_content=False, headers=headers)
|
||||
response = self._http_client().request('GET', self._url, timeout=timeout, preload_content=False, headers=headers)
|
||||
ret = response.data
|
||||
|
||||
if self._debug:
|
||||
|
||||
Reference in New Issue
Block a user