From 268f6bc2fbd29a4a7a8d7ddcb0d04825cb98a173 Mon Sep 17 00:00:00 2001 From: Adeeb Shihadeh Date: Tue, 18 Jul 2023 00:03:17 -0700 Subject: [PATCH] python: lock hw device inside loop (#1522) --- __init__.py | 4 +- python/__init__.py | 1 - python/constants.py | 2 + python/spi.py | 132 +++++++++++++++++++++++--------------------- 4 files changed, 72 insertions(+), 67 deletions(-) diff --git a/__init__.py b/__init__.py index 19cbee3e..dfb2bbf1 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,6 @@ -from .python.constants import McuType, BASEDIR, FW_PATH # noqa: F401 +from .python.constants import McuType, BASEDIR, FW_PATH, USBPACKET_MAX_SIZE # noqa: F401 from .python.spi import PandaSpiException, PandaProtocolMismatch # noqa: F401 from .python.serial import PandaSerial # noqa: F401 from .python import (Panda, PandaDFU, # noqa: F401 pack_can_buffer, unpack_can_buffer, calculate_checksum, unpack_log, - DLC_TO_LEN, LEN_TO_DLC, ALTERNATIVE_EXPERIENCE, USBPACKET_MAX_SIZE, CANPACKET_HEAD_SIZE) + DLC_TO_LEN, LEN_TO_DLC, ALTERNATIVE_EXPERIENCE, CANPACKET_HEAD_SIZE) diff --git a/python/__init__.py b/python/__init__.py index 0349d161..5558a9ab 100644 --- a/python/__init__.py +++ b/python/__init__.py @@ -26,7 +26,6 @@ __version__ = '0.0.10' LOGLEVEL = os.environ.get('LOGLEVEL', 'INFO').upper() logging.basicConfig(level=LOGLEVEL, format='%(message)s') -USBPACKET_MAX_SIZE = 0x40 CANPACKET_HEAD_SIZE = 0x6 DLC_TO_LEN = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 16, 20, 24, 32, 48, 64] LEN_TO_DLC = {length: dlc for (dlc, length) in enumerate(DLC_TO_LEN)} diff --git a/python/constants.py b/python/constants.py index 4c3e778a..16409ac3 100644 --- a/python/constants.py +++ b/python/constants.py @@ -5,6 +5,8 @@ from typing import List, NamedTuple BASEDIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../") FW_PATH = os.path.join(BASEDIR, "board/obj/") +USBPACKET_MAX_SIZE = 0x40 + class McuConfig(NamedTuple): mcu: str mcu_idcode: int diff --git a/python/spi.py b/python/spi.py index 55bfa6d2..f14a04ef 100644 --- a/python/spi.py +++ b/python/spi.py @@ -9,10 +9,10 @@ import logging import threading from contextlib import contextmanager from functools import reduce -from typing import List, Optional +from typing import Callable, List, Optional from .base import BaseHandle, BaseSTBootloaderHandle, TIMEOUT -from .constants import McuType, MCU_TYPE_BY_IDCODE +from .constants import McuType, MCU_TYPE_BY_IDCODE, USBPACKET_MAX_SIZE from .utils import crc8_pedal try: @@ -115,13 +115,13 @@ class PandaSpiHandle(BaseHandle): PROTOCOL_VERSION = 1 - def __init__(self): + def __init__(self) -> None: self.dev = SpiDevice() - self._transfer = self._transfer_spidev + self._transfer_raw: Callable[[SpiDevice, int, bytes, int, int, bool], bytes] = self._transfer_spidev if "KERN" in os.environ: - self._transfer = self._transfer_kernel_driver + self._transfer_raw = self._transfer_kernel_driver self.tx_buf = bytearray(1024) self.rx_buf = bytearray(1024) @@ -134,25 +134,65 @@ class PandaSpiHandle(BaseHandle): self.fileno = self.dev._spidev.fileno() # helpers - def _calc_checksum(self, data: List[int]) -> int: + def _calc_checksum(self, data: bytes) -> int: cksum = CHECKSUM_START for b in data: cksum ^= b return cksum - def _wait_for_ack(self, spi, ack_val: int, timeout: int, tx: int) -> None: + def _wait_for_ack(self, spi, ack_val: int, timeout: int, tx: int, length: int = 1) -> bytes: timeout_s = max(MIN_ACK_TIMEOUT_MS, timeout) * 1e-3 start = time.monotonic() while (timeout == 0) or ((time.monotonic() - start) < timeout_s): - dat = spi.xfer2([tx, ])[0] - if dat == NACK: + dat = spi.xfer2([tx, ] * length) + if dat[0] == NACK: raise PandaSpiNackResponse - elif dat == ack_val: - return + elif dat[0] == ack_val: + return bytes(dat) raise PandaSpiMissingAck + def _transfer_spidev(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes: + max_rx_len = max(USBPACKET_MAX_SIZE, max_rx_len) + + logging.debug("- send header") + packet = struct.pack(" max_rx_len: + raise PandaSpiException(f"response length greater than max ({max_rx_len} {response_len})") + + # read rest + remaining = response_len - preread_len + if remaining > 0: + dat += bytes(spi.readbytes(remaining)) + + dat = dat[:3 + response_len + 1] + if self._calc_checksum(dat) != 0: + raise PandaSpiBadChecksum + + return dat[3:-1] + def _transfer_kernel_driver(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes: self.tx_buf[:len(data)] = data self.ioctl_data.endpoint = endpoint @@ -169,7 +209,7 @@ class PandaSpiHandle(BaseHandle): raise PandaSpiException(f"ioctl returned {ret}") return bytes(self.rx_buf[:ret]) - def _transfer_spidev(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes: + def _transfer(self, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes: logging.debug("starting transfer: endpoint=%d, max_rx_len=%d", endpoint, max_rx_len) logging.debug("==============================================") @@ -179,44 +219,12 @@ class PandaSpiHandle(BaseHandle): while (time.monotonic() - start_time) < timeout*1e-3: n += 1 logging.debug("\ntry #%d", n) - try: - logging.debug("- send header") - packet = struct.pack(" max_rx_len: - raise PandaSpiException("response length greater than max") - - logging.debug("- receiving response") - dat = bytes(spi.xfer2(b"\x00" * (response_len + 1))) - if self._calc_checksum([DACK, *response_len_bytes, *dat]) != 0: - raise PandaSpiBadChecksum - - return dat[:-1] - except PandaSpiException as e: - exc = e - logging.debug("SPI transfer failed, retrying", exc_info=True) + with self.dev.acquire() as spi: + try: + return self._transfer_raw(spi, endpoint, data, timeout, max_rx_len, expect_disconnect) + except PandaSpiException as e: + exc = e + logging.debug("SPI transfer failed, retrying", exc_info=True) raise exc @@ -261,28 +269,24 @@ class PandaSpiHandle(BaseHandle): self.dev.close() def controlWrite(self, request_type: int, request: int, value: int, index: int, data, timeout: int = TIMEOUT, expect_disconnect: bool = False): - with self.dev.acquire() as spi: - return self._transfer(spi, 0, struct.pack(" int: - with self.dev.acquire() as spi: - for x in range(math.ceil(len(data) / XFER_SIZE)): - self._transfer(spi, endpoint, data[XFER_SIZE*x:XFER_SIZE*(x+1)], timeout) - return len(data) + for x in range(math.ceil(len(data) / XFER_SIZE)): + self._transfer(endpoint, data[XFER_SIZE*x:XFER_SIZE*(x+1)], timeout) + return len(data) def bulkRead(self, endpoint: int, length: int, timeout: int = TIMEOUT) -> bytes: ret: List[int] = [] - with self.dev.acquire() as spi: - for _ in range(math.ceil(length / XFER_SIZE)): - d = self._transfer(spi, endpoint, [], timeout, max_rx_len=XFER_SIZE) - ret += d - if len(d) < XFER_SIZE: - break + for _ in range(math.ceil(length / XFER_SIZE)): + d = self._transfer(endpoint, [], timeout, max_rx_len=XFER_SIZE) + ret += d + if len(d) < XFER_SIZE: + break return bytes(ret)