PandaDFU: SPI support (#1270)

* PandaDFU: SPI support

* get mcu type

* program bootstub

* little cleanup

* more cleanup

* connect by dfu serial

* time to remove that

* none

* fix linter

* little more

* catch

---------

Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
Adeeb Shihadeh
2023-03-06 21:52:08 -08:00
committed by GitHub
parent 18230831f3
commit efb36197bb
7 changed files with 217 additions and 181 deletions

View File

@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List
from .constants import McuType
class BaseHandle(ABC):
@@ -33,6 +35,10 @@ class BaseSTBootloaderHandle(ABC):
A handle to talk to a panda while it's in the STM32 bootloader.
"""
@abstractmethod
def get_mcu_type(self) -> McuType:
...
@abstractmethod
def close(self) -> None:
...
@@ -42,11 +48,15 @@ class BaseSTBootloaderHandle(ABC):
...
@abstractmethod
def program(self, address: int, dat: bytes, block_size: Optional[int] = None) -> None:
def program(self, address: int, dat: bytes) -> None:
...
@abstractmethod
def erase(self, address: int) -> None:
def erase_app(self) -> None:
...
@abstractmethod
def erase_bootstub(self) -> None:
...
@abstractmethod

View File

@@ -8,6 +8,7 @@ BASEDIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
class McuConfig(NamedTuple):
mcu: str
mcu_idcode: int
uid_address: int
block_size: int
sector_sizes: List[int]
serial_number_address: int
@@ -17,6 +18,7 @@ class McuConfig(NamedTuple):
bootstub_path: str
Fx = (
0x1FFF7A10,
0x800,
[0x4000 for _ in range(4)] + [0x10000] + [0x20000 for _ in range(11)],
0x1FFF79C0,
@@ -31,6 +33,7 @@ F4Config = McuConfig("STM32F4", 0x463, *Fx)
H7Config = McuConfig(
"STM32H7",
0x483,
0x1FF1E800,
0x400,
# there is an 8th sector, but we use that for the provisioning chunk, so don't program over that!
[0x20000 for _ in range(7)],
@@ -50,3 +53,5 @@ class McuType(enum.Enum):
@property
def config(self):
return self.value
MCU_TYPE_BY_IDCODE = {m.config.mcu_idcode: m for m in McuType}

View File

@@ -1,9 +1,10 @@
import usb1
import struct
import binascii
from typing import List, Tuple, Optional
from typing import List, Optional
from .base import BaseSTBootloaderHandle
from .spi import STBootloaderSPIHandle, PandaSpiException
from .usb import STBootloaderUSBHandle
from .constants import McuType
@@ -11,19 +12,20 @@ from .constants import McuType
class PandaDFU:
def __init__(self, dfu_serial: Optional[str]):
# try USB, then SPI
handle, mcu_type = PandaDFU.usb_connect(dfu_serial)
if None in (handle, mcu_type):
handle, mcu_type = PandaDFU.spi_connect(dfu_serial)
handle: Optional[BaseSTBootloaderHandle]
handle = PandaDFU.usb_connect(dfu_serial)
if handle is None:
handle = PandaDFU.spi_connect(dfu_serial)
if handle is None or mcu_type is None:
if handle is None:
raise Exception(f"failed to open DFU device {dfu_serial}")
self._handle: BaseSTBootloaderHandle = handle
self._mcu_type: McuType = mcu_type
self._mcu_type: McuType = self._handle.get_mcu_type()
@staticmethod
def usb_connect(dfu_serial: Optional[str]) -> Tuple[Optional[BaseSTBootloaderHandle], Optional[McuType]]:
handle, mcu_type = None, None
def usb_connect(dfu_serial: Optional[str]) -> Optional[STBootloaderUSBHandle]:
handle = None
context = usb1.USBContext()
for device in context.getDeviceList(skip_on_error=True):
if device.getVendorID() == 0x0483 and device.getProductID() == 0xdf11:
@@ -31,18 +33,28 @@ class PandaDFU:
this_dfu_serial = device.open().getASCIIStringDescriptor(3)
except Exception:
continue
if this_dfu_serial == dfu_serial or dfu_serial is None:
handle = STBootloaderUSBHandle(device.open())
# TODO: Find a way to detect F4 vs F2
# TODO: also check F4 BCD, don't assume in else
mcu_type = McuType.H7 if device.getbcdDevice() == 512 else McuType.F4
handle = STBootloaderUSBHandle(device, device.open())
break
return handle, mcu_type
return handle
@staticmethod
def spi_connect(dfu_serial: Optional[str]) -> Tuple[Optional[BaseSTBootloaderHandle], Optional[McuType]]:
return None, None
def spi_connect(dfu_serial: Optional[str]) -> Optional[STBootloaderSPIHandle]:
handle = None
this_dfu_serial = None
try:
handle = STBootloaderSPIHandle()
this_dfu_serial = PandaDFU.st_serial_to_dfu_serial(handle.get_uid(), handle.get_mcu_type())
except PandaSpiException:
handle = None
if dfu_serial is not None and dfu_serial != this_dfu_serial:
handle = None
return handle
@staticmethod
def list() -> List[str]:
@@ -67,6 +79,13 @@ class PandaDFU:
@staticmethod
def spi_list() -> List[str]:
try:
h = PandaDFU.spi_connect(None)
if h is not None:
dfu_serial = PandaDFU.st_serial_to_dfu_serial(h.get_uid(), h.get_mcu_type())
return [dfu_serial, ]
except PandaSpiException:
pass
return []
@staticmethod
@@ -82,17 +101,14 @@ class PandaDFU:
def get_mcu_type(self) -> McuType:
return self._mcu_type
def erase(self, address: int) -> None:
self._handle.erase(address)
def reset(self):
self._handle.jump(self._mcu_type.config.bootstub_address)
def program_bootstub(self, code_bootstub):
self._handle.clear_status()
self.erase(self._mcu_type.config.bootstub_address)
self.erase(self._mcu_type.config.app_address)
self._handle.program(self._mcu_type.config.bootstub_address, code_bootstub, self._mcu_type.config.block_size)
self._handle.erase_bootstub()
self._handle.erase_app()
self._handle.program(self._mcu_type.config.bootstub_address, code_bootstub)
self.reset()
def recover(self):

View File

@@ -1,3 +1,4 @@
import binascii
import os
import fcntl
import math
@@ -7,9 +8,10 @@ import logging
import threading
from contextlib import contextmanager
from functools import reduce
from typing import List
from typing import List, Optional
from .base import BaseHandle
from .base import BaseHandle, BaseSTBootloaderHandle
from .constants import McuType, MCU_TYPE_BY_IDCODE
try:
import spidev
@@ -172,3 +174,136 @@ class PandaSpiHandle(BaseHandle):
if len(d) < USB_MAX_SIZE:
break
return bytes(ret)
class STBootloaderSPIHandle(BaseSTBootloaderHandle):
"""
Implementation of the STM32 SPI bootloader protocol described in:
https://www.st.com/resource/en/application_note/an4286-spi-protocol-used-in-the-stm32-bootloader-stmicroelectronics.pdf
"""
SYNC = 0x5A
ACK = 0x79
NACK = 0x1F
def __init__(self):
self.dev = SpiDevice(speed=1000000)
# say hello
try:
with self.dev.acquire() as spi:
spi.xfer([self.SYNC, ])
try:
self._get_ack(spi)
except PandaSpiNackResponse:
# NACK ok here, will only ACK the first time
pass
self._mcu_type = MCU_TYPE_BY_IDCODE[self.get_chip_id()]
except PandaSpiException:
raise PandaSpiException("failed to connect to panda") # pylint: disable=W0707
def _get_ack(self, spi, timeout=1.0):
data = 0x00
start_time = time.monotonic()
while data not in (self.ACK, self.NACK) and (time.monotonic() - start_time < timeout):
data = spi.xfer([0x00, ])[0]
time.sleep(0.001)
spi.xfer([self.ACK, ])
if data == self.NACK:
raise PandaSpiNackResponse
elif data != self.ACK:
raise PandaSpiMissingAck
def _cmd(self, cmd: int, data: Optional[List[bytes]] = None, read_bytes: int = 0, predata=None) -> bytes:
ret = b""
with self.dev.acquire() as spi:
# sync + command
spi.xfer([self.SYNC, ])
spi.xfer([cmd, cmd ^ 0xFF])
self._get_ack(spi)
# "predata" - for commands that send the first data without a checksum
if predata is not None:
spi.xfer(predata)
self._get_ack(spi)
# send data
if data is not None:
for d in data:
if predata is not None:
spi.xfer(d + self._checksum(predata + d))
else:
spi.xfer(d + self._checksum(d))
self._get_ack(spi, timeout=20)
# receive
if read_bytes > 0:
ret = spi.xfer([0x00, ]*(read_bytes + 1))[1:]
if data is None or len(data) == 0:
self._get_ack(spi)
return bytes(ret)
def _checksum(self, data: bytes) -> bytes:
if len(data) == 1:
ret = data[0] ^ 0xFF
else:
ret = reduce(lambda a, b: a ^ b, data)
return bytes([ret, ])
# *** Bootloader commands ***
def read(self, address: int, length: int):
data = [struct.pack('>I', address), struct.pack('B', length - 1)]
return self._cmd(0x11, data=data, read_bytes=length)
def get_chip_id(self) -> int:
r = self._cmd(0x02, read_bytes=3)
assert r[0] == 1 # response length - 1
return ((r[1] << 8) + r[2])
def go_cmd(self, address: int) -> None:
self._cmd(0x21, data=[struct.pack('>I', address), ])
# *** helpers ***
def get_uid(self):
dat = self.read(McuType.H7.config.uid_address, 12)
return binascii.hexlify(dat).decode()
def erase_sector(self, sector: int):
p = struct.pack('>H', 0) # number of sectors to erase
d = struct.pack('>H', sector)
self._cmd(0x44, data=[d, ], predata=p)
# *** PandaDFU API ***
def erase_app(self):
self.erase_sector(1)
def erase_bootstub(self):
self.erase_sector(0)
def get_mcu_type(self):
return self._mcu_type
def clear_status(self):
pass
def close(self):
self.dev.close()
def program(self, address, dat):
bs = 256 # max block size for writing to flash over SPI
dat += b"\xFF" * ((bs - len(dat)) % bs)
for i in range(0, len(dat) // bs):
block = dat[i * bs:(i + 1) * bs]
self._cmd(0x31, data=[
struct.pack('>I', address + i*bs),
bytes([len(block) - 1]) + block,
])
def jump(self, address):
self.go_cmd(self._mcu_type.config.bootstub_address)

View File

@@ -1,118 +0,0 @@
import time
import struct
from functools import reduce
from .constants import McuType
from .spi import SpiDevice
SYNC = 0x5A
ACK = 0x79
NACK = 0x1F
# https://www.st.com/resource/en/application_note/an4286-spi-protocol-used-in-the-stm32-bootloader-stmicroelectronics.pdf
class PandaSpiDFU:
def __init__(self, dfu_serial):
self.dev = SpiDevice(speed=1000000)
# say hello
with self.dev.acquire() as spi:
try:
spi.xfer([SYNC, ])
self._get_ack(spi)
except Exception:
raise Exception("failed to connect to panda") # pylint: disable=W0707
self._mcu_type = self.get_mcu_type()
def _get_ack(self, spi, timeout=1.0):
data = 0x00
start_time = time.monotonic()
while data not in (ACK, NACK) and (time.monotonic() - start_time < timeout):
data = spi.xfer([0x00, ])[0]
time.sleep(0.001)
spi.xfer([ACK, ])
if data == NACK:
raise Exception("Got NACK response")
elif data != ACK:
raise Exception("Missing ACK")
def _cmd(self, cmd, data=None, read_bytes=0) -> bytes:
ret = b""
with self.dev.acquire() as spi:
# sync
spi.xfer([SYNC, ])
# send command
spi.xfer([cmd, cmd ^ 0xFF])
self._get_ack(spi)
# send data
if data is not None:
for d in data:
spi.xfer(self.add_checksum(d))
self._get_ack(spi, timeout=20)
# receive
if read_bytes > 0:
# send busy byte
ret = spi.xfer([0x00, ]*(read_bytes + 1))[1:]
self._get_ack(spi)
return ret
def add_checksum(self, data):
return data + bytes([reduce(lambda a, b: a ^ b, data)])
# ***** ST Bootloader functions *****
def get_bootloader_version(self) -> int:
return self._cmd(0x01, read_bytes=1)[0]
def get_id(self) -> int:
ret = self._cmd(0x02, read_bytes=3)
assert ret[0] == 1
return ((ret[1] << 8) + ret[2])
def go_cmd(self, address: int) -> None:
self._cmd(0x21, data=[struct.pack('>I', address), ])
def erase(self, address: int) -> None:
d = struct.pack('>H', address)
self._cmd(0x44, data=[d, ])
# ***** panda api *****
def get_mcu_type(self) -> McuType:
mcu_by_id = {mcu.config.mcu_idcode: mcu for mcu in McuType}
return mcu_by_id[self.get_id()]
def global_erase(self):
self.erase(0xFFFF)
def program_file(self, address, fn):
with open(fn, 'rb') as f:
code = f.read()
i = 0
while i < len(code):
#print(i, len(code))
block = code[i:i+256]
if len(block) < 256:
block += b'\xFF' * (256 - len(block))
self._cmd(0x31, data=[
struct.pack('>I', address + i),
bytes([len(block) - 1]) + block,
])
#print(f"Written {len(block)} bytes to {hex(address + i)}")
i += 256
def program_bootstub(self):
self.program_file(self._mcu_type.config.bootstub_address, self._mcu_type.config.bootstub_path)
def program_app(self):
self.program_file(self._mcu_type.config.app_address, self._mcu_type.config.app_path)
def reset(self):
self.go_cmd(self._mcu_type.config.bootstub_address)

View File

@@ -1,8 +1,8 @@
import struct
from typing import List, Optional
from typing import List
from .base import BaseHandle, BaseSTBootloaderHandle
from .constants import McuType
class PandaUsbHandle(BaseHandle):
def __init__(self, libusb_handle):
@@ -32,15 +32,32 @@ class STBootloaderUSBHandle(BaseSTBootloaderHandle):
DFU_CLRSTATUS = 4
DFU_ABORT = 6
def __init__(self, libusb_handle):
def __init__(self, libusb_device, libusb_handle):
self._libusb_handle = libusb_handle
# TODO: Find a way to detect F4 vs F2
# TODO: also check F4 BCD, don't assume in else
self._mcu_type = McuType.H7 if libusb_device.getbcdDevice() == 512 else McuType.F4
def _status(self) -> None:
while 1:
dat = self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6)
if dat[1] == 0:
break
def _erase_page_address(self, address: int) -> None:
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x41" + struct.pack("I", address))
self._status()
def get_mcu_type(self):
return self._mcu_type
def erase_app(self):
self._erase_page_address(self._mcu_type.config.app_address)
def erase_bootstub(self):
self._erase_page_address(self._mcu_type.config.bootstub_address)
def clear_status(self):
# Clear status
stat = self._libusb_handle.controlRead(0x21, self.DFU_GETSTATUS, 0, 0, 6)
@@ -54,26 +71,20 @@ class STBootloaderUSBHandle(BaseSTBootloaderHandle):
def close(self):
self._libusb_handle.close()
def program(self, address: int, dat: bytes, block_size: Optional[int] = None) -> None:
if block_size is None:
block_size = len(dat)
def program(self, address, dat):
# Set Address Pointer
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", address))
self._status()
# Program
dat += b"\xFF" * ((block_size - len(dat)) % block_size)
for i in range(0, len(dat) // block_size):
ldat = dat[i * block_size:(i + 1) * block_size]
bs = self._mcu_type.config.block_size
dat += b"\xFF" * ((bs - len(dat)) % bs)
for i in range(0, len(dat) // bs):
ldat = dat[i * bs:(i + 1) * bs]
print("programming %d with length %d" % (i, len(ldat)))
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 2 + i, 0, ldat)
self._status()
def erase(self, address):
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x41" + struct.pack("I", address))
self._status()
def jump(self, address):
self._libusb_handle.controlWrite(0x21, self.DFU_DNLOAD, 0, 0, b"\x21" + struct.pack("I", address))
self._status()