diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d2c6b9..0ea76cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `pybricksdev oad info` command. +- Added `pybricksdev oad flash` command. ## [1.0.0-alpha.50] - 2024-07-01 diff --git a/pybricksdev/ble/oad/__init__.py b/pybricksdev/ble/oad/__init__.py index dbcadd4..d387859 100644 --- a/pybricksdev/ble/oad/__init__.py +++ b/pybricksdev/ble/oad/__init__.py @@ -7,9 +7,18 @@ https://software-dl.ti.com/lprf/sdg-latest/html/oad-ble-stack-3.x/oad_profile.html """ -from ._common import oad_uuid +from ._common import OADReturn, oad_uuid +from .control_point import OADControlPoint +from .image_block import OADImageBlock +from .image_identify import OADImageIdentify -__all__ = ["OAD_SERVICE_UUID"] +__all__ = [ + "OAD_SERVICE_UUID", + "OADReturn", + "OADImageBlock", + "OADControlPoint", + "OADImageIdentify", +] OAD_SERVICE_UUID = oad_uuid(0xFFC0) """OAD service UUID.""" diff --git a/pybricksdev/ble/oad/_common.py b/pybricksdev/ble/oad/_common.py index 66fe884..4822ad5 100644 --- a/pybricksdev/ble/oad/_common.py +++ b/pybricksdev/ble/oad/_common.py @@ -2,8 +2,198 @@ # Copyright (c) 2024 The Pybricks Authors +import struct +from enum import IntEnum +from typing import NamedTuple + + def oad_uuid(uuid16: int) -> str: """ Converts a 16-bit UUID to the TI OAD 128-bit UUID format. """ return "f000{:04x}-0451-4000-b000-000000000000".format(uuid16) + + +IMAGE_ID_TI = " OAD IMG" # leading space is intentional +IMAGE_ID_LEGO = "LEGO 132" + + +class ImageType(IntEnum): + PERSISTENT_APP = 0x00 + APPLICATION = 0x01 + STACK = 0x02 + APP_AND_STACK = 0x03 + NETWORK_PROCESSOR = 0x04 + BLE_FACTORY_IMAGE = 0x05 + BIM = 0x06 + MERGED = 0x07 + + USER_0F = 0x0F + USER_10 = 0x10 + USER_11 = 0x11 + USER_12 = 0x12 + USER_13 = 0x13 + USER_14 = 0x14 + USER_15 = 0x15 + USER_16 = 0x16 + USER_17 = 0x17 + USER_18 = 0x18 + USER_19 = 0x19 + USER_1A = 0x1A + USER_1B = 0x1B + USER_1C = 0x1C + USER_1D = 0x1D + USER_1E = 0x1E + USER_1F = 0x1F + + HOST_20 = 0x20 + HOST_21 = 0x21 + HOST_22 = 0x22 + HOST_23 = 0x23 + HOST_24 = 0x24 + HOST_25 = 0x25 + HOST_26 = 0x26 + HOST_27 = 0x27 + HOST_28 = 0x28 + HOST_29 = 0x29 + HOST_2A = 0x2A + HOST_2B = 0x2B + HOST_2C = 0x2C + HOST_2D = 0x2D + HOST_2E = 0x2E + HOST_2F = 0x2F + HOST_30 = 0x30 + HOST_31 = 0x31 + HOST_32 = 0x32 + HOST_33 = 0x33 + HOST_34 = 0x34 + HOST_35 = 0x35 + HOST_36 = 0x36 + HOST_37 = 0x37 + HOST_38 = 0x38 + HOST_39 = 0x39 + HOST_3A = 0x3A + HOST_3B = 0x3B + HOST_3C = 0x3C + HOST_3D = 0x3D + HOST_3E = 0x3E + HOST_3F = 0x3F + + +class ImageCopyStatus(IntEnum): + DEFAULT_STATUS = 0xFF + IMAGE_TO_BE_COPIED = 0xFE + IMAGE_COPIED = 0xFC + + +class CRCStatus(IntEnum): + INVALID = 0b00 + VALID = 0b01 + NOT_CALCULATED = 0b11 + + UNKNOWN = 0xFF + + +DEFAULT_IMAGE_NUMBER = 0xFF + + +class ImageInfo(NamedTuple): + copy_status: ImageCopyStatus + crc_status: CRCStatus + image_type: ImageType + image_num: int + + @staticmethod + def from_bytes(data: bytes) -> "ImageInfo": + if len(data) != 4: + raise ValueError("Expected 4 bytes") + + return ImageInfo( + ImageCopyStatus(data[0]), + CRCStatus(data[1]), + ImageType(data[2]), + data[3], + ) + + def __bytes__(self): + return struct.pack( + " int: + return ((version // 10) << 4) | (version % 10) + + +def _decode_version(v: int) -> int: + return (v >> 4) * 10 + (v & 0x0F) + + +class SoftwareVersion(NamedTuple): + app: Version + stack: Version + + @staticmethod + def from_bytes(data: bytes) -> "SoftwareVersion": + if len(data) != 4: + raise ValueError("Expected 4 bytes") + + return SoftwareVersion( + Version(_decode_version(data[0]), _decode_version(data[1])), + Version(_decode_version(data[2]), _decode_version(data[3])), + ) + + def __bytes__(self): + return struct.pack( + "<4B", + _encode_version(self.app.major), + _encode_version(self.app.minor), + _encode_version(self.stack.major), + _encode_version(self.stack.minor), + ) + + +class OADReturn(IntEnum): + SUCCESS = 0 + """OAD succeeded""" + CRC_ERR = 1 + """The downloaded image’s CRC doesn’t match the one expected from the metadata""" + FLASH_ERR = 2 + """Flash function failure such as flashOpen/flashRead/flash write/flash erase""" + BUFFER_OFL = 3 + """The block number of the received packet doesn’t match the one requested, an overflow has occurred.""" + ALREADY_STARTED = 4 + """OAD start command received, while OAD is already is progress""" + NOT_STARTED = 5 + """OAD data block received with OAD start process""" + DL_NOT_COMPLETE = 6 + """OAD enable command received without complete OAD image download""" + NO_RESOURCES = 7 + """Memory allocation fails/ used only for backward compatibility""" + IMAGE_TOO_BIG = 8 + """Image is too big""" + INCOMPATIBLE_IMAGE = 9 + """Stack and flash boundary mismatch, program entry mismatch""" + INVALID_FILE = 10 + """Invalid image ID received""" + INCOMPATIBLE_FILE = 11 + """BIM/image header/firmware version mismatch""" + AUTH_FAIL = 12 + """Start OAD process / Image Identify message/image payload authentication/validation fail""" + EXT_NOT_SUPPORTED = 13 + """Data length extension or OAD control point characteristic not supported""" + DL_COMPLETE = 14 + """OAD image payload download complete""" + CCCD_NOT_ENABLED = 15 + """Internal (target side) error code used to halt the process if a CCCD has not been enabled""" + IMG_ID_TIMEOUT = 16 + """OAD Image ID has been tried too many times and has timed out. Device will disconnect.""" diff --git a/pybricksdev/ble/oad/control_point.py b/pybricksdev/ble/oad/control_point.py index 1246bad..fa94fb1 100644 --- a/pybricksdev/ble/oad/control_point.py +++ b/pybricksdev/ble/oad/control_point.py @@ -2,12 +2,12 @@ # Copyright (c) 2024 The Pybricks Authors import asyncio -import struct from enum import IntEnum +from typing import AsyncGenerator from bleak import BleakClient -from ._common import oad_uuid +from ._common import OADReturn, SoftwareVersion, oad_uuid __all__ = ["OADControlPoint"] @@ -31,45 +31,11 @@ class CmdId(IntEnum): ERASE_ALL_BONDS = 0x13 -class OADReturn(IntEnum): - SUCCESS = 0 - """OAD succeeded""" - CRC_ERR = 1 - """The downloaded image’s CRC doesn’t match the one expected from the metadata""" - FLASH_ERR = 2 - """Flash function failure such as flashOpen/flashRead/flash write/flash erase""" - BUFFER_OFL = 3 - """The block number of the received packet doesn’t match the one requested, an overflow has occurred.""" - ALREADY_STARTED = 4 - """OAD start command received, while OAD is already is progress""" - NOT_STARTED = 5 - """OAD data block received with OAD start process""" - DL_NOT_COMPLETE = 6 - """OAD enable command received without complete OAD image download""" - NO_RESOURCES = 7 - """Memory allocation fails/ used only for backward compatibility""" - IMAGE_TOO_BIG = 8 - """Image is too big""" - INCOMPATIBLE_IMAGE = 9 - """Stack and flash boundary mismatch, program entry mismatch""" - INVALID_FILE = 10 - """Invalid image ID received""" - INCOMPATIBLE_FILE = 11 - """BIM/image header/firmware version mismatch""" - AUTH_FAIL = 12 - """Start OAD process / Image Identify message/image payload authentication/validation fail""" - EXT_NOT_SUPPORTED = 13 - """Data length extension or OAD control point characteristic not supported""" - DL_COMPLETE = 14 - """OAD image payload download complete""" - CCCD_NOT_ENABLED = 15 - """Internal (target side) error code used to halt the process if a CCCD has not been enabled""" - IMG_ID_TIMEOUT = 16 - """OAD Image ID has been tried too many times and has timed out. Device will disconnect.""" - - -def _decode_version(v: int) -> int: - return (v >> 4) * 10 + (v & 0x0F) +OAD_LEGO_MARIO_DEVICE_TYPE = 0xFF150409 +"""Device type for LEGO Mario and friends.""" + +OAD_LEGO_TECHNIC_MOVE_DEVICE_TYPE = 0xFF160409 +"""Device type for LEGO Technic Move Hub.""" class OADControlPoint: @@ -91,7 +57,7 @@ def _notification_handler(self, sender, data): async def _send_command(self, cmd_id: CmdId, payload: bytes = b""): await self._client.write_gatt_char( - OAD_CONTROL_POINT_CHAR_UUID, bytes([cmd_id]) + payload + OAD_CONTROL_POINT_CHAR_UUID, bytes([cmd_id]) + payload, response=False ) rsp = await self._queue.get() @@ -129,18 +95,28 @@ async def set_image_count(self, count: int) -> OADReturn: return OADReturn(rsp[0]) - async def start_oad_process(self) -> int: + async def start_oad_process(self) -> AsyncGenerator[tuple[OADReturn, int], None]: """ Start the OAD process. Returns: Block Number """ - rsp = await self._send_command(CmdId.START_OAD_PROCESS) + await self._client.write_gatt_char( + OAD_CONTROL_POINT_CHAR_UUID, + bytes([CmdId.START_OAD_PROCESS]), + response=False, + ) - if len(rsp) != 4: - raise RuntimeError(f"Unexpected response: {rsp.hex(':')}") + while True: + rsp = await self._queue.get() - return int.from_bytes(rsp, "little") + if len(rsp) != 6 or rsp[0] != CmdId.IMAGE_BLOCK_WRITE_CHAR: + raise RuntimeError(f"Unexpected response: {rsp.hex(':')}") + + status = OADReturn(rsp[1]) + block_num = int.from_bytes(rsp[2:], "little") + + yield status, block_num async def enable_oad_image(self) -> OADReturn: """ @@ -182,7 +158,7 @@ async def disable_oad_image_block_write(self) -> OADReturn: return OADReturn(rsp[0]) - async def get_software_version(self) -> tuple[tuple[int, int], tuple[int, int]]: + async def get_software_version(self) -> SoftwareVersion: """ Get the software version. @@ -193,10 +169,7 @@ async def get_software_version(self) -> tuple[tuple[int, int], tuple[int, int]]: if len(rsp) != 4: raise RuntimeError(f"Unexpected response: {rsp.hex(':')}") - return ( - (_decode_version(rsp[0]), _decode_version(rsp[1])), - (_decode_version(rsp[2]), _decode_version(rsp[3])), - ) + return SoftwareVersion.from_bytes(rsp) async def get_oad_image_status(self) -> OADReturn: """ @@ -237,21 +210,6 @@ async def get_device_type(self) -> int: return int.from_bytes(rsp, "little") - async def image_block_write(self, prev_status: int, block_num: int) -> None: - """ - Write an image block. - - Args: - prev_status: Status of the previous block received - block_num: Block number - """ - rsp = await self._send_command( - CmdId.IMAGE_BLOCK_WRITE_CHAR, struct.pack(" OADReturn: """ Erase all bonds. diff --git a/pybricksdev/ble/oad/firmware.py b/pybricksdev/ble/oad/firmware.py new file mode 100644 index 0000000..42752cf --- /dev/null +++ b/pybricksdev/ble/oad/firmware.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 The Pybricks Authors + +import struct +from typing import NamedTuple + +from ._common import ImageInfo, SoftwareVersion + +# More info at: +# https://github.com/TexasInstruments/simplelink-lowpower-f3-sdk/blob/main/tools/common/oad/oad_image_tool.py + + +class ODAHeader(NamedTuple): + image_id: str + image_crc: int + bmi_version: int + header_version: int + wireless_tech: int + image_info: ImageInfo + image_validation: int + image_length: int + program_entry_address: int + software_version: int + image_end_address: int + image_header_length: int + rfu2: int + + +def parse_oad_header(firmware: bytes) -> ODAHeader: + ( + image_id, + image_crc, + bmi_version, + header_version, + wireless_tech, + image_info, + image_validation, + image_length, + program_entry_address, + software_version, + image_end_address, + image_header_length, + rfu2, + ) = struct.unpack_from( + "<8sI2BH4s3I4sI2H", + firmware, + ) + + return ODAHeader( + image_id.decode("ascii"), + image_crc, + bmi_version, + header_version, + wireless_tech, + ImageInfo.from_bytes(image_info), + image_validation, + image_length, + program_entry_address, + SoftwareVersion.from_bytes(software_version), + image_end_address, + image_header_length, + rfu2, + ) diff --git a/pybricksdev/ble/oad/image_block.py b/pybricksdev/ble/oad/image_block.py new file mode 100644 index 0000000..4b454bc --- /dev/null +++ b/pybricksdev/ble/oad/image_block.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 The Pybricks Authors + +""" +TI OAD (Over-the-Air Download) Image Block characteristic. + +https://software-dl.ti.com/lprf/sdg-latest/html/oad-ble-stack-3.x/oad_profile.html#oad-image-block-characteristic-0xffc2 +""" + + +from bleak import BleakClient + +from ._common import oad_uuid + +__all__ = ["OADImageBlock"] + +OAD_IMAGE_BLOCK_CHAR_UUID = oad_uuid(0xFFC2) + + +class OADImageBlock: + def __init__(self, client: BleakClient): + self._client = client + + async def write(self, block_num: int, data: bytes) -> None: + """ + Write an image block. + + Args: + offset: Offset of the block. + data: Block data. + + Returns: None. + """ + await self._client.write_gatt_char( + OAD_IMAGE_BLOCK_CHAR_UUID, + block_num.to_bytes(4, "little") + data, + response=False, + ) diff --git a/pybricksdev/ble/oad/image_identify.py b/pybricksdev/ble/oad/image_identify.py new file mode 100644 index 0000000..f342a01 --- /dev/null +++ b/pybricksdev/ble/oad/image_identify.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2024 The Pybricks Authors + +""" +TI OAD (Over-the-Air Download) Image Identify characteristic. + +https://software-dl.ti.com/lprf/sdg-latest/html/oad-ble-stack-3.x/oad_profile.html#oad-image-identify-0xffc1 +""" + +import asyncio +import struct + +from bleak import BleakClient + +from ._common import ImageInfo, OADReturn, SoftwareVersion, oad_uuid + +__all__ = ["OADImageIdentify"] + +OAD_IMAGE_IDENTIFY_CHAR_UUID = oad_uuid(0xFFC1) +"""OAD Image Identify characteristic UUID.""" + + +class OADImageIdentify: + def __init__(self, client: BleakClient): + self._client = client + self._queue = asyncio.Queue[bytes]() + + def _notification_handler(self, sender, data): + self._queue.put_nowait(data) + + async def __aenter__(self): + await self._client.start_notify( + OAD_IMAGE_IDENTIFY_CHAR_UUID, self._notification_handler + ) + return self + + async def __aexit__(self, *exc_info): + await self._client.stop_notify(OAD_IMAGE_IDENTIFY_CHAR_UUID) + + async def validate( + self, + img_id: str, + bmi_ver: int, + header_ver: int, + image_info: ImageInfo, + image_len: int, + sw_ver: SoftwareVersion, + ) -> OADReturn: + """ + Validate the image header. + + Returns: True if the image header is valid. + """ + data = struct.pack( + "<8s2B4sI4s", + bytes(img_id, "ascii"), + bmi_ver, + header_ver, + bytes(image_info), + image_len, + bytes(sw_ver), + ) + + await self._client.write_gatt_char( + OAD_IMAGE_IDENTIFY_CHAR_UUID, data, response=False + ) + rsp = await self._queue.get() + + return OADReturn(rsp[0]) diff --git a/pybricksdev/cli/__init__.py b/pybricksdev/cli/__init__.py index f1873a4..ee80357 100644 --- a/pybricksdev/cli/__init__.py +++ b/pybricksdev/cli/__init__.py @@ -290,6 +290,26 @@ def run(self, args: argparse.Namespace): return self.subparsers.choices[args.action].tool.run(args) +class OADFlash(Tool): + def add_parser(self, subparsers: argparse._SubParsersAction): + parser = subparsers.add_parser( + "flash", + help="update firmware on a LEGO Powered Up device using TI OAD", + ) + parser.tool = self + parser.add_argument( + "firmware", + metavar="", + type=argparse.FileType(mode="rb"), + help="the firmware .oda file", + ).completer = FilesCompleter(allowednames=(".oda",)) + + async def run(self, args: argparse.Namespace): + from .oad import flash_oad_image + + await flash_oad_image(args.firmware) + + class OADInfo(Tool): def add_parser(self, subparsers: argparse._SubParsersAction): parser = subparsers.add_parser( @@ -315,7 +335,7 @@ def add_parser(self, subparsers: argparse._SubParsersAction): metavar="", dest="action", help="the action to perform" ) - for tool in (OADInfo(),): + for tool in OADFlash(), OADInfo(): tool.add_parser(self.subparsers) def run(self, args: argparse.Namespace): diff --git a/pybricksdev/cli/oad.py b/pybricksdev/cli/oad.py index 7c72fd4..12cbad7 100644 --- a/pybricksdev/cli/oad.py +++ b/pybricksdev/cli/oad.py @@ -2,19 +2,32 @@ # Copyright (c) 2024 The Pybricks Authors import asyncio +from typing import BinaryIO from bleak import BleakClient, BleakScanner from bleak.backends.device import BLEDevice from bleak.backends.scanner import AdvertisementData +from tqdm.auto import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm from ..ble.lwp3 import LEGO_CID, LWP3_HUB_SERVICE_UUID, HubKind -from ..ble.oad.control_point import OADControlPoint +from ..ble.oad import OADControlPoint, OADImageBlock, OADImageIdentify, OADReturn +from ..ble.oad.control_point import ( + OAD_LEGO_MARIO_DEVICE_TYPE, + OAD_LEGO_TECHNIC_MOVE_DEVICE_TYPE, +) +from ..ble.oad.firmware import parse_oad_header -__all__ = ["dump_oad_info"] +__all__ = ["dump_oad_info", "flash_oad_image"] # hubs known to use TI OAD _OAD_HUBS = [HubKind.MARIO, HubKind.LUIGI, HubKind.PEACH, HubKind.TECHNIC_MOVE] +_KNOWN_DEVICE_TYPES = { + OAD_LEGO_MARIO_DEVICE_TYPE: "LEGO Mario", + OAD_LEGO_TECHNIC_MOVE_DEVICE_TYPE: "LEGO Technic Move Hub", +} + def _match_oad_hubs(dev: BLEDevice, adv: AdvertisementData): """ @@ -33,6 +46,91 @@ def _match_oad_hubs(dev: BLEDevice, adv: AdvertisementData): return kind in _OAD_HUBS +async def flash_oad_image(firmware: BinaryIO) -> None: + """ + Connects to an OAD hub and flashes a firmware image to it. + """ + + firmware_bytes = firmware.read() + + header = parse_oad_header(firmware_bytes) + + print("Scanning for hubs...") + device = await BleakScanner.find_device_by_filter(_match_oad_hubs) + + if device is None: + print("No OAD device found") + return + + # long timeout in case pairing is needed + async with asyncio.timeout(60), BleakClient(device) as client, OADImageIdentify( + client + ) as image_identify, OADControlPoint(client) as control_point: + image_block = OADImageBlock(client) + + print(f"Connected to {device.name}") + + dev_type = await control_point.get_device_type() + + # TODO: match this based on firmware image target + if dev_type not in _KNOWN_DEVICE_TYPES: + print(f"Unsupported device type: {dev_type:08X}") + return + + block_size = await control_point.get_oad_block_size() + + status = await image_identify.validate( + header.image_id, + header.bmi_version, + header.header_version, + header.image_info, + header.image_length, + header.software_version, + ) + if status != OADReturn.SUCCESS: + print(f"Failed to validate image: {status.name}") + return + + sent_blocks = set() + + print("Flashing...") + + with logging_redirect_tqdm(), tqdm( + total=header.image_length, unit="B", unit_scale=True + ) as pbar: + async with asyncio.TaskGroup() as group: + try: + async for ( + status, + block_num, + ) in control_point.start_oad_process(): + if status == OADReturn.SUCCESS: + data = firmware_bytes[ + block_num + * (block_size - 4) : (block_num + 1) + * (block_size - 4) + ] + + task = group.create_task(image_block.write(block_num, data)) + + if block_num not in sent_blocks: + task.add_done_callback(lambda _: pbar.update(len(data))) + sent_blocks.add(block_num) + + elif status == OADReturn.DL_COMPLETE: + break + else: + print( + f"Block {block_num} with unhandled status: {status.name}" + ) + except BaseException: + await control_point.cancel_oad() + raise + + # This causes hub to reset and disconnect + await control_point.enable_oad_image() + + async def dump_oad_info(): """ Connects to an OAD hub and prints some information about it. @@ -43,20 +141,25 @@ async def dump_oad_info(): print("No OAD device found") return - async with BleakClient(device) as client, OADControlPoint(client) as control_point: - # long timeout in case pairing is needed - async with asyncio.timeout(30): - sw_ver = await control_point.get_software_version() - print(f"Software version: {sw_ver}") + # long timeout in case pairing is needed + async with asyncio.timeout(30), BleakClient(device) as client, OADControlPoint( + client + ) as control_point: + sw_ver = await control_point.get_software_version() + print( + f"Software version: app={sw_ver.app.major}.{sw_ver.app.minor}, stack={sw_ver.stack.major}.{sw_ver.stack.minor}" + ) - profile_ver = await control_point.get_profile_version() - print(f"Profile version: {profile_ver}") + profile_ver = await control_point.get_profile_version() + print(f"Profile version: {profile_ver}") - dev_type = await control_point.get_device_type() - print(f"Device type: {dev_type:08X}") + dev_type = await control_point.get_device_type() + print( + f"Device type: {dev_type:08X} ({_KNOWN_DEVICE_TYPES.get(dev_type, 'Unknown')})" + ) - block_size = await control_point.get_oad_block_size() - print(f"Block size: {block_size}") + block_size = await control_point.get_oad_block_size() + print(f"Block size: {block_size}") - image_status = await control_point.get_oad_image_status() - print(f"Image status: {image_status.name}") + image_status = await control_point.get_oad_image_status() + print(f"Image status: {image_status.name}")