Skip to content

Commit

Permalink
Cleanly shut down the serial port on disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly committed Jul 14, 2024
1 parent 8d2315d commit ddaeef9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 50 deletions.
27 changes: 13 additions & 14 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
else:
from asyncio import timeout as asyncio_timeout # pragma: no cover

from zigpy.config import CONF_DEVICE_PATH
from zigpy.datastructures import PriorityLock
from zigpy.types import (
APSStatus,
Expand Down Expand Up @@ -461,37 +460,37 @@ def protocol_version(self) -> int:

async def connect(self) -> None:
assert self._uart is None

self._uart = await zigpy_deconz.uart.connect(self._config, self)

await self.version()
try:
await self.version()
device_state_rsp = await self.send_command(CommandId.device_state)
except Exception:
await self.disconnect()
self._uart = None
raise

device_state_rsp = await self.send_command(CommandId.device_state)
self._device_state = device_state_rsp["device_state"]

self._data_poller_task = asyncio.create_task(self._data_poller())

def connection_lost(self, exc: Exception) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Lost serial connection."""
LOGGER.debug(
"Serial %r connection lost unexpectedly: %r",
self._config[CONF_DEVICE_PATH],
exc,
)

if self._app is not None:
self._app.connection_lost(exc)

def close(self):
self._app = None

async def disconnect(self):
if self._data_poller_task is not None:
self._data_poller_task.cancel()
self._data_poller_task = None

if self._uart is not None:
self._uart.close()
await self._uart.disconnect()
self._uart = None

self._app = None

def _get_command_priority(self, command: Command) -> int:
return {
# The watchdog is fed using `write_parameter` and `get_device_state` so they
Expand Down
54 changes: 20 additions & 34 deletions zigpy_deconz/uart.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,49 @@
"""Uart module."""

from __future__ import annotations

import asyncio
import binascii
import logging
from typing import Callable, Dict
from typing import Any, Callable

import zigpy.config
import zigpy.serial

LOGGER = logging.getLogger(__name__)


class Gateway(asyncio.Protocol):
class Gateway(zigpy.serial.SerialProtocol):
END = b"\xC0"
ESC = b"\xDB"
ESC_END = b"\xDC"
ESC_ESC = b"\xDD"

def __init__(self, api, connected_future=None):
def __init__(self, api):
"""Initialize instance of the UART gateway."""

super().__init__()
self._api = api
self._buffer = b""
self._connected_future = connected_future
self._transport = None

def connection_lost(self, exc) -> None:
def connection_lost(self, exc: Exception | None) -> None:
"""Port was closed expectedly or unexpectedly."""
super().connection_lost(exc)

if exc is not None:
LOGGER.warning("Lost connection: %r", exc, exc_info=exc)

self._api.connection_lost(exc)

def connection_made(self, transport):
"""Call this when the uart connection is established."""

LOGGER.debug("Connection made")
self._transport = transport
if self._connected_future and not self._connected_future.done():
self._connected_future.set_result(True)
if self._api is not None:
self._api.connection_lost(exc)

def close(self):
self._transport.close()
self._api = None

def send(self, data):
def send(self, data: bytes) -> None:
"""Send data, taking care of escaping and framing."""
LOGGER.debug("Send: %s", binascii.hexlify(data).decode())
checksum = bytes(self._checksum(data))
frame = self._escape(data + checksum)
self._transport.write(self.END + frame + self.END)
self.send_data(self.END + frame + self.END)

def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Handle data received from the uart."""
self._buffer += data
super().data_received(data)

while self._buffer:
end = self._buffer.find(self.END)
if end < 0:
Expand Down Expand Up @@ -121,23 +111,19 @@ def _checksum(self, data):
return bytes(ret)


async def connect(config: Dict[str, any], api: Callable) -> Gateway:
loop = asyncio.get_running_loop()
connected_future = loop.create_future()
protocol = Gateway(api, connected_future)
async def connect(config: dict[str, Any], api: Callable) -> Gateway:
protocol = Gateway(api)

LOGGER.debug("Connecting to %s", config[zigpy.config.CONF_DEVICE_PATH])

_, protocol = await zigpy.serial.create_serial_connection(
loop=loop,
loop=asyncio.get_running_loop(),
protocol_factory=lambda: protocol,
url=config[zigpy.config.CONF_DEVICE_PATH],
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
xonxoff=False,
)

await connected_future

LOGGER.debug("Connected to %s", config[zigpy.config.CONF_DEVICE_PATH])
await protocol.wait_until_connected()

return protocol
4 changes: 2 additions & 2 deletions zigpy_deconz/zigbee/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def connect(self):
try:
await api.connect()
except Exception:
api.close()
await api.disconnect()
raise

self._api = api
Expand All @@ -108,7 +108,7 @@ async def disconnect(self):
self._delayed_neighbor_scan_task = None

if self._api is not None:
self._api.close()
await self._api.disconnect()
self._api = None

async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60):
Expand Down

0 comments on commit ddaeef9

Please sign in to comment.