python: lock hw device inside loop (#1522)

This commit is contained in:
Adeeb Shihadeh
2023-07-18 00:03:17 -07:00
committed by GitHub
parent b2741013f6
commit 268f6bc2fb
4 changed files with 72 additions and 67 deletions

View File

@@ -9,10 +9,10 @@ import logging
import threading
from contextlib import contextmanager
from functools import reduce
from typing import List, Optional
from typing import Callable, List, Optional
from .base import BaseHandle, BaseSTBootloaderHandle, TIMEOUT
from .constants import McuType, MCU_TYPE_BY_IDCODE
from .constants import McuType, MCU_TYPE_BY_IDCODE, USBPACKET_MAX_SIZE
from .utils import crc8_pedal
try:
@@ -115,13 +115,13 @@ class PandaSpiHandle(BaseHandle):
PROTOCOL_VERSION = 1
def __init__(self):
def __init__(self) -> None:
self.dev = SpiDevice()
self._transfer = self._transfer_spidev
self._transfer_raw: Callable[[SpiDevice, int, bytes, int, int, bool], bytes] = self._transfer_spidev
if "KERN" in os.environ:
self._transfer = self._transfer_kernel_driver
self._transfer_raw = self._transfer_kernel_driver
self.tx_buf = bytearray(1024)
self.rx_buf = bytearray(1024)
@@ -134,25 +134,65 @@ class PandaSpiHandle(BaseHandle):
self.fileno = self.dev._spidev.fileno()
# helpers
def _calc_checksum(self, data: List[int]) -> int:
def _calc_checksum(self, data: bytes) -> int:
cksum = CHECKSUM_START
for b in data:
cksum ^= b
return cksum
def _wait_for_ack(self, spi, ack_val: int, timeout: int, tx: int) -> None:
def _wait_for_ack(self, spi, ack_val: int, timeout: int, tx: int, length: int = 1) -> bytes:
timeout_s = max(MIN_ACK_TIMEOUT_MS, timeout) * 1e-3
start = time.monotonic()
while (timeout == 0) or ((time.monotonic() - start) < timeout_s):
dat = spi.xfer2([tx, ])[0]
if dat == NACK:
dat = spi.xfer2([tx, ] * length)
if dat[0] == NACK:
raise PandaSpiNackResponse
elif dat == ack_val:
return
elif dat[0] == ack_val:
return bytes(dat)
raise PandaSpiMissingAck
def _transfer_spidev(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes:
max_rx_len = max(USBPACKET_MAX_SIZE, max_rx_len)
logging.debug("- send header")
packet = struct.pack("<BBHH", SYNC, endpoint, len(data), max_rx_len)
packet += bytes([self._calc_checksum(packet), ])
spi.xfer2(packet)
logging.debug("- waiting for header ACK")
self._wait_for_ack(spi, HACK, MIN_ACK_TIMEOUT_MS, 0x11)
# send data
logging.debug("- sending data")
packet = bytes([*data, self._calc_checksum(data)])
spi.xfer2(packet)
if expect_disconnect:
logging.debug("- expecting disconnect, returning")
return b""
else:
logging.debug("- waiting for data ACK")
preread_len = USBPACKET_MAX_SIZE + 1 # read enough for a controlRead
dat = self._wait_for_ack(spi, DACK, timeout, 0x13, length=3 + preread_len)
# get response length, then response
response_len = struct.unpack("<H", dat[1:3])[0]
if response_len > max_rx_len:
raise PandaSpiException(f"response length greater than max ({max_rx_len} {response_len})")
# read rest
remaining = response_len - preread_len
if remaining > 0:
dat += bytes(spi.readbytes(remaining))
dat = dat[:3 + response_len + 1]
if self._calc_checksum(dat) != 0:
raise PandaSpiBadChecksum
return dat[3:-1]
def _transfer_kernel_driver(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes:
self.tx_buf[:len(data)] = data
self.ioctl_data.endpoint = endpoint
@@ -169,7 +209,7 @@ class PandaSpiHandle(BaseHandle):
raise PandaSpiException(f"ioctl returned {ret}")
return bytes(self.rx_buf[:ret])
def _transfer_spidev(self, spi, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes:
def _transfer(self, endpoint: int, data, timeout: int, max_rx_len: int = 1000, expect_disconnect: bool = False) -> bytes:
logging.debug("starting transfer: endpoint=%d, max_rx_len=%d", endpoint, max_rx_len)
logging.debug("==============================================")
@@ -179,44 +219,12 @@ class PandaSpiHandle(BaseHandle):
while (time.monotonic() - start_time) < timeout*1e-3:
n += 1
logging.debug("\ntry #%d", n)
try:
logging.debug("- send header")
packet = struct.pack("<BBHH", SYNC, endpoint, len(data), max_rx_len)
packet += bytes([reduce(lambda x, y: x^y, packet) ^ CHECKSUM_START])
spi.xfer2(packet)
to = timeout - (time.monotonic() - start_time)*1e3
logging.debug("- waiting for header ACK")
self._wait_for_ack(spi, HACK, int(to), 0x11)
# send data
logging.debug("- sending data")
packet = bytes([*data, self._calc_checksum(data)])
spi.xfer2(packet)
if expect_disconnect:
logging.debug("- expecting disconnect, returning")
return b""
else:
to = timeout - (time.monotonic() - start_time)*1e3
logging.debug("- waiting for data ACK")
self._wait_for_ack(spi, DACK, int(to), 0x13)
# get response length, then response
response_len_bytes = bytes(spi.xfer2(b"\x00" * 2))
response_len = struct.unpack("<H", response_len_bytes)[0]
if response_len > max_rx_len:
raise PandaSpiException("response length greater than max")
logging.debug("- receiving response")
dat = bytes(spi.xfer2(b"\x00" * (response_len + 1)))
if self._calc_checksum([DACK, *response_len_bytes, *dat]) != 0:
raise PandaSpiBadChecksum
return dat[:-1]
except PandaSpiException as e:
exc = e
logging.debug("SPI transfer failed, retrying", exc_info=True)
with self.dev.acquire() as spi:
try:
return self._transfer_raw(spi, endpoint, data, timeout, max_rx_len, expect_disconnect)
except PandaSpiException as e:
exc = e
logging.debug("SPI transfer failed, retrying", exc_info=True)
raise exc
@@ -261,28 +269,24 @@ class PandaSpiHandle(BaseHandle):
self.dev.close()
def controlWrite(self, request_type: int, request: int, value: int, index: int, data, timeout: int = TIMEOUT, expect_disconnect: bool = False):
with self.dev.acquire() as spi:
return self._transfer(spi, 0, struct.pack("<BHHH", request, value, index, 0), timeout, expect_disconnect=expect_disconnect)
return self._transfer(0, struct.pack("<BHHH", request, value, index, 0), timeout, expect_disconnect=expect_disconnect)
def controlRead(self, request_type: int, request: int, value: int, index: int, length: int, timeout: int = TIMEOUT):
with self.dev.acquire() as spi:
return self._transfer(spi, 0, struct.pack("<BHHH", request, value, index, length), timeout)
return self._transfer(0, struct.pack("<BHHH", request, value, index, length), timeout, max_rx_len=length)
# TODO: implement these properly
def bulkWrite(self, endpoint: int, data: List[int], timeout: int = TIMEOUT) -> int:
with self.dev.acquire() as spi:
for x in range(math.ceil(len(data) / XFER_SIZE)):
self._transfer(spi, endpoint, data[XFER_SIZE*x:XFER_SIZE*(x+1)], timeout)
return len(data)
for x in range(math.ceil(len(data) / XFER_SIZE)):
self._transfer(endpoint, data[XFER_SIZE*x:XFER_SIZE*(x+1)], timeout)
return len(data)
def bulkRead(self, endpoint: int, length: int, timeout: int = TIMEOUT) -> bytes:
ret: List[int] = []
with self.dev.acquire() as spi:
for _ in range(math.ceil(length / XFER_SIZE)):
d = self._transfer(spi, endpoint, [], timeout, max_rx_len=XFER_SIZE)
ret += d
if len(d) < XFER_SIZE:
break
for _ in range(math.ceil(length / XFER_SIZE)):
d = self._transfer(endpoint, [], timeout, max_rx_len=XFER_SIZE)
ret += d
if len(d) < XFER_SIZE:
break
return bytes(ret)