PandaDFU: abstract out USB comms (#1274)

* wip

* revert that

* split list + connect

* some more

* mypy fix

* add clear status back

* rename

* cleanup

* cleaner mypy fix

---------

Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
Adeeb Shihadeh 2023-03-06 09:24:00 -08:00 committed by GitHub
parent 946f952aa7
commit 18230831f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 143 additions and 68 deletions

View File

@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional
# This mimics the handle given by libusb1 for easy interoperability
class BaseHandle(ABC):
"""
A handle to talk to a panda.
Borrows heavily from the libusb1 handle API.
"""
@abstractmethod
def close(self) -> None:
...
@ -23,3 +26,30 @@ class BaseHandle(ABC):
@abstractmethod
def bulkRead(self, endpoint: int, length: int, timeout: int = 0) -> bytes:
...
class BaseSTBootloaderHandle(ABC):
"""
A handle to talk to a panda while it's in the STM32 bootloader.
"""
@abstractmethod
def close(self) -> None:
...
@abstractmethod
def clear_status(self) -> None:
...
@abstractmethod
def program(self, address: int, dat: bytes, block_size: Optional[int] = None) -> None:
...
@abstractmethod
def erase(self, address: int) -> None:
...
@abstractmethod
def jump(self, address: int) -> None:
...

View File

@ -1,20 +1,29 @@
import usb1
import struct
import binascii
from typing import List, Tuple, Optional
from .base import BaseSTBootloaderHandle
from .usb import STBootloaderUSBHandle
from .constants import McuType
# *** DFU mode ***
DFU_DNLOAD = 1
DFU_UPLOAD = 2
DFU_GETSTATUS = 3
DFU_CLRSTATUS = 4
DFU_ABORT = 6
class PandaDFU:
def __init__(self, dfu_serial):
self._handle = None
def __init__(self, dfu_serial: Optional[str]):
# try USB, then SPI
handle, mcu_type = PandaDFU.usb_connect(dfu_serial)
if None in (handle, mcu_type):
handle, mcu_type = PandaDFU.spi_connect(dfu_serial)
if handle is None or mcu_type is None:
raise Exception(f"failed to open DFU device {dfu_serial}")
self._handle: BaseSTBootloaderHandle = handle
self._mcu_type: McuType = mcu_type
@staticmethod
def usb_connect(dfu_serial: Optional[str]) -> Tuple[Optional[BaseSTBootloaderHandle], Optional[McuType]]:
handle, mcu_type = None, None
context = usb1.USBContext()
for device in context.getDeviceList(skip_on_error=True):
if device.getVendorID() == 0x0483 and device.getProductID() == 0xdf11:
@ -23,15 +32,26 @@ class PandaDFU:
except Exception:
continue
if this_dfu_serial == dfu_serial or dfu_serial is None:
self._handle = device.open()
self._mcu_type = self.get_mcu_type(device)
handle = STBootloaderUSBHandle(device.open())
# TODO: Find a way to detect F4 vs F2
# TODO: also check F4 BCD, don't assume in else
mcu_type = McuType.H7 if device.getbcdDevice() == 512 else McuType.F4
break
if self._handle is None:
raise Exception(f"failed to open DFU device {dfu_serial}")
return handle, mcu_type
@staticmethod
def list():
def spi_connect(dfu_serial: Optional[str]) -> Tuple[Optional[BaseSTBootloaderHandle], Optional[McuType]]:
return None, None
@staticmethod
def list() -> List[str]:
ret = PandaDFU.usb_list()
ret += PandaDFU.spi_list()
return list(set(ret))
@staticmethod
def usb_list() -> List[str]:
context = usb1.USBContext()
dfu_serials = []
try:
@ -45,6 +65,10 @@ class PandaDFU:
pass
return dfu_serials
@staticmethod
def spi_list() -> List[str]:
return []
@staticmethod
def st_serial_to_dfu_serial(st: str, mcu_type: McuType = McuType.F4):
if st is None or st == "none":
@ -55,52 +79,20 @@ class PandaDFU:
else:
return binascii.hexlify(struct.pack("!HHH", uid_base[1] + uid_base[5], uid_base[0] + uid_base[4] + 0xA, uid_base[3])).upper().decode("utf-8")
def get_mcu_type(self, dev) -> McuType:
# TODO: Find a way to detect F4 vs F2
# TODO: also check F4 BCD, don't assume in else
return McuType.H7 if dev.getbcdDevice() == 512 else McuType.F4
def get_mcu_type(self) -> McuType:
return self._mcu_type
def status(self):
while 1:
dat = self._handle.controlRead(0x21, DFU_GETSTATUS, 0, 0, 6)
if dat[1] == 0:
break
def erase(self, address: int) -> None:
self._handle.erase(address)
def clear_status(self):
# Clear status
stat = self._handle.controlRead(0x21, DFU_GETSTATUS, 0, 0, 6)
if stat[4] == 0xa:
self._handle.controlRead(0x21, DFU_CLRSTATUS, 0, 0, 0)
elif stat[4] == 0x9:
self._handle.controlWrite(0x21, DFU_ABORT, 0, 0, b"")
self.status()
stat = str(self._handle.controlRead(0x21, DFU_GETSTATUS, 0, 0, 6))
def erase(self, address):
self._handle.controlWrite(0x21, DFU_DNLOAD, 0, 0, b"\x41" + struct.pack("I", address))
self.status()
def program(self, address, dat, block_size=None):
if block_size is None:
block_size = len(dat)
# Set Address Pointer
self._handle.controlWrite(0x21, DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", address))
self.status()
# Program
dat += b"\xFF" * ((block_size - len(dat)) % block_size)
for i in range(0, len(dat) // block_size):
ldat = dat[i * block_size:(i + 1) * block_size]
print("programming %d with length %d" % (i, len(ldat)))
self._handle.controlWrite(0x21, DFU_DNLOAD, 2 + i, 0, ldat)
self.status()
def reset(self):
self._handle.jump(self._mcu_type.config.bootstub_address)
def program_bootstub(self, code_bootstub):
self.clear_status()
self._handle.clear_status()
self.erase(self._mcu_type.config.bootstub_address)
self.erase(self._mcu_type.config.app_address)
self.program(self._mcu_type.config.bootstub_address, code_bootstub, self._mcu_type.config.block_size)
self._handle.program(self._mcu_type.config.bootstub_address, code_bootstub, self._mcu_type.config.block_size)
self.reset()
def recover(self):
@ -108,11 +100,3 @@ class PandaDFU:
code = f.read()
self.program_bootstub(code)
def reset(self):
self._handle.controlWrite(0x21, DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", self._mcu_type.config.bootstub_address))
self.status()
try:
self._handle.controlWrite(0x21, DFU_DNLOAD, 2, 0, b"")
_ = str(self._handle.controlRead(0x21, DFU_GETSTATUS, 0, 0, 6))
except Exception:
pass

View File

@ -1,6 +1,8 @@
from typing import List
import struct
from typing import List, Optional
from .base import BaseHandle, BaseSTBootloaderHandle
from .base import BaseHandle
class PandaUsbHandle(BaseHandle):
def __init__(self, libusb_handle):
@ -21,3 +23,62 @@ class PandaUsbHandle(BaseHandle):
def bulkRead(self, endpoint: int, length: int, timeout: int = 0) -> bytes:
return self._libusb_handle.bulkRead(endpoint, length, timeout) # type: ignore
class STBootloaderUSBHandle(BaseSTBootloaderHandle):
DFU_DNLOAD = 1
DFU_UPLOAD = 2
DFU_GETSTATUS = 3
DFU_CLRSTATUS = 4
DFU_ABORT = 6
def __init__(self, libusb_handle):
self._libusb_handle = libusb_handle
def _status(self) -> None:
while 1:
dat = self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6)
if dat[1] == 0:
break
def clear_status(self):
# Clear status
stat = self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6)
if stat[4] == 0xa:
self._libusb_handle.controlRead(0x21, self.DFU_CLRSTATUS, 0, 0, 0)
elif stat[4] == 0x9:
self._libusb_handle.controlWrite(0x21, self.DFU_ABORT, 0, 0, b"")
self._status()
stat = str(self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6))
def close(self):
self._libusb_handle.close()
def program(self, address: int, dat: bytes, block_size: Optional[int] = None) -> None:
if block_size is None:
block_size = len(dat)
# Set Address Pointer
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", address))
self._status()
# Program
dat += b"\xFF" * ((block_size - len(dat)) % block_size)
for i in range(0, len(dat) // block_size):
ldat = dat[i * block_size:(i + 1) * block_size]
print("programming %d with length %d" % (i, len(ldat)))
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 2 + i, 0, ldat)
self._status()
def erase(self, address):
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x41" + struct.pack("I", address))
self._status()
def jump(self, address):
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", address))
self._status()
try:
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 2, 0, b"")
_ = str(self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6))
except Exception:
pass

View File

@ -12,10 +12,10 @@ def test_dfu(p):
assert Panda.wait_for_dfu(dfu_serial, timeout=20), "failed to enter DFU"
dfu = PandaDFU(dfu_serial)
assert dfu._mcu_type == app_mcu_type
assert dfu.get_mcu_type() == app_mcu_type
assert dfu_serial in PandaDFU.list()
dfu.clear_status()
dfu._handle.clear_status()
dfu.reset()
p.reconnect()