Skip to content

Commit

Permalink
Add a unit test and make sure command futures are always removed
Browse files Browse the repository at this point in the history
  • Loading branch information
puddly committed Dec 17, 2023
1 parent ff45b1d commit 02862a7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
48 changes: 43 additions & 5 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
import zigpy.config
import zigpy.types as zigpy_t

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout
else:
from asyncio import timeout as asyncio_timeout

from zigpy_deconz import api as deconz_api, types as t, uart
import zigpy_deconz.exception
import zigpy_deconz.zigbee.application
Expand Down Expand Up @@ -86,7 +91,7 @@ async def mock_connect(config, api):

@pytest.fixture
async def mock_command_rsp(gateway):
def inner(command_id, params, rsp, *, replace=False):
def inner(command_id, params, rsp, *, rsp_command=None, replace=False):
if (
getattr(getattr(gateway.send, "side_effect", None), "_handlers", None)
is None
Expand All @@ -107,15 +112,18 @@ def receiver(data):

kwargs, rest = t.deserialize_dict(command.payload, schema)

for params, mock in receiver._handlers[command.command_id]:
for params, rsp_command, mock in receiver._handlers[command.command_id]:
if rsp_command is None:
rsp_command = command.command_id

if all(kwargs[k] == v for k, v in params.items()):
_, rx_schema = deconz_api.COMMAND_SCHEMAS[command.command_id]
_, rx_schema = deconz_api.COMMAND_SCHEMAS[rsp_command]
ret = mock(**kwargs)

asyncio.get_running_loop().call_soon(
gateway._api.data_received,
deconz_api.Command(
command_id=command.command_id,
command_id=rsp_command,
seq=command.seq,
payload=t.serialize_dict(ret, rx_schema),
).serialize(),
Expand All @@ -128,7 +136,9 @@ def receiver(data):
gateway.send.side_effect._handlers[command_id].clear()

mock = MagicMock(return_value=rsp)
gateway.send.side_effect._handlers[command_id].append((params, mock))
gateway.send.side_effect._handlers[command_id].append(
(params, rsp_command, mock)
)

return mock

Expand Down Expand Up @@ -993,3 +1003,31 @@ async def test_cb3_device_state_callback_bug(api, mock_command_rsp):
await asyncio.sleep(0.01)

assert api._device_state == device_state


async def test_firmware_responding_with_wrong_type_with_correct_seq(
api, mock_command_rsp, caplog
):
await api.connect()

mock_command_rsp(
command_id=deconz_api.CommandId.aps_data_confirm,
params={},
# Completely different response
rsp_command=deconz_api.CommandId.version,
rsp={
"status": deconz_api.Status.SUCCESS,
"frame_length": t.uint16_t(9),
"version": deconz_api.FirmwareVersion(0x26450900),
},
)

with caplog.at_level(logging.DEBUG):
with pytest.raises(asyncio.TimeoutError):
async with asyncio_timeout(0.5):
await api.send_command(deconz_api.CommandId.aps_data_confirm)

assert (
"Firmware responded incorrectly (Response is mismatched! Sent"
" <CommandId.aps_data_confirm: 4>, received <CommandId.version: 13>), retrying"
) in caplog.text
15 changes: 6 additions & 9 deletions zigpy_deconz/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,7 @@ def __init__(self, app: Callable, device_config: dict[str, Any]):
self._app = app

# [seq][cmd_id] = [fut1, fut2, ...]
self._awaiting = collections.defaultdict(
lambda: collections.defaultdict(lambda: collections.deque([]))
)
self._awaiting = collections.defaultdict(lambda: collections.defaultdict(list))
self._command_lock = asyncio.Lock()
self._config = device_config
self._device_state = DeviceState(
Expand Down Expand Up @@ -574,11 +572,10 @@ async def _command(self, cmd, **kwargs):
async with asyncio_timeout(COMMAND_TIMEOUT):
return await fut
except asyncio.TimeoutError:
LOGGER.warning(
"No response to '%s' command with seq id '0x%02x'", cmd, seq
)
self._awaiting[seq][cmd].remove(fut)
LOGGER.debug("No response to '%s' command with seq %d", cmd, seq)
raise
finally:
self._awaiting[seq][cmd].remove(fut)

def data_received(self, data: bytes) -> None:
command, _ = Command.deserialize(data)
Expand All @@ -593,13 +590,13 @@ def data_received(self, data: bytes) -> None:
wrong_fut_cmd_id = None

try:
fut = self._awaiting[command.seq][command.command_id].popleft()
fut = self._awaiting[command.seq][command.command_id][0]
except IndexError:
# XXX: The firmware can sometimes respond with the wrong response. Find the
# future associated with it so we can throw an appropriate error.
for cmd_id, futs in self._awaiting[command.seq].items():
if futs:
fut = futs.popleft()
fut = futs[0]
wrong_fut_cmd_id = cmd_id
break

Expand Down

0 comments on commit 02862a7

Please sign in to comment.