diff --git a/tools/lib/filereader.py b/tools/lib/filereader.py index ee9ee294bb..f5418be81a 100644 --- a/tools/lib/filereader.py +++ b/tools/lib/filereader.py @@ -1,4 +1,5 @@ import os +import io import posixpath import socket from functools import cache @@ -41,9 +42,17 @@ def file_exists(fn): return URLFile(fn).get_length_online() != -1 return os.path.exists(fn) +class DiskFile(io.BufferedReader): + def get_multi_range(self, ranges: list[tuple[int, int]]) -> list[bytes]: + parts = [] + for r in ranges: + self.seek(r[0]) + parts.append(self.read(r[1] - r[0])) + return parts -def FileReader(fn, debug=False): +def FileReader(fn): fn = resolve_name(fn) if fn.startswith(("http://", "https://")): - return URLFile(fn, debug=debug) - return open(fn, "rb") + return URLFile(fn) + else: + return DiskFile(open(fn, "rb")) diff --git a/tools/lib/url_file.py b/tools/lib/url_file.py index 2bf3ba8209..01c6c5dc47 100644 --- a/tools/lib/url_file.py +++ b/tools/lib/url_file.py @@ -1,13 +1,11 @@ +import re import logging import os import socket -import time from hashlib import sha256 from urllib3 import PoolManager, Retry from urllib3.response import BaseHTTPResponse from urllib3.util import Timeout -from urllib3.exceptions import MaxRetryError - from openpilot.common.utils import atomic_write from openpilot.system.hardware.hw import Paths @@ -42,12 +40,11 @@ class URLFile: URLFile._pool_manager = PoolManager(num_pools=10, maxsize=100, socket_options=socket_options, retries=retries) return URLFile._pool_manager - def __init__(self, url: str, timeout: int = 10, debug: bool = False, cache: bool | None = None): + def __init__(self, url: str, timeout: int = 10, cache: bool | None = None): self._url = url self._timeout = Timeout(connect=timeout, read=timeout) self._pos = 0 self._length: int | None = None - self._debug = debug # True by default, false if FILEREADER_CACHE is defined, but can be overwritten by the cache input self._force_download = not int(os.environ.get("FILEREADER_CACHE", "0")) if cache is not None: @@ -63,10 +60,7 @@ class URLFile: pass def _request(self, method: str, url: str, headers: dict[str, str] | None = None) -> BaseHTTPResponse: - try: - return URLFile.pool_manager().request(method, url, timeout=self._timeout, headers=headers) - except MaxRetryError as e: - raise URLFileException(f"Failed to {method} {url}: {e}") from e + return URLFile.pool_manager().request(method, url, timeout=self._timeout, headers=headers) def get_length_online(self) -> int: response = self._request('HEAD', self._url) @@ -125,39 +119,45 @@ class URLFile: return response def read_aux(self, ll: int | None = None) -> bytes: - download_range = False - headers = {} - if self._pos != 0 or ll is not None: - if ll is None: - end = self.get_length() - 1 - else: - end = min(self._pos + ll, self.get_length()) - 1 - if self._pos >= end: - return b"" - headers['Range'] = f"bytes={self._pos}-{end}" - download_range = True + if ll is None: + length = self.get_length() + if length == -1: + raise URLFileException(f"Remote file is empty or doesn't exist: {self._url}") + end = length + else: + end = self._pos + ll + data = self.get_multi_range([(self._pos, end)]) + self._pos += len(data[0]) + return data[0] - if self._debug: - t1 = time.monotonic() + def get_multi_range(self, ranges: list[tuple[int, int]]) -> list[bytes]: + # HTTP range requests are inclusive + assert all(e > s for s, e in ranges), "Range end must be greater than start" + rs = [f"{s}-{e-1}" for s, e in ranges if e > s] - response = self._request('GET', self._url, headers=headers) - ret = response.data + r = self._request("GET", self._url, headers={"Range": "bytes=" + ",".join(rs)}) + if r.status not in [200, 206]: + raise URLFileException(f"Expected 206 or 200 response {r.status} ({self._url})") - if self._debug: - t2 = time.monotonic() - if t2 - t1 > 0.1: - print(f"get {self._url} {headers!r} {t2 - t1:.3f} slow") + ctype = (r.headers.get("content-type") or "").lower() + if "multipart/byteranges" not in ctype: + return [r.data,] - response_code = response.status - if response_code == 416: # Requested Range Not Satisfiable - raise URLFileException(f"Error, range out of bounds {response_code} {headers} ({self._url}): {repr(ret)[:500]}") - if download_range and response_code != 206: # Partial Content - raise URLFileException(f"Error, requested range but got unexpected response {response_code} {headers} ({self._url}): {repr(ret)[:500]}") - if (not download_range) and response_code != 200: # OK - raise URLFileException(f"Error {response_code} {headers} ({self._url}): {repr(ret)[:500]}") + m = re.search(r'boundary="?([^";]+)"?', ctype) + if not m: + raise URLFileException(f"Missing multipart boundary ({self._url})") + boundary = m.group(1).encode() - self._pos += len(ret) - return ret + parts = [] + for chunk in r.data.split(b"--" + boundary): + if b"\r\n\r\n" not in chunk: + continue + payload = chunk.split(b"\r\n\r\n", 1)[1].rstrip(b"\r\n") + if payload and payload != b"--": + parts.append(payload) + if len(parts) != len(ranges): + raise URLFileException(f"Expected {len(ranges)} parts, got {len(parts)} ({self._url})") + return parts def seek(self, pos: int) -> None: self._pos = pos