diff --git a/CHANGELOG.md b/CHANGELOG.md index 0b66cad..5d13679 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Size check when restoring firmware via USB/DFU. - Added `pybricksdev.tools.chunk()` function. +- Added basic command completion to `pybricksdev lwp3 repl`. ### Fixed - Wait for some time to allow program output to be received before disconnecting in the `run` command. diff --git a/pybricksdev/cli/lwp3/repl.py b/pybricksdev/cli/lwp3/repl.py index 37e745d..b94bc55 100644 --- a/pybricksdev/cli/lwp3/repl.py +++ b/pybricksdev/cli/lwp3/repl.py @@ -7,9 +7,11 @@ """ import asyncio +from enum import Enum import inspect import logging import os +import re import struct from pathlib import Path @@ -18,6 +20,8 @@ from bleak.backends.device import BLEDevice from bleak.backends.scanner import AdvertisementData from prompt_toolkit import PromptSession +from prompt_toolkit.completion import Completer, Completion, FuzzyCompleter +from prompt_toolkit.document import Document from prompt_toolkit.history import FileHistory from prompt_toolkit.patch_stdout import StdoutProxy, patch_stdout @@ -42,26 +46,53 @@ # The first groups is any type from bytecodes that inherits from int (includes # enums/flags) or bytes. -_eval_pool.update( - { - k: v - for k, v in bytecodes.__dict__.items() - if inspect.isclass(v) - and v.__module__ == bytecodes.__name__ - and (issubclass(v, int) or issubclass(v, bytes)) - } -) +_PARAMETER_TYPES = { + k: v + for k, v in bytecodes.__dict__.items() + if inspect.isclass(v) + and v.__module__ == bytecodes.__name__ + and (issubclass(v, int) or issubclass(v, bytes)) +} + +_eval_pool.update(_PARAMETER_TYPES) # The second group are all of the non-abstract message types from the messages module. -_eval_pool.update( - { - k: v - for k, v in messages.__dict__.items() - if inspect.isclass(v) - and issubclass(v, AbstractMessage) - and not inspect.isabstract(v) - } -) +_MESSAGE_KINDS = { + k: v + for k, v in messages.__dict__.items() + if inspect.isclass(v) + and issubclass(v, AbstractMessage) + and not inspect.isabstract(v) +} + +_eval_pool.update(_MESSAGE_KINDS) + + +class _CommandCompleter(Completer): + """ + Custom completer for command prompt. + """ + + # matches words with dots in them, e.g. "Enum.MEMBER" + _MATCH_DOT = re.compile(r"[a-zA-Z0-9_\.]+") + + def get_completions(self, document: Document, complete_event): + if document.get_word_before_cursor() == ".": + # if this is a dotted word, look up the enum member + cls = _PARAMETER_TYPES.get( + document.get_word_before_cursor(pattern=self._MATCH_DOT).split(".")[0] + ) + if cls and issubclass(cls, Enum): + for m in cls: + yield Completion(m.name) + elif document.find_enclosing_bracket_left("(", ")") is not None: + # if we are inside of "(...)", list the enums and other parameter types + for p in _PARAMETER_TYPES.keys(): + yield Completion(p) + elif document.get_word_under_cursor() == "": + # if we are at the beginning of the line, list the commands + for m in _MESSAGE_KINDS.keys(): + yield Completion(m) async def repl() -> None: @@ -69,7 +100,11 @@ async def repl() -> None: Provides an interactive REPL for sending and receiving LWP3 messages. """ os.makedirs(history_file.parent, exist_ok=True) - session = PromptSession(history=FileHistory(history_file)) + + session = PromptSession( + history=FileHistory(history_file), + completer=FuzzyCompleter(_CommandCompleter()), + ) def match_lwp3_uuid(dev: BLEDevice, adv: AdvertisementData) -> None: if LWP3_HUB_SERVICE_UUID.lower() not in adv.service_uuids: