From ddaeef975049a1738a22297808b967bb2a459939 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Sun, 14 Jul 2024 11:38:54 -0400 Subject: [PATCH] Cleanly shut down the serial port on disconnect --- zigpy_deconz/api.py | 27 +++++++-------- zigpy_deconz/uart.py | 54 +++++++++++------------------- zigpy_deconz/zigbee/application.py | 4 +-- 3 files changed, 35 insertions(+), 50 deletions(-) diff --git a/zigpy_deconz/api.py b/zigpy_deconz/api.py index e91991a..2a2d77d 100644 --- a/zigpy_deconz/api.py +++ b/zigpy_deconz/api.py @@ -14,7 +14,6 @@ else: from asyncio import timeout as asyncio_timeout # pragma: no cover -from zigpy.config import CONF_DEVICE_PATH from zigpy.datastructures import PriorityLock from zigpy.types import ( APSStatus, @@ -461,37 +460,37 @@ def protocol_version(self) -> int: async def connect(self) -> None: assert self._uart is None + self._uart = await zigpy_deconz.uart.connect(self._config, self) - await self.version() + try: + await self.version() + device_state_rsp = await self.send_command(CommandId.device_state) + except Exception: + await self.disconnect() + self._uart = None + raise - device_state_rsp = await self.send_command(CommandId.device_state) self._device_state = device_state_rsp["device_state"] self._data_poller_task = asyncio.create_task(self._data_poller()) - def connection_lost(self, exc: Exception) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Lost serial connection.""" - LOGGER.debug( - "Serial %r connection lost unexpectedly: %r", - self._config[CONF_DEVICE_PATH], - exc, - ) - if self._app is not None: self._app.connection_lost(exc) - def close(self): - self._app = None - + async def disconnect(self): if self._data_poller_task is not None: self._data_poller_task.cancel() self._data_poller_task = None if self._uart is not None: - self._uart.close() + await self._uart.disconnect() self._uart = None + self._app = None + def _get_command_priority(self, command: Command) -> int: return { # The watchdog is fed using `write_parameter` and `get_device_state` so they diff --git a/zigpy_deconz/uart.py b/zigpy_deconz/uart.py index f555787..42457d1 100644 --- a/zigpy_deconz/uart.py +++ b/zigpy_deconz/uart.py @@ -1,9 +1,11 @@ """Uart module.""" +from __future__ import annotations + import asyncio import binascii import logging -from typing import Callable, Dict +from typing import Any, Callable import zigpy.config import zigpy.serial @@ -11,49 +13,37 @@ LOGGER = logging.getLogger(__name__) -class Gateway(asyncio.Protocol): +class Gateway(zigpy.serial.SerialProtocol): END = b"\xC0" ESC = b"\xDB" ESC_END = b"\xDC" ESC_ESC = b"\xDD" - def __init__(self, api, connected_future=None): + def __init__(self, api): """Initialize instance of the UART gateway.""" - + super().__init__() self._api = api - self._buffer = b"" - self._connected_future = connected_future - self._transport = None - def connection_lost(self, exc) -> None: + def connection_lost(self, exc: Exception | None) -> None: """Port was closed expectedly or unexpectedly.""" + super().connection_lost(exc) - if exc is not None: - LOGGER.warning("Lost connection: %r", exc, exc_info=exc) - - self._api.connection_lost(exc) - - def connection_made(self, transport): - """Call this when the uart connection is established.""" - - LOGGER.debug("Connection made") - self._transport = transport - if self._connected_future and not self._connected_future.done(): - self._connected_future.set_result(True) + if self._api is not None: + self._api.connection_lost(exc) def close(self): - self._transport.close() + self._api = None - def send(self, data): + def send(self, data: bytes) -> None: """Send data, taking care of escaping and framing.""" - LOGGER.debug("Send: %s", binascii.hexlify(data).decode()) checksum = bytes(self._checksum(data)) frame = self._escape(data + checksum) - self._transport.write(self.END + frame + self.END) + self.send_data(self.END + frame + self.END) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Handle data received from the uart.""" - self._buffer += data + super().data_received(data) + while self._buffer: end = self._buffer.find(self.END) if end < 0: @@ -121,23 +111,19 @@ def _checksum(self, data): return bytes(ret) -async def connect(config: Dict[str, any], api: Callable) -> Gateway: - loop = asyncio.get_running_loop() - connected_future = loop.create_future() - protocol = Gateway(api, connected_future) +async def connect(config: dict[str, Any], api: Callable) -> Gateway: + protocol = Gateway(api) LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH]) _, protocol = await zigpy.serial.create_serial_connection( - loop=loop, + loop=asyncio.get_running_loop(), protocol_factory=lambda: protocol, url=config[zigpy.config.CONF_DEVICE_PATH], baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE], xonxoff=False, ) - await connected_future - - LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH]) + await protocol.wait_until_connected() return protocol diff --git a/zigpy_deconz/zigbee/application.py b/zigpy_deconz/zigbee/application.py index b3294c6..fa0233f 100644 --- a/zigpy_deconz/zigbee/application.py +++ b/zigpy_deconz/zigbee/application.py @@ -96,7 +96,7 @@ async def connect(self): try: await api.connect() except Exception: - api.close() + await api.disconnect() raise self._api = api @@ -108,7 +108,7 @@ async def disconnect(self): self._delayed_neighbor_scan_task = None if self._api is not None: - self._api.close() + await self._api.disconnect() self._api = None async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):