From f204cda9a13a64528224479a1ef4ee0b07c3dea4 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 12:26:47 -0500 Subject: [PATCH 01/18] EZSP v18 --- bellows/ezsp/__init__.py | 5 +- bellows/ezsp/v18/__init__.py | 106 +++++++++++++++++++++++++++++++++++ bellows/ezsp/v18/commands.py | 105 ++++++++++++++++++++++++++++++++++ bellows/ezsp/v18/config.py | 16 ++++++ bellows/types/struct.py | 20 +++++++ 5 files changed, 250 insertions(+), 2 deletions(-) create mode 100644 bellows/ezsp/v18/__init__.py create mode 100644 bellows/ezsp/v18/commands.py create mode 100644 bellows/ezsp/v18/config.py diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index e4441486..b5bd087f 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -29,11 +29,11 @@ import bellows.types as t import bellows.uart -from . import v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v16, v17 +from . import v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v16, v17, v18 RESET_ATTEMPTS = 3 -EZSP_LATEST = v17.EZSPv17.VERSION +EZSP_LATEST = v18.EZSPv18.VERSION LOGGER = logging.getLogger(__name__) MTOR_MIN_INTERVAL = 60 MTOR_MAX_INTERVAL = 3600 @@ -61,6 +61,7 @@ class EZSP: v14.EZSPv14.VERSION: v14.EZSPv14, v16.EZSPv16.VERSION: v16.EZSPv16, v17.EZSPv17.VERSION: v17.EZSPv17, + v18.EZSPv18.VERSION: v18.EZSPv18, } def __init__(self, device_config: dict, application: Any | None = None): diff --git a/bellows/ezsp/v18/__init__.py b/bellows/ezsp/v18/__init__.py new file mode 100644 index 00000000..996ee8d2 --- /dev/null +++ b/bellows/ezsp/v18/__init__.py @@ -0,0 +1,106 @@ +""""EZSP Protocol version 18 protocol handler.""" +from __future__ import annotations + +import voluptuous as vol + +import bellows.config +import bellows.types as t + +from . import commands, config +from ..v17 import EZSPv17 + + +class EZSPv18(EZSPv17): + """EZSP Version 18 Protocol version handler.""" + + VERSION = 18 + COMMANDS = commands.COMMANDS + SCHEMAS = { + bellows.config.CONF_EZSP_CONFIG: vol.Schema(config.EZSP_SCHEMA), + bellows.config.CONF_EZSP_POLICIES: vol.Schema(config.EZSP_POLICIES_SCH), + } + + async def send_unicast( + self, + nwk: t.NWK, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + data: bytes, + ) -> tuple[t.sl_Status, t.uint8_t]: + status, sequence = await self.sendUnicast( + message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + nwk=nwk, + aps_frame=t.EmberApsFrameV18( + profileId=aps_frame.profileId, + clusterId=aps_frame.clusterId, + sourceEndpoint=aps_frame.sourceEndpoint, + destinationEndpoint=aps_frame.destinationEndpoint, + options=aps_frame.options, + groupId=aps_frame.groupId, + sequence=aps_frame.sequence, + radius=0, + ), + message_tag=message_tag, + message=data, + ) + + return status, sequence + + async def send_multicast( + self, + aps_frame: t.EmberApsFrame, + radius: t.uint8_t, + non_member_radius: t.uint8_t, + message_tag: t.uint8_t, + data: bytes, + ) -> tuple[t.sl_Status, t.uint8_t]: + status, sequence = await self.sendMulticast( + aps_frame=t.EmberApsFrameV18( + profileId=aps_frame.profileId, + clusterId=aps_frame.clusterId, + sourceEndpoint=aps_frame.sourceEndpoint, + destinationEndpoint=aps_frame.destinationEndpoint, + options=aps_frame.options, + groupId=aps_frame.groupId, + sequence=aps_frame.sequence, + radius=radius, + ), + hops=radius, + broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, + alias=0x0000, + sequence=aps_frame.sequence, + message_tag=message_tag, + message=data, + ) + + return status, sequence + + async def send_broadcast( + self, + address: t.BroadcastAddress, + aps_frame: t.EmberApsFrame, + radius: t.uint8_t, + message_tag: t.uint8_t, + aps_sequence: t.uint8_t, + data: bytes, + ) -> tuple[t.sl_Status, t.uint8_t]: + status, sequence = await self.sendBroadcast( + alias=0x0000, + destination=address, + sequence=aps_sequence, + aps_frame=t.EmberApsFrameV18( + profileId=aps_frame.profileId, + clusterId=aps_frame.clusterId, + sourceEndpoint=aps_frame.sourceEndpoint, + destinationEndpoint=aps_frame.destinationEndpoint, + options=aps_frame.options, + groupId=aps_frame.groupId, + sequence=aps_frame.sequence, + radius=radius, + ), + radius=radius, + message_tag=message_tag, + message=data, + ) + + return status, sequence diff --git a/bellows/ezsp/v18/commands.py b/bellows/ezsp/v18/commands.py new file mode 100644 index 00000000..fdec8623 --- /dev/null +++ b/bellows/ezsp/v18/commands.py @@ -0,0 +1,105 @@ +from zigpy.types import EUI64, NWK, BroadcastAddress + +import bellows.types as t + +from ..v17.commands import COMMANDS as COMMANDS_v17 + +COMMANDS = { + **COMMANDS_v17, + "sendUnicast": ( + 0x0034, + { + "message_type": t.EmberOutgoingMessageType, + "nwk": NWK, + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "message_tag": t.uint16_t, + "message": t.LVBytes, + }, + { + "status": t.sl_Status, + "sequence": t.uint8_t, + }, + ), + "sendBroadcast": ( + 0x0036, + { + "alias": t.uint16_t, + "destination": BroadcastAddress, + "sequence": t.uint8_t, + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "radius": t.uint8_t, + "message_tag": t.uint16_t, + "message": t.LVBytes, + }, + { + "status": t.sl_Status, + "sequence": t.uint8_t, + }, + ), + "sendMulticast": ( + 0x0038, + { + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "hops": t.uint8_t, + "broadcast_addr": t.BroadcastAddress, + "alias": t.uint16_t, + "sequence": t.uint8_t, + "message_tag": t.uint16_t, + "message": t.LVBytes, + }, + { + "status": t.sl_Status, + "sequence": t.uint8_t, + }, + ), + "sendReply": ( + 0x0039, + { + "sender": t.NWK, + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "message": t.LVBytes, + }, + { + "status": t.sl_Status, + }, + ), + "incomingMessageHandler": ( + 0x0045, + {}, + { + "message_type": t.EmberIncomingMessageType, + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "nwk": NWK, + "eui64": EUI64, + "binding_index": t.uint8_t, + "address_index": t.uint8_t, + "lqi": t.uint8_t, + "rssi": t.int8s, + "timestamp": t.uint32_t, + "message": t.LVBytes, + }, + ), + "messageSentHandler": ( + 0x003F, + {}, + { + "status": t.sl_Status, + "message_type": t.EmberOutgoingMessageType, + "nwk": NWK, + "aps_frame": t.EmberApsFrameV18, # APS frame format has changed + "message_tag": t.uint16_t, + "message": t.LVBytes, + }, + ), + "macFilterMatchMessageHandler": ( + 0x46, + {}, + { + "filterValueMatch": t.uint16_t, # Was `filterIndexMatch: uint8_t` + "legacyPassthroughType": t.EmberMacPassthroughType, + "lastHopLqi": t.uint8_t, + "lastHopRssi": t.int8s, + "messageContents": t.LVBytes, + }, + ), +} diff --git a/bellows/ezsp/v18/config.py b/bellows/ezsp/v18/config.py new file mode 100644 index 00000000..167fa57a --- /dev/null +++ b/bellows/ezsp/v18/config.py @@ -0,0 +1,16 @@ +import voluptuous as vol + +from bellows.config import cv_uint16 +from bellows.types import EzspPolicyId + +from ..v4.config import EZSP_POLICIES_SHARED +from ..v17 import config as v17_config + +EZSP_SCHEMA = { + **v17_config.EZSP_SCHEMA, +} + +EZSP_POLICIES_SCH = { + **EZSP_POLICIES_SHARED, + **{vol.Optional(policy.name): cv_uint16 for policy in EzspPolicyId}, +} diff --git a/bellows/types/struct.py b/bellows/types/struct.py index 91a17c01..df34a440 100644 --- a/bellows/types/struct.py +++ b/bellows/types/struct.py @@ -67,6 +67,26 @@ class EmberApsFrame(EzspStruct): sequence: basic.uint8_t +class EmberApsFrameV18(EzspStruct): + # ZigBee APS frame parameters (EZSP v18+). + # The application profile ID that describes the format of the message. + profileId: basic.uint16_t + # The cluster ID for this message. + clusterId: basic.uint16_t + # The source endpoint. + sourceEndpoint: basic.uint8_t + # The destination endpoint. + destinationEndpoint: basic.uint8_t + # A bitmask of options. + options: named.EmberApsOption + # The group ID for this message, if it is multicast mode. + groupId: basic.uint16_t + # The sequence number. + sequence: basic.uint8_t + # The radius of the message. (Added in EZSP v18) + radius: basic.uint8_t + + class EmberBindingTableEntry(EzspStruct): # An entry in the binding table. # The type of binding. From 5d76d1ba0840600605ceab71648ad0ccab31fb84 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 13:18:23 -0500 Subject: [PATCH 02/18] Remove unused `_handle_no_such_device` --- bellows/zigbee/application.py | 11 ----------- tests/test_application.py | 29 ----------------------------- 2 files changed, 40 deletions(-) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 1bd0271c..8dd8b6a1 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -800,17 +800,6 @@ def _handle_frame_sent( exc, ) - async def _handle_no_such_device(self, sender: int) -> None: - """Try to match unknown device by its EUI64 address.""" - status, ieee = await self._ezsp.lookupEui64ByNodeId(nodeId=sender) - status = t.sl_Status.from_ember_status(status) - - if status == t.sl_Status.OK: - LOGGER.debug("Found %s ieee for %s sender", ieee, sender) - self.handle_join(sender, ieee, 0) - return - LOGGER.debug("Couldn't look up ieee for %s", sender) - def _handle_tc_join_handler( self, nwk: t.EmberNodeId, diff --git a/tests/test_application.py b/tests/test_application.py index 018dd632..72865c7e 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1600,35 +1600,6 @@ def test_handle_id_conflict(app, ieee): assert app.handle_leave.call_args[0][0] == nwk -async def test_handle_no_such_device(app, ieee): - """Test handling of an unknown device IEEE lookup.""" - - app._ezsp.lookupEui64ByNodeId = AsyncMock() - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.ERR_FATAL, ieee)), - ) - p2 = patch.object(app, "handle_join") - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 0 - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.SUCCESS, sentinel.ieee)), - ) - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 1 - assert handle_join_mock.call_args[0][0] == sentinel.nwk - assert handle_join_mock.call_args[0][1] == sentinel.ieee - - async def test_cleanup_tc_link_key(app): """Test cleaning up tc link key.""" ezsp = app._ezsp From 17549d2e71d7c093428b90a9b053c7970bab811e Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 17:05:38 -0500 Subject: [PATCH 03/18] Move packet reception logic into the protocol handler --- bellows/ezsp/protocol.py | 79 +++++-------- bellows/ezsp/v14/__init__.py | 118 +++++++++++++++++++ bellows/ezsp/v4/__init__.py | 113 +++++++++++++++++++ bellows/zigbee/application.py | 141 ++++------------------- tests/test_application.py | 206 +++++++++++++++------------------- tests/test_ezsp_protocol.py | 56 ++++----- 6 files changed, 407 insertions(+), 306 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 30006dc4..f86d261a 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -11,7 +11,9 @@ from typing import TYPE_CHECKING, Any from zigpy.datastructures import PriorityDynamicBoundedSemaphore +from zigpy.event.event_base import EventBase import zigpy.state +import zigpy.types from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError @@ -27,13 +29,14 @@ MAX_COMMAND_CONCURRENCY = 1 -class ProtocolHandler(abc.ABC): +class ProtocolHandler(EventBase, abc.ABC): """EZSP protocol specific handler.""" COMMANDS = {} VERSION = None def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: + super().__init__() self._handle_callback = cb_handler self._awaiting = {} self._gw = gateway @@ -179,52 +182,6 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) - if ( - frame_name == "incomingMessageHandler" - and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT - ): - # Extract received APS frame and sender - aps_frame = result[1] - sender = result[4] - - # The fragment count and index are encoded in the groupId field - fragment_count = (aps_frame.groupId >> 8) & 0xFF - fragment_index = aps_frame.groupId & 0xFF - - ( - complete, - reassembled, - frag_count, - frag_index, - ) = self._fragment_manager.handle_incoming_fragment( - sender_nwk=sender, - aps_sequence=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - fragment_count=fragment_count, - fragment_index=fragment_index, - payload=result[7], - ) - - ack_task = asyncio.create_task( - self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) - ) # APS Ack - - self._fragment_ack_tasks.add(ack_task) - ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) - - if not complete: - # Do not pass partial data up the stack - LOGGER.debug("Fragment reassembly not complete. waiting for more data.") - return - - # Replace partial data with fully reassembled data - result[7] = reassembled - - LOGGER.debug( - "Reassembled fragmented message. Proceeding with normal handling." - ) - if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) try: @@ -246,8 +203,19 @@ def __call__(self, data: bytes) -> None: sequence, self.COMMANDS_BY_ID.get(expected_id, [expected_id])[0], ) - else: - self._handle_callback(frame_name, result) + + return + + # Handle callbacks via version-specific methods that emit events + if frame_name == "incomingMessageHandler": + if not self._handle_incoming_message(result): + # Fragment incomplete, skip legacy callback + return + elif frame_name == "messageSentHandler": + self._handle_message_sent(result) + + # Always call legacy callback handler for backwards compatibility + self._handle_callback(frame_name, result) async def _send_fragment_ack( self, @@ -386,3 +354,16 @@ async def set_extended_timeout( self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True ) -> None: raise NotImplementedError() + + @abc.abstractmethod + def _handle_incoming_message(self, args: list) -> bool: + """Handle incomingMessageHandler callback and emit packet_received event. + + Returns True if message was fully handled, False if fragment is incomplete. + """ + raise NotImplementedError + + @abc.abstractmethod + def _handle_message_sent(self, args: list) -> None: + """Handle messageSentHandler callback and emit message_sent event.""" + raise NotImplementedError diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index 16dbeec7..ba868ef5 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -1,11 +1,14 @@ """"EZSP Protocol version 14 protocol handler.""" from __future__ import annotations +import asyncio from collections.abc import AsyncGenerator +import logging import voluptuous as vol from zigpy.exceptions import NetworkNotFormed import zigpy.state +import zigpy.types import bellows.config import bellows.types as t @@ -13,6 +16,8 @@ from . import commands, config from ..v13 import EZSPv13 +LOGGER = logging.getLogger(__name__) + class EZSPv14(EZSPv13): """EZSP Version 14 Protocol version handler.""" @@ -144,3 +149,116 @@ async def send_broadcast( ) return status, sequence + + def _handle_incoming_message(self, args: list) -> bool: + """Handle incomingMessageHandler callback and emit packet_received event. + + Returns True if message was fully handled, False if fragment is incomplete. + """ + ( + message_type, + aps_frame, + sender, + eui64, + binding_index, + address_index, + lqi, + rssi, + timestamp, + message, + ) = args + + # Handle fragmented messages + if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: + fragment_count = (aps_frame.groupId >> 8) & 0xFF + fragment_index = aps_frame.groupId & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = self._fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=message, + ) + + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) + self._fragment_ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) + + if not complete: + LOGGER.debug("Fragment reassembly not complete, waiting for more data") + return False + + LOGGER.debug("Reassembled fragmented message, proceeding with handling") + message = reassembled + + # Determine destination address based on message type + if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=aps_frame.groupId, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: + # We don't know our own NWK at this level, leave as None + dst = None + else: + LOGGER.debug("Ignoring message type: %r", message_type) + return True + + self.emit( + "packet_received", + zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(sender), + ), + src_ep=aps_frame.sourceEndpoint, + dst=dst, + dst_ep=aps_frame.destinationEndpoint, + tsn=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + data=zigpy.types.SerializableBytes(message), + lqi=lqi, + rssi=rssi, + ), + ) + + return True + + def _handle_message_sent(self, args: list) -> None: + """Handle messageSentHandler callback and emit message_sent event.""" + ( + status, + message_type, + destination, + aps_frame, + message_tag, + message, + ) = args + + self.emit( + "message_sent", + ( + status, # Already sl_Status in v14 + message_type, + destination, + aps_frame, + message_tag, + message, + ), + ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index 3b454ecd..bd23699c 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -1,12 +1,14 @@ """"EZSP Protocol version 4 command.""" from __future__ import annotations +import asyncio from collections.abc import AsyncGenerator, Iterable import logging import random import voluptuous as vol import zigpy.state +import zigpy.types import bellows.config import bellows.types as t @@ -235,3 +237,114 @@ async def set_extended_timeout( newId=nwk, newExtendedTimeout=extended_timeout, ) + + def _handle_incoming_message(self, args: list) -> bool: + """Handle incomingMessageHandler callback and emit packet_received event. + + Returns True if message was fully handled, False if fragment is incomplete. + """ + ( + message_type, + aps_frame, + lqi, + rssi, + sender, + binding_index, + address_index, + message, + ) = args + + # Handle fragmented messages + if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: + fragment_count = (aps_frame.groupId >> 8) & 0xFF + fragment_index = aps_frame.groupId & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = self._fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=message, + ) + + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) + self._fragment_ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) + + if not complete: + LOGGER.debug("Fragment reassembly not complete, waiting for more data") + return False + + LOGGER.debug("Reassembled fragmented message, proceeding with handling") + message = reassembled + + # Determine destination address based on message type + if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=aps_frame.groupId, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: + # We don't know our own NWK at this level, leave as None + dst = None + else: + LOGGER.debug("Ignoring message type: %r", message_type) + return True + + self.emit( + "packet_received", + zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(sender), + ), + src_ep=aps_frame.sourceEndpoint, + dst=dst, + dst_ep=aps_frame.destinationEndpoint, + tsn=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + data=zigpy.types.SerializableBytes(message), + lqi=lqi, + rssi=rssi, + ), + ) + + return True + + def _handle_message_sent(self, args: list) -> None: + """Handle messageSentHandler callback and emit message_sent event.""" + ( + message_type, + destination, + aps_frame, + message_tag, + status, + message, + ) = args + + self.emit( + "message_sent", + ( + t.sl_Status.from_ember_status(status), + message_type, + destination, + aps_frame, + message_tag, + message, + ), + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 8dd8b6a1..2cdd8398 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -241,6 +241,11 @@ async def start_network(self): cnt_group.reset() ezsp.add_callback(self.ezsp_callback_handler) + + # Subscribe to protocol-level events + ezsp._protocol.on_event("packet_received", self._on_packet_received) + ezsp._protocol.on_event("message_sent", self._on_message_sent) + self.controller_event.set() group_membership = {} @@ -621,72 +626,7 @@ async def force_remove(self, dev): def ezsp_callback_handler(self, frame_name, args): LOGGER.debug("Received %s frame with %s", frame_name, args) - if frame_name == "incomingMessageHandler": - if self._ezsp.ezsp_version >= 14: - ( - message_type, - aps_frame, - nwk, - _eui64, - binding_index, - address_index, - lqi, - rssi, - _timestamp, - message, - ) = args - else: - ( - message_type, - aps_frame, - lqi, - rssi, - nwk, - binding_index, - address_index, - message, - ) = args - - self._handle_frame( - message_type=message_type, - aps_frame=aps_frame, - lqi=lqi, - rssi=rssi, - sender=nwk, - binding_index=binding_index, - address_index=address_index, - message=message, - ) - elif frame_name == "messageSentHandler": - if self._ezsp.ezsp_version >= 14: - ( - status, - message_type, - destination, - aps_frame, - message_tag, - message, - ) = args - else: - ( - message_type, - destination, - aps_frame, - message_tag, - status, - message, - ) = args - status = t.sl_Status.from_ember_status(status) - - self._handle_frame_sent( - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - status=status, - message=message, - ) - elif frame_name == "trustCenterJoinHandler": + if frame_name == "trustCenterJoinHandler": self._handle_tc_join_handler(*args) elif frame_name == "incomingRouteRecordHandler": self.handle_route_record(*args) @@ -697,64 +637,27 @@ def ezsp_callback_handler(self, frame_name, args): elif frame_name == "idConflictHandler": self._handle_id_conflict(*args) - def _handle_frame( - self, - message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame, - lqi: t.uint8_t, - rssi: t.int8s, - sender: t.EmberNodeId, - binding_index: t.uint8_t, - address_index: t.uint8_t, - message: bytes, - ) -> None: - if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + def _on_packet_received(self, packet: zigpy.types.ZigbeePacket) -> None: + """Handle packet_received event from protocol handler.""" + if packet.dst is None: + packet = packet.replace( + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=self.state.node_info.nwk, + ) ) + self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() + elif packet.dst.addr_mode == zigpy.types.AddrMode.Broadcast: self.state.counters[COUNTERS_CTRL][COUNTER_RX_BCAST].increment() - elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, address=aps_frame.groupId - ) + elif packet.dst.addr_mode == zigpy.types.AddrMode.Group: self.state.counters[COUNTERS_CTRL][COUNTER_RX_MCAST].increment() - elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, address=self.state.node_info.nwk - ) - self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() - else: - LOGGER.debug("Ignoring message type: %r", message_type) - return - self.packet_received( - zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=sender, - ), - src_ep=aps_frame.sourceEndpoint, - dst=dst, - dst_ep=aps_frame.destinationEndpoint, - tsn=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - data=zigpy.types.SerializableBytes(message), - lqi=lqi, - rssi=rssi, - ) - ) + self.packet_received(packet) + + def _on_message_sent(self, data: tuple) -> None: + """Handle message_sent event from protocol handler.""" + status, message_type, destination, aps_frame, message_tag, message = data - def _handle_frame_sent( - self, - message_type: t.EmberIncomingMessageType, - destination: t.EmberNodeId, - aps_frame: t.EmberApsFrame, - message_tag: int, - status: t.sl_Status, - message: bytes, - ): if status == t.sl_Status.OK: msg = "success" else: diff --git a/tests/test_application.py b/tests/test_application.py index 72865c7e..8544c965 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -71,6 +71,10 @@ def inner(config, send_timeout: float = 0.05, **kwargs): app.handle_message = MagicMock() app.packet_received = MagicMock() + # Set up event subscriptions normally done in connect() + app._ezsp._protocol.on_event("packet_received", app._on_packet_received) + app._ezsp._protocol.on_event("message_sent", app._on_message_sent) + return app return inner @@ -431,8 +435,8 @@ def aps_frame(): def _handle_incoming_aps_frame(app, aps_frame, type): - app.ezsp_callback_handler( - "incomingMessageHandler", + # Call protocol handler directly (v4/v8 field order) + app._ezsp._protocol._handle_incoming_message( list( dict( type=type, @@ -444,7 +448,16 @@ def _handle_incoming_aps_frame(app, aps_frame, type): addressIndex=78, message=b"test message", ).values() - ), + ) + ) + + +def _handle_message_sent( + app, msg_type, destination, aps_frame, message_tag, status, message +): + # Call protocol handler directly (v4/v8 field order) + app._ezsp._protocol._handle_message_sent( + [msg_type, destination, aps_frame, message_tag, status, message] ) @@ -552,9 +565,7 @@ def test_frame_handler_ignored(app, aps_frame): ) async def test_send_failure(app, aps, ieee, msg_type): fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", [msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b""] - ) + _handle_message_sent(app, msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b"") assert fut.result() == (t.sl_Status.OK, "message send success") @@ -562,54 +573,47 @@ async def test_dup_send_failure(app, aps, ieee): fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() fut.set_result("Already set") - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xBEED, - aps, - 254, - sentinel.status, - b"", - ], + _handle_message_sent( + app, + t.EmberIncomingMessageType.INCOMING_UNICAST, + 0xBEED, + aps, + 254, + sentinel.status, + b"", ) def test_send_failure_unexpected(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, - 0xBEED, - aps, - 257, - 1, - b"", - ], + _handle_message_sent( + app, + t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, + 0xBEED, + aps, + 257, + 1, + b"", ) async def test_send_success(app, aps, ieee): fut = app._pending_requests[(0xBEED, 253)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, - 0xBEED, - aps, - 253, - t.EmberStatus.SUCCESS, - b"", - ], + _handle_message_sent( + app, + t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, + 0xBEED, + aps, + 253, + t.EmberStatus.SUCCESS, + b"", ) assert fut.result() == (t.sl_Status.OK, "message send success") def test_unexpected_send_success(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b""], + _handle_message_sent( + app, t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b"" ) @@ -737,26 +741,24 @@ def packet(): async def test_request_concurrency_duplicate_failure( make_app, packet: zigpy_t.ZigbeePacket ) -> None: + # Increase the send timeout, CI is inconsistent with the default + app = make_app({}, send_timeout=0.5) + def send_unicast(aps_frame, data, message_tag, nwk): asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, - "messageSentHandler", - list( - dict( - type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - indexOrDestination=0x1234, - apsFrame=aps_frame, - messageTag=message_tag, - status=bellows.types.sl_Status.OK, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent, + [ + t.EmberOutgoingMessageType.OUTGOING_DIRECT, + 0x1234, + aps_frame, + message_tag, + bellows.types.sl_Status.OK, + b"", + ], ) return [bellows.types.sl_Status.OK, 0x12] - # Increase the send timeout, CI is inconsistent with the default - app = make_app({}, send_timeout=0.5) app._ezsp.send_unicast = AsyncMock( side_effect=send_unicast, spec=app._ezsp.send_unicast ) @@ -791,18 +793,15 @@ async def _test_send_packet_unicast( def send_unicast(*args, **kwargs): asyncio.get_running_loop().call_later( 0.01, - app.ezsp_callback_handler, - "messageSentHandler", - list( - dict( - type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - indexOrDestination=0x1234, - apsFrame=sentinel.aps, - messageTag=sentinel.msg_tag, - status=sent_handler_status, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent, + [ + t.EmberOutgoingMessageType.OUTGOING_DIRECT, + 0x1234, + sentinel.aps, + sentinel.msg_tag, + sent_handler_status, + b"", + ], ) return [status, 0x12] @@ -1045,18 +1044,8 @@ async def send_message_sent_reply( await asyncio.sleep(0.01) - app.ezsp_callback_handler( - "messageSentHandler", - list( - dict( - type=type, - indexOrDestination=indexOrDestination, - apsFrame=apsFrame, - messageTag=messageTag, - status=t.EmberStatus.SUCCESS, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent( + [type, indexOrDestination, apsFrame, messageTag, t.EmberStatus.SUCCESS, b""] ) async def send_unicast(nwk, aps_frame, message_tag, data): @@ -1102,18 +1091,15 @@ async def test_send_packet_broadcast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, - "messageSentHandler", - list( - dict( - type=t.EmberOutgoingMessageType.OUTGOING_BROADCAST, - indexOrDestination=0xFFFE, - apsFrame=sentinel.aps, - messageTag=sentinel.msg_tag, - status=t.EmberStatus.SUCCESS, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent, + [ + t.EmberOutgoingMessageType.OUTGOING_BROADCAST, + 0xFFFE, + sentinel.aps, + sentinel.msg_tag, + t.EmberStatus.SUCCESS, + b"", + ], ) await app.send_packet(packet) @@ -1148,18 +1134,15 @@ async def test_send_packet_broadcast_ignored_delivery_failure(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, - "messageSentHandler", - list( - dict( - type=t.EmberOutgoingMessageType.OUTGOING_BROADCAST, - indexOrDestination=0xFFFE, - apsFrame=sentinel.aps, - messageTag=sentinel.msg_tag, - status=t.EmberStatus.DELIVERY_FAILED, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent, + [ + t.EmberOutgoingMessageType.OUTGOING_BROADCAST, + 0xFFFE, + sentinel.aps, + sentinel.msg_tag, + t.EmberStatus.DELIVERY_FAILED, + b"", + ], ) # Does not throw an error @@ -1201,18 +1184,15 @@ async def test_send_packet_multicast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, - "messageSentHandler", - list( - dict( - type=t.EmberOutgoingMessageType.OUTGOING_MULTICAST, - indexOrDestination=0x1234, - apsFrame=sentinel.aps, - messageTag=sentinel.msg_tag, - status=t.EmberStatus.SUCCESS, - message=b"", - ).values() - ), + app._ezsp._protocol._handle_message_sent, + [ + t.EmberOutgoingMessageType.OUTGOING_MULTICAST, + 0x1234, + sentinel.aps, + sentinel.msg_tag, + t.EmberStatus.SUCCESS, + b"", + ], ) await app.send_packet(packet) diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 3906eb5e..d294b7f4 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -206,9 +206,11 @@ async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - prot_hndl._handle_callback.assert_not_called() - assert "Fragment reassembly not complete. waiting for more data." in caplog.text - mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) + assert len(prot_hndl._handle_callback.mock_calls) == 0 + assert "Fragment reassembly not complete, waiting for more data" in caplog.text + assert prot_hndl._send_fragment_ack.mock_calls == [ + call(sender, aps_frame, 2, 0) + ] async def test_incoming_fragmented_message_complete(prot_hndl, caplog): @@ -241,7 +243,6 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): groupId=513, # fragment_count=2, fragment_index=1 sequence=238, ) - reassembled = b"complete message" with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: mock_ack.return_value = None @@ -256,12 +257,14 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - prot_hndl._handle_callback.assert_not_called() + assert len(prot_hndl._handle_callback.mock_calls) == 0 assert ( - "Reassembled fragmented message. Proceeding with normal handling." + "Reassembled fragmented message, proceeding with handling" not in caplog.text ) - mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) + assert prot_hndl._send_fragment_ack.mock_calls == [ + call(sender, aps_frame_1, 2, 0) + ] # Packet 2 prot_hndl(packet2) @@ -272,21 +275,24 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - prot_hndl._handle_callback.assert_called_once_with( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 - aps_frame_2, # Parsed APS frame - 255, # lastHopLqi: 0xFF - -8, # lastHopRssi: 0xF8 - sender, # 0x1D6F - 255, # bindingIndex: 0xFF - 255, # addressIndex: 0xFF - reassembled, # Reassembled payload - ], - ) - assert ( - "Reassembled fragmented message. Proceeding with normal handling." - in caplog.text - ) - mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) + # Legacy callback is called with original args (last fragment's payload) + assert prot_hndl._handle_callback.mock_calls == [ + call( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, + aps_frame_2, + 255, # lastHopLqi + -8, # lastHopRssi + sender, + 255, # bindingIndex + 255, # addressIndex + b"message", # Original last fragment payload, not reassembled + ], + ) + ] + assert "Reassembled fragmented message, proceeding with handling" in caplog.text + assert prot_hndl._send_fragment_ack.mock_calls == [ + call(sender, aps_frame_1, 2, 0), + call(sender, aps_frame_2, 2, 1), + ] From 19a7a1dada0417fb8d10c1467495e1676ca45243 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 17:05:48 -0500 Subject: [PATCH 04/18] Add to EZSP config --- bellows/ezsp/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bellows/ezsp/config.py b/bellows/ezsp/config.py index 1045da0f..61e485ce 100644 --- a/bellows/ezsp/config.py +++ b/bellows/ezsp/config.py @@ -129,4 +129,5 @@ class ValueConfig: 14: DEFAULT_CONFIG_NEW, 16: DEFAULT_CONFIG_NEW, 17: DEFAULT_CONFIG_NEW, + 18: DEFAULT_CONFIG_NEW, } From 76487f1fb00c40eb0f8998b74307a401de165638 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 17:07:19 -0500 Subject: [PATCH 05/18] Add a new test for v18 --- tests/test_ezsp_v18.py | 164 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 tests/test_ezsp_v18.py diff --git a/tests/test_ezsp_v18.py b/tests/test_ezsp_v18.py new file mode 100644 index 00000000..db0db68c --- /dev/null +++ b/tests/test_ezsp_v18.py @@ -0,0 +1,164 @@ +from unittest.mock import MagicMock, call + +import pytest + +import bellows.ezsp.v18 +import bellows.types as t + +from tests.common import mock_ezsp_commands + + +@pytest.fixture +def ezsp_f(): + """EZSP v18 protocol handler.""" + ezsp = bellows.ezsp.v18.EZSPv18(MagicMock(), MagicMock()) + mock_ezsp_commands(ezsp) + + return ezsp + + +def test_ezsp_frame(ezsp_f): + ezsp_f._seq = 0x22 + data = ezsp_f._ezsp_frame("version", 18) + assert data == b"\x22\x00\x01\x00\x00\x12" + + +def test_ezsp_frame_rx(ezsp_f): + """Test receiving a version frame.""" + ezsp_f(b"\x01\x01\x80\x00\x00\x01\x02\x34\x12") + assert ezsp_f._handle_callback.call_count == 1 + assert ezsp_f._handle_callback.call_args[0][0] == "version" + assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] + + +async def test_send_unicast(ezsp_f) -> None: + ezsp_f.sendUnicast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_unicast( + nwk=0x1234, + aps_frame=aps_frame, + message_tag=0x42, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendUnicast.mock_calls == [ + call( + message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + nwk=0x1234, + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + radius=0, + ), + message_tag=0x42, + message=b"hello", + ) + ] + + +async def test_send_multicast(ezsp_f) -> None: + ezsp_f.sendMulticast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x1234, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_multicast( + aps_frame=aps_frame, + radius=12, + non_member_radius=34, + message_tag=0x42, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendMulticast.mock_calls == [ + call( + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x1234, + sequence=0x34, + radius=12, + ), + hops=12, + broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, + alias=0x0000, + sequence=0x34, + message_tag=0x0042, + message=b"hello", + ) + ] + + +async def test_send_broadcast(ezsp_f) -> None: + ezsp_f.sendBroadcast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_broadcast( + address=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + aps_frame=aps_frame, + radius=12, + message_tag=0x42, + aps_sequence=34, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendBroadcast.mock_calls == [ + call( + alias=0x0000, + destination=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + sequence=34, + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + radius=12, + ), + radius=12, + message_tag=0x42, + message=b"hello", + ) + ] From fa84c12ed9671b9fc31e4748506f4f77adb23159 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Wed, 31 Dec 2025 18:14:17 -0500 Subject: [PATCH 06/18] Properly handle event lifecycle --- bellows/zigbee/application.py | 27 ++++++++++++++++++++++----- tests/test_application.py | 5 ++--- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 2cdd8398..b18e0b90 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -2,7 +2,7 @@ import asyncio from asyncio import timeout as asyncio_timeout -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from datetime import UTC, datetime import importlib.metadata import logging @@ -97,6 +97,7 @@ def __init__(self, config: dict) -> None: self._multicast = None self._mfg_id_task: asyncio.Task | None = None self._pending_requests = {} + self._protocol_on_remove_callbacks: list[Callable[[], None]] = [] self._watchdog_failures = 0 self._watchdog_feed_counter = 0 @@ -241,10 +242,7 @@ async def start_network(self): cnt_group.reset() ezsp.add_callback(self.ezsp_callback_handler) - - # Subscribe to protocol-level events - ezsp._protocol.on_event("packet_received", self._on_packet_received) - ezsp._protocol.on_event("message_sent", self._on_message_sent) + self._subscribe_to_protocol_events() self.controller_event.set() @@ -607,14 +605,33 @@ async def reset_network_info(self): else: await self._ezsp.leaveNetwork() + def _unsubscribe_from_protocol_events(self) -> None: + """Unsubscribe from protocol events.""" + for callback in self._protocol_on_remove_callbacks: + callback() + + self._protocol_on_remove_callbacks.clear() + async def _reset(self): + self._unsubscribe_from_protocol_events() self._ezsp.stop_ezsp() await self._ezsp.startup_reset() await self._ezsp.write_config(self.config[CONF_EZSP_CONFIG]) + self._subscribe_to_protocol_events() + + def _subscribe_to_protocol_events(self) -> None: + """Subscribe to protocol-level events.""" + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event("packet_received", self._on_packet_received) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event("message_sent", self._on_message_sent) + ) async def disconnect(self): # TODO: how do you shut down the stack? self.controller_event.clear() + self._unsubscribe_from_protocol_events() if self._ezsp is not None: await self._ezsp.disconnect() self._ezsp = None diff --git a/tests/test_application.py b/tests/test_application.py index 8544c965..376347ac 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -71,9 +71,8 @@ def inner(config, send_timeout: float = 0.05, **kwargs): app.handle_message = MagicMock() app.packet_received = MagicMock() - # Set up event subscriptions normally done in connect() - app._ezsp._protocol.on_event("packet_received", app._on_packet_received) - app._ezsp._protocol.on_event("message_sent", app._on_message_sent) + # Set up event subscriptions normally done in start_network() + app._subscribe_to_protocol_events() return app From 5c725b6c96f9cc8ebca6cd0669900c27ef6a3d1f Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:00:10 -0500 Subject: [PATCH 07/18] Migrate to events --- bellows/ezsp/protocol.py | 154 +++++++++++++++++++++++++++++----- bellows/ezsp/v14/__init__.py | 142 ++++++++----------------------- bellows/ezsp/v4/__init__.py | 138 ++++++++---------------------- bellows/zigbee/application.py | 34 +++++--- tests/test_ezsp_v14.py | 151 +++++++++++++++++++++++++++++++++ 5 files changed, 378 insertions(+), 241 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index f86d261a..52498fd7 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -5,10 +5,11 @@ from asyncio import timeout as asyncio_timeout import binascii from collections.abc import AsyncGenerator, Callable, Iterable +from dataclasses import dataclass import functools import logging import time -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Final from zigpy.datastructures import PriorityDynamicBoundedSemaphore from zigpy.event.event_base import EventBase @@ -29,6 +30,25 @@ MAX_COMMAND_CONCURRENCY = 1 +@dataclass(frozen=True, kw_only=True) +class MessageSentEvent: + event_type: Final[str] = "message_sent" + + status: t.sl_Status + message_type: t.EmberOutgoingMessageType + destination: t.uint16_t + aps_frame: t.EmberApsFrame + message_tag: t.uint8_t + message_contents: t.LVBytes + + +@dataclass(frozen=True, kw_only=True) +class PacketReceivedEvent: + event_type: Final[str] = "packet_received" + + packet: zigpy.types.ZigbeePacket + + class ProtocolHandler(EventBase, abc.ABC): """EZSP protocol specific handler.""" @@ -206,17 +226,16 @@ def __call__(self, data: bytes) -> None: return - # Handle callbacks via version-specific methods that emit events - if frame_name == "incomingMessageHandler": - if not self._handle_incoming_message(result): - # Fragment incomplete, skip legacy callback - return - elif frame_name == "messageSentHandler": - self._handle_message_sent(result) + self.handle_parsed_callback(frame_name, result) # Always call legacy callback handler for backwards compatibility self._handle_callback(frame_name, result) + @abc.abstractmethod + def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: + """Handle a parsed callback frame.""" + raise NotImplementedError + async def _send_fragment_ack( self, sender: int, @@ -243,6 +262,112 @@ async def _send_fragment_ack( status = await self.sendReply(sender, ackFrame, b"") return status[0] + def _handle_incoming_message( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame | t.EmberApsFrameV18, + sender: zigpy.types.NWK, + eui64: zigpy.types.EUI64 | None, + binding_index: t.uint8_t, + address_index: t.uint8_t, + lqi: t.uint8_t, + rssi: t.int8s, + timestamp: t.uint32_t | None, + message: t.LVBytes, + ) -> None: + """Handle incomingMessageHandler callback and maybe return a packet.""" + + if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: + fragment_count = (aps_frame.groupId >> 8) & 0xFF + fragment_index = aps_frame.groupId & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = self._fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=message, + ) + + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) + self._fragment_ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) + + if not complete: + LOGGER.debug("Fragment reassembly not complete, waiting for more data") + return + + LOGGER.debug("Reassembled fragmented message, proceeding with handling") + message = reassembled + + # Determine destination address based on message type + if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=aps_frame.groupId, + ) + elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: + dst = None # We don't know the current NWK here + else: + LOGGER.debug("Ignoring message type: %r", message_type) + return + + self.emit( + PacketReceivedEvent.event_type, + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(sender), + ), + src_ep=aps_frame.sourceEndpoint, + dst=dst, + dst_ep=aps_frame.destinationEndpoint, + tsn=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + data=zigpy.types.SerializableBytes(message), + lqi=lqi, + rssi=rssi, + ) + ), + ) + + def _handle_message_sent( + self, + message_type: t.EmberOutgoingMessageType, + destination: t.uint16_t, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + status: t.sl_Status, + message_contents: t.LVBytes, + ) -> None: + self.emit( + MessageSentEvent.event_type, + MessageSentEvent( + status=t.sl_Status.from_ember_status(status), + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + message_contents=message_contents, + ), + ) + def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") @@ -354,16 +479,3 @@ async def set_extended_timeout( self, nwk: t.NWK, ieee: t.EUI64, extended_timeout: bool = True ) -> None: raise NotImplementedError() - - @abc.abstractmethod - def _handle_incoming_message(self, args: list) -> bool: - """Handle incomingMessageHandler callback and emit packet_received event. - - Returns True if message was fully handled, False if fragment is incomplete. - """ - raise NotImplementedError - - @abc.abstractmethod - def _handle_message_sent(self, args: list) -> None: - """Handle messageSentHandler callback and emit message_sent event.""" - raise NotImplementedError diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index ba868ef5..4fb2ff6f 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -1,9 +1,9 @@ """"EZSP Protocol version 14 protocol handler.""" from __future__ import annotations -import asyncio from collections.abc import AsyncGenerator import logging +from typing import Any import voluptuous as vol from zigpy.exceptions import NetworkNotFormed @@ -150,115 +150,47 @@ async def send_broadcast( return status, sequence - def _handle_incoming_message(self, args: list) -> bool: - """Handle incomingMessageHandler callback and emit packet_received event. - - Returns True if message was fully handled, False if fragment is incomplete. - """ - ( - message_type, - aps_frame, - sender, - eui64, - binding_index, - address_index, - lqi, - rssi, - timestamp, - message, - ) = args - - # Handle fragmented messages - if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: - fragment_count = (aps_frame.groupId >> 8) & 0xFF - fragment_index = aps_frame.groupId & 0xFF - + def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: + """Handle a parsed callback frame.""" + if frame_name == "incomingMessageHandler": ( - complete, - reassembled, - frag_count, - frag_index, - ) = self._fragment_manager.handle_incoming_fragment( - sender_nwk=sender, - aps_sequence=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - fragment_count=fragment_count, - fragment_index=fragment_index, - payload=message, - ) - - ack_task = asyncio.create_task( - self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) - ) - self._fragment_ack_tasks.add(ack_task) - ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) - - if not complete: - LOGGER.debug("Fragment reassembly not complete, waiting for more data") - return False - - LOGGER.debug("Reassembled fragmented message, proceeding with handling") - message = reassembled - - # Determine destination address based on message type - if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, - address=aps_frame.groupId, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: - # We don't know our own NWK at this level, leave as None - dst = None - else: - LOGGER.debug("Ignoring message type: %r", message_type) - return True - - self.emit( - "packet_received", - zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(sender), - ), - src_ep=aps_frame.sourceEndpoint, - dst=dst, - dst_ep=aps_frame.destinationEndpoint, - tsn=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - data=zigpy.types.SerializableBytes(message), + message_type, + aps_frame, + lqi, + rssi, + sender, + binding_index, + address_index, + message, + ) = args + + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, lqi=lqi, rssi=rssi, - ), - ) - - return True - - def _handle_message_sent(self, args: list) -> None: - """Handle messageSentHandler callback and emit message_sent event.""" - ( - status, - message_type, - destination, - aps_frame, - message_tag, - message, - ) = args - - self.emit( - "message_sent", + timestamp=None, + message=message, + ) + elif frame_name == "messageSentHandler": ( - status, # Already sl_Status in v14 + status, message_type, - destination, + nwk, aps_frame, message_tag, message, - ), - ) + ) = args + + self._handle_message_sent( + message_type=message_type, + destination=nwk, + aps_frame=aps_frame, + message_tag=message_tag, + status=status, + message_contents=message, + ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index bd23699c..fbda9527 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -1,10 +1,10 @@ """"EZSP Protocol version 4 command.""" from __future__ import annotations -import asyncio from collections.abc import AsyncGenerator, Iterable import logging import random +from typing import Any import voluptuous as vol import zigpy.state @@ -238,113 +238,47 @@ async def set_extended_timeout( newExtendedTimeout=extended_timeout, ) - def _handle_incoming_message(self, args: list) -> bool: - """Handle incomingMessageHandler callback and emit packet_received event. - - Returns True if message was fully handled, False if fragment is incomplete. - """ - ( - message_type, - aps_frame, - lqi, - rssi, - sender, - binding_index, - address_index, - message, - ) = args - - # Handle fragmented messages - if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: - fragment_count = (aps_frame.groupId >> 8) & 0xFF - fragment_index = aps_frame.groupId & 0xFF - + def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: + """Handle a parsed callback frame.""" + if frame_name == "incomingMessageHandler": ( - complete, - reassembled, - frag_count, - frag_index, - ) = self._fragment_manager.handle_incoming_fragment( - sender_nwk=sender, - aps_sequence=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - fragment_count=fragment_count, - fragment_index=fragment_index, - payload=message, - ) - - ack_task = asyncio.create_task( - self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) - ) - self._fragment_ack_tasks.add(ack_task) - ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) - - if not complete: - LOGGER.debug("Fragment reassembly not complete, waiting for more data") - return False - - LOGGER.debug("Reassembled fragmented message, proceeding with handling") - message = reassembled - - # Determine destination address based on message type - if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, - address=aps_frame.groupId, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: - # We don't know our own NWK at this level, leave as None - dst = None - else: - LOGGER.debug("Ignoring message type: %r", message_type) - return True - - self.emit( - "packet_received", - zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(sender), - ), - src_ep=aps_frame.sourceEndpoint, - dst=dst, - dst_ep=aps_frame.destinationEndpoint, - tsn=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - data=zigpy.types.SerializableBytes(message), + message_type, + aps_frame, + lqi, + rssi, + sender, + binding_index, + address_index, + message, + ) = args + + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, lqi=lqi, rssi=rssi, - ), - ) - - return True - - def _handle_message_sent(self, args: list) -> None: - """Handle messageSentHandler callback and emit message_sent event.""" - ( - message_type, - destination, - aps_frame, - message_tag, - status, - message, - ) = args - - self.emit( - "message_sent", + timestamp=None, + message=message, + ) + elif frame_name == "messageSentHandler": ( - t.sl_Status.from_ember_status(status), message_type, destination, aps_frame, message_tag, + status, message, - ), - ) + ) = args + + self._handle_message_sent( + type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + status=t.sl_Status.from_ember_status(status), + message_contents=message, + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index b18e0b90..6a6ff271 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -39,6 +39,7 @@ StackAlreadyRunning, ) import bellows.ezsp +from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent from bellows.ezsp.xncp import FirmwareFeatures import bellows.multicast import bellows.types as t @@ -622,10 +623,14 @@ async def _reset(self): def _subscribe_to_protocol_events(self) -> None: """Subscribe to protocol-level events.""" self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event("packet_received", self._on_packet_received) + self._ezsp._protocol.on_event( + PacketReceivedEvent.event_type, self._on_packet_received + ) ) self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event("message_sent", self._on_message_sent) + self._ezsp._protocol.on_event( + MessageSentEvent.event_type, self._on_message_sent + ) ) async def disconnect(self): @@ -654,8 +659,11 @@ def ezsp_callback_handler(self, frame_name, args): elif frame_name == "idConflictHandler": self._handle_id_conflict(*args) - def _on_packet_received(self, packet: zigpy.types.ZigbeePacket) -> None: + def _on_packet_received(self, message: PacketReceivedEvent) -> None: """Handle packet_received event from protocol handler.""" + packet = message.packet + + # The protocol handler doesn't know our current NWK address if packet.dst is None: packet = packet.replace( dst=zigpy.types.AddrModeAddress( @@ -663,6 +671,8 @@ def _on_packet_received(self, packet: zigpy.types.ZigbeePacket) -> None: address=self.state.node_info.nwk, ) ) + + if packet.dst.addr_mode == zigpy.types.AddrMode.NWK: self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() elif packet.dst.addr_mode == zigpy.types.AddrMode.Broadcast: self.state.counters[COUNTERS_CTRL][COUNTER_RX_BCAST].increment() @@ -671,40 +681,38 @@ def _on_packet_received(self, packet: zigpy.types.ZigbeePacket) -> None: self.packet_received(packet) - def _on_message_sent(self, data: tuple) -> None: + def _on_message_sent(self, event: MessageSentEvent) -> None: """Handle message_sent event from protocol handler.""" - status, message_type, destination, aps_frame, message_tag, message = data - - if status == t.sl_Status.OK: + if event.status == t.sl_Status.OK: msg = "success" else: msg = "failure" - if message_type in ( + if event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_BROADCAST, t.EmberOutgoingMessageType.OUTGOING_BROADCAST_WITH_ALIAS, ): cnt_name = f"broadcast_tx_{msg}" - elif message_type in ( + elif event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_MULTICAST, t.EmberOutgoingMessageType.OUTGOING_MULTICAST_WITH_ALIAS, ): cnt_name = f"multicast_tx_{msg}" - elif message_type in ( + elif event.message_type in ( t.EmberOutgoingMessageType.OUTGOING_DIRECT, t.EmberOutgoingMessageType.OUTGOING_VIA_ADDRESS_TABLE, ): cnt_name = f"unicast_tx_{msg}" - elif message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: + elif event.message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: cnt_name = f"via_binding_tx_{msg}" else: cnt_name = f"unknown_msg_type_{msg}" - pending_tag = (destination, message_tag) + pending_tag = (event.destination, event.message_tag) try: future = self._pending_requests[pending_tag] - future.set_result((status, f"message send {msg}")) + future.set_result((event.status, f"message send {msg}")) self.state.counters[COUNTERS_CTRL][cnt_name].increment() except KeyError: self.state.counters[COUNTERS_CTRL][f"{cnt_name}_unexpected"].increment() diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 49bf152b..75c22a98 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -3,6 +3,7 @@ import pytest import zigpy.exceptions import zigpy.state +import zigpy.types import bellows.ezsp.v14 import bellows.types as t @@ -226,3 +227,153 @@ async def test_send_broadcast(ezsp_f) -> None: message=b"hello", ) ] + + +@pytest.mark.parametrize( + "message_type, expected_dst", + [ + ( + t.EmberIncomingMessageType.INCOMING_UNICAST, + None, + ), + ( + t.EmberIncomingMessageType.INCOMING_BROADCAST, + zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ), + ), + ( + t.EmberIncomingMessageType.INCOMING_MULTICAST, + zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=0x1234, + ), + ), + ], +) +def test_incoming_message_handler(ezsp_f, message_type, expected_dst) -> None: + """Test incomingMessageHandler emits packet_received event.""" + received_packets = [] + ezsp_f.on_event("packet_received", lambda pkt: received_packets.append(pkt)) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x1234, + sequence=0x42, + ) + + ezsp_f.handle_parsed_callback( + "incomingMessageHandler", + [ + message_type, + aps_frame, + t.EmberNodeId(0x1234), # sender nwk + t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), # sender eui64 + 0, # binding_index + 0, # address_index + 200, # lqi + -40, # rssi + 12345678, # timestamp + b"test message", # message + ], + ) + + assert len(received_packets) == 1 + packet = received_packets[0] + assert packet.src == zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ) + assert packet.src_ep == 1 + assert packet.dst == expected_dst + assert packet.dst_ep == 2 + assert packet.profile_id == 0x0104 + assert packet.cluster_id == 0x0006 + assert packet.data == zigpy.types.SerializableBytes(b"test message") + assert packet.lqi == 200 + assert packet.rssi == -40 + + +def test_incoming_message_handler_ignored_type(ezsp_f) -> None: + """Test incomingMessageHandler ignores unknown message types.""" + received_packets = [] + ezsp_f.on_event("packet_received", lambda pkt: received_packets.append(pkt)) + + aps_frame = t.EmberApsFrame(options=t.EmberApsOption.APS_OPTION_NONE) + ezsp_f.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, + aps_frame, + t.EmberNodeId(0x1234), + t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + 0, + 0, + 200, + -40, + 12345678, + b"test", + ], + ) + + assert len(received_packets) == 0 + # Legacy callback should still be called + assert ezsp_f._handle_callback.mock_calls == [ + call( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, + aps_frame, + t.EmberNodeId(0x1234), + t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), + 0, + 0, + 200, + -40, + 12345678, + b"test", + ], + ) + ] + + +def test_message_sent_handler(ezsp_f) -> None: + """Test messageSentHandler emits message_sent event.""" + sent_messages = [] + ezsp_f.on_event("message_sent", lambda msg: sent_messages.append(msg)) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + ezsp_f.handle_parsed_callback( + "messageSentHandler", + [ + t.sl_Status.OK, + t.EmberOutgoingMessageType.OUTGOING_DIRECT, + t.EmberNodeId(0x1234), + aps_frame, + 0x42, # message_tag + b"sent message", + ], + ) + + assert len(sent_messages) == 1 + status, msg_type, destination, frame, tag, message = sent_messages[0] + assert status == t.sl_Status.OK + assert msg_type == t.EmberOutgoingMessageType.OUTGOING_DIRECT + assert destination == t.EmberNodeId(0x1234) + assert frame == aps_frame + assert tag == 0x42 + assert message == b"sent message" From ee832b274d66724ae5767bf0b6f4c7898e2bb390 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:01:23 -0500 Subject: [PATCH 08/18] Revert unit test changes --- tests/test_application.py | 231 ++++++++++++++++++++++-------------- tests/test_ezsp_protocol.py | 56 ++++----- tests/test_ezsp_v14.py | 151 ----------------------- tests/test_ezsp_v18.py | 164 ------------------------- 4 files changed, 167 insertions(+), 435 deletions(-) delete mode 100644 tests/test_ezsp_v18.py diff --git a/tests/test_application.py b/tests/test_application.py index 376347ac..2f055ec5 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -434,8 +434,8 @@ def aps_frame(): def _handle_incoming_aps_frame(app, aps_frame, type): - # Call protocol handler directly (v4/v8 field order) - app._ezsp._protocol._handle_incoming_message( + app.ezsp_callback_handler( + "incomingMessageHandler", list( dict( type=type, @@ -447,16 +447,7 @@ def _handle_incoming_aps_frame(app, aps_frame, type): addressIndex=78, message=b"test message", ).values() - ) - ) - - -def _handle_message_sent( - app, msg_type, destination, aps_frame, message_tag, status, message -): - # Call protocol handler directly (v4/v8 field order) - app._ezsp._protocol._handle_message_sent( - [msg_type, destination, aps_frame, message_tag, status, message] + ), ) @@ -564,7 +555,9 @@ def test_frame_handler_ignored(app, aps_frame): ) async def test_send_failure(app, aps, ieee, msg_type): fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - _handle_message_sent(app, msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b"") + app.ezsp_callback_handler( + "messageSentHandler", [msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b""] + ) assert fut.result() == (t.sl_Status.OK, "message send success") @@ -572,47 +565,54 @@ async def test_dup_send_failure(app, aps, ieee): fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() fut.set_result("Already set") - _handle_message_sent( - app, - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xBEED, - aps, - 254, - sentinel.status, - b"", + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, + 0xBEED, + aps, + 254, + sentinel.status, + b"", + ], ) def test_send_failure_unexpected(app, aps, ieee): - _handle_message_sent( - app, - t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, - 0xBEED, - aps, - 257, - 1, - b"", + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, + 0xBEED, + aps, + 257, + 1, + b"", + ], ) async def test_send_success(app, aps, ieee): fut = app._pending_requests[(0xBEED, 253)] = asyncio.Future() - _handle_message_sent( - app, - t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, - 0xBEED, - aps, - 253, - t.EmberStatus.SUCCESS, - b"", + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, + 0xBEED, + aps, + 253, + t.EmberStatus.SUCCESS, + b"", + ], ) assert fut.result() == (t.sl_Status.OK, "message send success") def test_unexpected_send_success(app, aps, ieee): - _handle_message_sent( - app, t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b"" + app.ezsp_callback_handler( + "messageSentHandler", + [t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b""], ) @@ -740,24 +740,26 @@ def packet(): async def test_request_concurrency_duplicate_failure( make_app, packet: zigpy_t.ZigbeePacket ) -> None: - # Increase the send timeout, CI is inconsistent with the default - app = make_app({}, send_timeout=0.5) - def send_unicast(aps_frame, data, message_tag, nwk): asyncio.get_running_loop().call_soon( - app._ezsp._protocol._handle_message_sent, - [ - t.EmberOutgoingMessageType.OUTGOING_DIRECT, - 0x1234, - aps_frame, - message_tag, - bellows.types.sl_Status.OK, - b"", - ], + app.ezsp_callback_handler, + "messageSentHandler", + list( + dict( + type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + indexOrDestination=0x1234, + apsFrame=aps_frame, + messageTag=message_tag, + status=bellows.types.sl_Status.OK, + message=b"", + ).values() + ), ) return [bellows.types.sl_Status.OK, 0x12] + # Increase the send timeout, CI is inconsistent with the default + app = make_app({}, send_timeout=0.5) app._ezsp.send_unicast = AsyncMock( side_effect=send_unicast, spec=app._ezsp.send_unicast ) @@ -792,15 +794,18 @@ async def _test_send_packet_unicast( def send_unicast(*args, **kwargs): asyncio.get_running_loop().call_later( 0.01, - app._ezsp._protocol._handle_message_sent, - [ - t.EmberOutgoingMessageType.OUTGOING_DIRECT, - 0x1234, - sentinel.aps, - sentinel.msg_tag, - sent_handler_status, - b"", - ], + app.ezsp_callback_handler, + "messageSentHandler", + list( + dict( + type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + indexOrDestination=0x1234, + apsFrame=sentinel.aps, + messageTag=sentinel.msg_tag, + status=sent_handler_status, + message=b"", + ).values() + ), ) return [status, 0x12] @@ -1043,8 +1048,18 @@ async def send_message_sent_reply( await asyncio.sleep(0.01) - app._ezsp._protocol._handle_message_sent( - [type, indexOrDestination, apsFrame, messageTag, t.EmberStatus.SUCCESS, b""] + app.ezsp_callback_handler( + "messageSentHandler", + list( + dict( + type=type, + indexOrDestination=indexOrDestination, + apsFrame=apsFrame, + messageTag=messageTag, + status=t.EmberStatus.SUCCESS, + message=b"", + ).values() + ), ) async def send_unicast(nwk, aps_frame, message_tag, data): @@ -1090,15 +1105,18 @@ async def test_send_packet_broadcast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol._handle_message_sent, - [ - t.EmberOutgoingMessageType.OUTGOING_BROADCAST, - 0xFFFE, - sentinel.aps, - sentinel.msg_tag, - t.EmberStatus.SUCCESS, - b"", - ], + app.ezsp_callback_handler, + "messageSentHandler", + list( + dict( + type=t.EmberOutgoingMessageType.OUTGOING_BROADCAST, + indexOrDestination=0xFFFE, + apsFrame=sentinel.aps, + messageTag=sentinel.msg_tag, + status=t.EmberStatus.SUCCESS, + message=b"", + ).values() + ), ) await app.send_packet(packet) @@ -1133,15 +1151,18 @@ async def test_send_packet_broadcast_ignored_delivery_failure(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol._handle_message_sent, - [ - t.EmberOutgoingMessageType.OUTGOING_BROADCAST, - 0xFFFE, - sentinel.aps, - sentinel.msg_tag, - t.EmberStatus.DELIVERY_FAILED, - b"", - ], + app.ezsp_callback_handler, + "messageSentHandler", + list( + dict( + type=t.EmberOutgoingMessageType.OUTGOING_BROADCAST, + indexOrDestination=0xFFFE, + apsFrame=sentinel.aps, + messageTag=sentinel.msg_tag, + status=t.EmberStatus.DELIVERY_FAILED, + message=b"", + ).values() + ), ) # Does not throw an error @@ -1183,15 +1204,18 @@ async def test_send_packet_multicast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol._handle_message_sent, - [ - t.EmberOutgoingMessageType.OUTGOING_MULTICAST, - 0x1234, - sentinel.aps, - sentinel.msg_tag, - t.EmberStatus.SUCCESS, - b"", - ], + app.ezsp_callback_handler, + "messageSentHandler", + list( + dict( + type=t.EmberOutgoingMessageType.OUTGOING_MULTICAST, + indexOrDestination=0x1234, + apsFrame=sentinel.aps, + messageTag=sentinel.msg_tag, + status=t.EmberStatus.SUCCESS, + message=b"", + ).values() + ), ) await app.send_packet(packet) @@ -1579,6 +1603,35 @@ def test_handle_id_conflict(app, ieee): assert app.handle_leave.call_args[0][0] == nwk +async def test_handle_no_such_device(app, ieee): + """Test handling of an unknown device IEEE lookup.""" + + app._ezsp.lookupEui64ByNodeId = AsyncMock() + + p1 = patch.object( + app._ezsp, + "lookupEui64ByNodeId", + AsyncMock(return_value=(t.EmberStatus.ERR_FATAL, ieee)), + ) + p2 = patch.object(app, "handle_join") + with p1 as lookup_mock, p2 as handle_join_mock: + await app._handle_no_such_device(sentinel.nwk) + assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] + assert handle_join_mock.call_count == 0 + + p1 = patch.object( + app._ezsp, + "lookupEui64ByNodeId", + AsyncMock(return_value=(t.EmberStatus.SUCCESS, sentinel.ieee)), + ) + with p1 as lookup_mock, p2 as handle_join_mock: + await app._handle_no_such_device(sentinel.nwk) + assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] + assert handle_join_mock.call_count == 1 + assert handle_join_mock.call_args[0][0] == sentinel.nwk + assert handle_join_mock.call_args[0][1] == sentinel.ieee + + async def test_cleanup_tc_link_key(app): """Test cleaning up tc link key.""" ezsp = app._ezsp diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index d294b7f4..3906eb5e 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -206,11 +206,9 @@ async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - assert len(prot_hndl._handle_callback.mock_calls) == 0 - assert "Fragment reassembly not complete, waiting for more data" in caplog.text - assert prot_hndl._send_fragment_ack.mock_calls == [ - call(sender, aps_frame, 2, 0) - ] + prot_hndl._handle_callback.assert_not_called() + assert "Fragment reassembly not complete. waiting for more data." in caplog.text + mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) async def test_incoming_fragmented_message_complete(prot_hndl, caplog): @@ -243,6 +241,7 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): groupId=513, # fragment_count=2, fragment_index=1 sequence=238, ) + reassembled = b"complete message" with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: mock_ack.return_value = None @@ -257,14 +256,12 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - assert len(prot_hndl._handle_callback.mock_calls) == 0 + prot_hndl._handle_callback.assert_not_called() assert ( - "Reassembled fragmented message, proceeding with handling" + "Reassembled fragmented message. Proceeding with normal handling." not in caplog.text ) - assert prot_hndl._send_fragment_ack.mock_calls == [ - call(sender, aps_frame_1, 2, 0) - ] + mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) # Packet 2 prot_hndl(packet2) @@ -275,24 +272,21 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - # Legacy callback is called with original args (last fragment's payload) - assert prot_hndl._handle_callback.mock_calls == [ - call( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, - aps_frame_2, - 255, # lastHopLqi - -8, # lastHopRssi - sender, - 255, # bindingIndex - 255, # addressIndex - b"message", # Original last fragment payload, not reassembled - ], - ) - ] - assert "Reassembled fragmented message, proceeding with handling" in caplog.text - assert prot_hndl._send_fragment_ack.mock_calls == [ - call(sender, aps_frame_1, 2, 0), - call(sender, aps_frame_2, 2, 1), - ] + prot_hndl._handle_callback.assert_called_once_with( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 + aps_frame_2, # Parsed APS frame + 255, # lastHopLqi: 0xFF + -8, # lastHopRssi: 0xF8 + sender, # 0x1D6F + 255, # bindingIndex: 0xFF + 255, # addressIndex: 0xFF + reassembled, # Reassembled payload + ], + ) + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 75c22a98..49bf152b 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -3,7 +3,6 @@ import pytest import zigpy.exceptions import zigpy.state -import zigpy.types import bellows.ezsp.v14 import bellows.types as t @@ -227,153 +226,3 @@ async def test_send_broadcast(ezsp_f) -> None: message=b"hello", ) ] - - -@pytest.mark.parametrize( - "message_type, expected_dst", - [ - ( - t.EmberIncomingMessageType.INCOMING_UNICAST, - None, - ), - ( - t.EmberIncomingMessageType.INCOMING_BROADCAST, - zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ), - ), - ( - t.EmberIncomingMessageType.INCOMING_MULTICAST, - zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, - address=0x1234, - ), - ), - ], -) -def test_incoming_message_handler(ezsp_f, message_type, expected_dst) -> None: - """Test incomingMessageHandler emits packet_received event.""" - received_packets = [] - ezsp_f.on_event("packet_received", lambda pkt: received_packets.append(pkt)) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x1234, - sequence=0x42, - ) - - ezsp_f.handle_parsed_callback( - "incomingMessageHandler", - [ - message_type, - aps_frame, - t.EmberNodeId(0x1234), # sender nwk - t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), # sender eui64 - 0, # binding_index - 0, # address_index - 200, # lqi - -40, # rssi - 12345678, # timestamp - b"test message", # message - ], - ) - - assert len(received_packets) == 1 - packet = received_packets[0] - assert packet.src == zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(0x1234), - ) - assert packet.src_ep == 1 - assert packet.dst == expected_dst - assert packet.dst_ep == 2 - assert packet.profile_id == 0x0104 - assert packet.cluster_id == 0x0006 - assert packet.data == zigpy.types.SerializableBytes(b"test message") - assert packet.lqi == 200 - assert packet.rssi == -40 - - -def test_incoming_message_handler_ignored_type(ezsp_f) -> None: - """Test incomingMessageHandler ignores unknown message types.""" - received_packets = [] - ezsp_f.on_event("packet_received", lambda pkt: received_packets.append(pkt)) - - aps_frame = t.EmberApsFrame(options=t.EmberApsOption.APS_OPTION_NONE) - ezsp_f.handle_parsed_callback( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, - aps_frame, - t.EmberNodeId(0x1234), - t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), - 0, - 0, - 200, - -40, - 12345678, - b"test", - ], - ) - - assert len(received_packets) == 0 - # Legacy callback should still be called - assert ezsp_f._handle_callback.mock_calls == [ - call( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, - aps_frame, - t.EmberNodeId(0x1234), - t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11"), - 0, - 0, - 200, - -40, - 12345678, - b"test", - ], - ) - ] - - -def test_message_sent_handler(ezsp_f) -> None: - """Test messageSentHandler emits message_sent event.""" - sent_messages = [] - ezsp_f.on_event("message_sent", lambda msg: sent_messages.append(msg)) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0x42, - ) - - ezsp_f.handle_parsed_callback( - "messageSentHandler", - [ - t.sl_Status.OK, - t.EmberOutgoingMessageType.OUTGOING_DIRECT, - t.EmberNodeId(0x1234), - aps_frame, - 0x42, # message_tag - b"sent message", - ], - ) - - assert len(sent_messages) == 1 - status, msg_type, destination, frame, tag, message = sent_messages[0] - assert status == t.sl_Status.OK - assert msg_type == t.EmberOutgoingMessageType.OUTGOING_DIRECT - assert destination == t.EmberNodeId(0x1234) - assert frame == aps_frame - assert tag == 0x42 - assert message == b"sent message" diff --git a/tests/test_ezsp_v18.py b/tests/test_ezsp_v18.py deleted file mode 100644 index db0db68c..00000000 --- a/tests/test_ezsp_v18.py +++ /dev/null @@ -1,164 +0,0 @@ -from unittest.mock import MagicMock, call - -import pytest - -import bellows.ezsp.v18 -import bellows.types as t - -from tests.common import mock_ezsp_commands - - -@pytest.fixture -def ezsp_f(): - """EZSP v18 protocol handler.""" - ezsp = bellows.ezsp.v18.EZSPv18(MagicMock(), MagicMock()) - mock_ezsp_commands(ezsp) - - return ezsp - - -def test_ezsp_frame(ezsp_f): - ezsp_f._seq = 0x22 - data = ezsp_f._ezsp_frame("version", 18) - assert data == b"\x22\x00\x01\x00\x00\x12" - - -def test_ezsp_frame_rx(ezsp_f): - """Test receiving a version frame.""" - ezsp_f(b"\x01\x01\x80\x00\x00\x01\x02\x34\x12") - assert ezsp_f._handle_callback.call_count == 1 - assert ezsp_f._handle_callback.call_args[0][0] == "version" - assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] - - -async def test_send_unicast(ezsp_f) -> None: - ezsp_f.sendUnicast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_unicast( - nwk=0x1234, - aps_frame=aps_frame, - message_tag=0x42, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendUnicast.mock_calls == [ - call( - message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - nwk=0x1234, - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - radius=0, - ), - message_tag=0x42, - message=b"hello", - ) - ] - - -async def test_send_multicast(ezsp_f) -> None: - ezsp_f.sendMulticast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x1234, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_multicast( - aps_frame=aps_frame, - radius=12, - non_member_radius=34, - message_tag=0x42, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendMulticast.mock_calls == [ - call( - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x1234, - sequence=0x34, - radius=12, - ), - hops=12, - broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, - alias=0x0000, - sequence=0x34, - message_tag=0x0042, - message=b"hello", - ) - ] - - -async def test_send_broadcast(ezsp_f) -> None: - ezsp_f.sendBroadcast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_broadcast( - address=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - aps_frame=aps_frame, - radius=12, - message_tag=0x42, - aps_sequence=34, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendBroadcast.mock_calls == [ - call( - alias=0x0000, - destination=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - sequence=34, - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - radius=12, - ), - radius=12, - message_tag=0x42, - message=b"hello", - ) - ] From f41a091d750cca6697d80d9b1838a3a2976a1c24 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 15:43:03 -0500 Subject: [PATCH 09/18] Remove unnecessary unit tests --- tests/test_application.py | 225 -------------------------------------- 1 file changed, 225 deletions(-) diff --git a/tests/test_application.py b/tests/test_application.py index 2f055ec5..01892e3b 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -420,202 +420,6 @@ async def test_startup_no_board_info(app, ieee, caplog): assert "EZSP Radio does not support getMfgToken command" in caplog.text -@pytest.fixture -def aps_frame(): - return t.EmberApsFrame( - profileId=0x1234, - clusterId=0x5678, - sourceEndpoint=0x9A, - destinationEndpoint=0xBC, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0xDE, - ) - - -def _handle_incoming_aps_frame(app, aps_frame, type): - app.ezsp_callback_handler( - "incomingMessageHandler", - list( - dict( - type=type, - apsFrame=aps_frame, - lastHopLqi=123, - lastHopRssi=-45, - sender=0xABCD, - bindingIndex=56, - addressIndex=78, - message=b"test message", - ).values() - ), - ) - - -def test_frame_handler_unicast(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_UNICAST - ) - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.NWK - assert packet.dst.address == app.state.node_info.nwk - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_UNICAST - ] - == 1 - ) - - -def test_frame_handler_broadcast(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST - ) - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.Broadcast - assert packet.dst.address == zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_BCAST - ] - == 1 - ) - - -def test_frame_handler_multicast(app, aps_frame): - aps_frame.groupId = 0xEF12 - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_MULTICAST - ) - - assert app.packet_received.call_count == 1 - - packet = app.packet_received.mock_calls[0].args[0] - assert packet.profile_id == 0x1234 - assert packet.cluster_id == 0x5678 - assert packet.src_ep == 0x9A - assert packet.dst_ep == 0xBC - assert packet.tsn == 0xDE - assert packet.src.addr_mode == zigpy_t.AddrMode.NWK - assert packet.src.address == 0xABCD - assert packet.dst.addr_mode == zigpy_t.AddrMode.Group - assert packet.dst.address == 0xEF12 - assert packet.data.serialize() == b"test message" - assert packet.lqi == 123 - assert packet.rssi == -45 - - assert ( - app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ - bellows.zigbee.application.COUNTER_RX_MCAST - ] - == 1 - ) - - -def test_frame_handler_ignored(app, aps_frame): - _handle_incoming_aps_frame( - app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK - ) - assert app.packet_received.call_count == 0 - - -@pytest.mark.parametrize( - "msg_type", - ( - t.EmberIncomingMessageType.INCOMING_BROADCAST, - t.EmberIncomingMessageType.INCOMING_MULTICAST, - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xFF, - ), -) -async def test_send_failure(app, aps, ieee, msg_type): - fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", [msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b""] - ) - assert fut.result() == (t.sl_Status.OK, "message send success") - - -async def test_dup_send_failure(app, aps, ieee): - fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() - fut.set_result("Already set") - - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, - 0xBEED, - aps, - 254, - sentinel.status, - b"", - ], - ) - - -def test_send_failure_unexpected(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, - 0xBEED, - aps, - 257, - 1, - b"", - ], - ) - - -async def test_send_success(app, aps, ieee): - fut = app._pending_requests[(0xBEED, 253)] = asyncio.Future() - app.ezsp_callback_handler( - "messageSentHandler", - [ - t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, - 0xBEED, - aps, - 253, - t.EmberStatus.SUCCESS, - b"", - ], - ) - - assert fut.result() == (t.sl_Status.OK, "message send success") - - -def test_unexpected_send_success(app, aps, ieee): - app.ezsp_callback_handler( - "messageSentHandler", - [t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b""], - ) - - async def test_join_handler(app, ieee): # Calls device.initialize, leaks a task app.handle_join = MagicMock() @@ -1603,35 +1407,6 @@ def test_handle_id_conflict(app, ieee): assert app.handle_leave.call_args[0][0] == nwk -async def test_handle_no_such_device(app, ieee): - """Test handling of an unknown device IEEE lookup.""" - - app._ezsp.lookupEui64ByNodeId = AsyncMock() - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.ERR_FATAL, ieee)), - ) - p2 = patch.object(app, "handle_join") - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 0 - - p1 = patch.object( - app._ezsp, - "lookupEui64ByNodeId", - AsyncMock(return_value=(t.EmberStatus.SUCCESS, sentinel.ieee)), - ) - with p1 as lookup_mock, p2 as handle_join_mock: - await app._handle_no_such_device(sentinel.nwk) - assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] - assert handle_join_mock.call_count == 1 - assert handle_join_mock.call_args[0][0] == sentinel.nwk - assert handle_join_mock.call_args[0][1] == sentinel.ieee - - async def test_cleanup_tc_link_key(app): """Test cleaning up tc link key.""" ezsp = app._ezsp From 92635fd5df7af4a6b6cc9788d70f4fe7fd3fe0a3 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 16:48:36 -0500 Subject: [PATCH 10/18] Fix tests --- bellows/ezsp/v4/__init__.py | 2 +- tests/test_application.py | 154 ++++++++++++++++++++++++- tests/test_ezsp_protocol.py | 221 +++++++++++++++++++++++++++++------- tests/test_ezsp_v14.py | 97 ++++++++++++++++ tests/test_ezsp_v18.py | 164 ++++++++++++++++++++++++++ 5 files changed, 587 insertions(+), 51 deletions(-) create mode 100644 tests/test_ezsp_v18.py diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index fbda9527..f5712a8b 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -275,7 +275,7 @@ def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: ) = args self._handle_message_sent( - type=message_type, + message_type=message_type, destination=destination, aps_frame=aps_frame, message_tag=message_tag, diff --git a/tests/test_application.py b/tests/test_application.py index 01892e3b..2d15c420 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,6 +16,7 @@ import bellows.config as config from bellows.exception import ControllerError, EzspError, InvalidCommandError import bellows.ezsp as ezsp +from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent from bellows.ezsp.v9.commands import GetTokenDataRsp from bellows.ezsp.xncp import ( FirmwareFeatures, @@ -546,7 +547,7 @@ async def test_request_concurrency_duplicate_failure( ) -> None: def send_unicast(aps_frame, data, message_tag, nwk): asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -598,7 +599,7 @@ async def _test_send_packet_unicast( def send_unicast(*args, **kwargs): asyncio.get_running_loop().call_later( 0.01, - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -852,7 +853,7 @@ async def send_message_sent_reply( await asyncio.sleep(0.01) - app.ezsp_callback_handler( + app._ezsp._protocol.handle_parsed_callback( "messageSentHandler", list( dict( @@ -909,7 +910,7 @@ async def test_send_packet_broadcast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -955,7 +956,7 @@ async def test_send_packet_broadcast_ignored_delivery_failure(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -1008,7 +1009,7 @@ async def test_send_packet_multicast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app.ezsp_callback_handler, + app._ezsp._protocol.handle_parsed_callback, "messageSentHandler", list( dict( @@ -2461,3 +2462,144 @@ async def test_set_tx_power(app: ControllerApplication) -> None: assert result == 12.0 assert app._ezsp.setRadioPower.mock_calls == [call(power=12)] assert mock_update.mock_calls == [call(app._ezsp, tx_power=12)] + + +async def test_reset_resubscribes_events(app: ControllerApplication) -> None: + """Test that _reset unsubscribes, resets, and resubscribes to protocol events.""" + app._ezsp.stop_ezsp = MagicMock() + app._ezsp.startup_reset = AsyncMock() + app._ezsp.write_config = AsyncMock() + + # Add a dummy callback to verify unsubscribe is called + unsubscribe_mock = MagicMock() + app._protocol_on_remove_callbacks.append(unsubscribe_mock) + + await app._reset() + + # Verify unsubscribe was called + assert unsubscribe_mock.mock_calls == [call()] + + # Verify EZSP reset sequence + assert len(app._ezsp.stop_ezsp.mock_calls) == 1 + assert len(app._ezsp.startup_reset.mock_calls) == 1 + assert len(app._ezsp.write_config.mock_calls) == 1 + + # Verify we resubscribed (callbacks list should have 2 entries now) + assert len(app._protocol_on_remove_callbacks) == 2 + + +def test_on_packet_received_unicast(app: ControllerApplication) -> None: + """Test _on_packet_received with unicast message (dst=None gets replaced).""" + app.state.node_info.nwk = zigpy_t.NWK(0x0000) + + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + # Unicast packets have dst=None, protocol handler doesn't know our NWK + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=None, # Will be replaced with our NWK + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with dst replaced + assert packet_received_mock.mock_calls == [ + call( + zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x0000), + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"test"), + lqi=200, + rssi=-40, + ) + ) + ] + + +def test_on_packet_received_broadcast(app: ControllerApplication) -> None: + """Test _on_packet_received with broadcast message.""" + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Broadcast, + address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"broadcast"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with the same packet (dst already set) + assert packet_received_mock.mock_calls == [call(event.packet)] + + +def test_on_packet_received_multicast(app: ControllerApplication) -> None: + """Test _on_packet_received with multicast message.""" + packet_received_mock = MagicMock() + app.packet_received = packet_received_mock + + event = PacketReceivedEvent( + packet=zigpy_t.ZigbeePacket( + src=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.NWK, + address=zigpy_t.NWK(0x1234), + ), + src_ep=1, + dst=zigpy_t.AddrModeAddress( + addr_mode=zigpy_t.AddrMode.Group, + address=0x5678, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy_t.SerializableBytes(b"multicast"), + lqi=200, + rssi=-40, + ) + ) + + app._on_packet_received(event) + + # Verify packet_received was called with the same packet (dst already set) + assert packet_received_mock.mock_calls == [call(event.packet)] diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 3906eb5e..c05db28c 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -3,8 +3,10 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import pytest +import zigpy.types from bellows.ezsp import EZSP +from bellows.ezsp.protocol import PacketReceivedEvent import bellows.ezsp.v4 import bellows.ezsp.v9 from bellows.ezsp.v9.commands import GetTokenDataRsp @@ -206,9 +208,9 @@ async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - prot_hndl._handle_callback.assert_not_called() - assert "Fragment reassembly not complete. waiting for more data." in caplog.text - mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) + assert len(prot_hndl._handle_callback.mock_calls) == 1 + assert "Fragment reassembly not complete, waiting for more data" in caplog.text + assert mock_ack.mock_calls == [call(sender, aps_frame, 2, 0)] async def test_incoming_fragmented_message_complete(prot_hndl, caplog): @@ -221,27 +223,34 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07" + b"message" ) # fragment index 1 - sender = 0x1D6F aps_frame_1 = t.EmberApsFrame( profileId=260, - clusterId=65281, + clusterId=0xFF01, sourceEndpoint=2, destinationEndpoint=2, - options=33088, # Includes APS_OPTION_FRAGMENT - groupId=512, # fragment_count=2, fragment_index=0 + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + | t.EmberApsOption.APS_OPTION_FRAGMENT + ), + groupId=0x0200, # fragment_count=2, fragment_index=0 sequence=238, ) + aps_frame_2 = t.EmberApsFrame( profileId=260, - clusterId=65281, + clusterId=0xFF01, sourceEndpoint=2, destinationEndpoint=2, - options=33088, - groupId=513, # fragment_count=2, fragment_index=1 + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + | t.EmberApsOption.APS_OPTION_FRAGMENT + ), + groupId=0x0201, # fragment_count=2, fragment_index=1 sequence=238, ) - reassembled = b"complete message" with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: mock_ack.return_value = None @@ -250,43 +259,167 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): # Packet 1 prot_hndl(packet1) assert len(prot_hndl._fragment_ack_tasks) == 1 - ack_task = next(iter(prot_hndl._fragment_ack_tasks)) - await asyncio.gather(ack_task) # Ensure task completes and triggers callback - assert ( - len(prot_hndl._fragment_ack_tasks) == 0 - ), "Done callback should have removed task" - - prot_hndl._handle_callback.assert_not_called() - assert ( - "Reassembled fragmented message. Proceeding with normal handling." - not in caplog.text - ) - mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) + await asyncio.gather( + *prot_hndl._fragment_ack_tasks + ) # Ensure task completes and triggers callback + assert len(prot_hndl._fragment_ack_tasks) == 0 # Packet 2 prot_hndl(packet2) assert len(prot_hndl._fragment_ack_tasks) == 1 - ack_task = next(iter(prot_hndl._fragment_ack_tasks)) - await asyncio.gather(ack_task) # Ensure task completes and triggers callback - assert ( - len(prot_hndl._fragment_ack_tasks) == 0 - ), "Done callback should have removed task" + await asyncio.gather( + *prot_hndl._fragment_ack_tasks + ) # Ensure task completes and triggers callback + assert len(prot_hndl._fragment_ack_tasks) == 0 + + assert "Reassembled fragmented message, proceeding with handling" in caplog.text + assert mock_ack.mock_calls == [ + call(0x1D6F, aps_frame_1, 2, 0), + call(0x1D6F, aps_frame_2, 2, 1), + ] + + +def test_incoming_message_broadcast(prot_hndl) -> None: + """Test handling of incoming broadcast message.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + # v4 field order: type, apsFrame, lqi, rssi, sender, bindingIndex, addressIndex, message + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_BROADCAST, + aps_frame, + 200, # lqi + -40, # rssi + t.EmberNodeId(0x1234), # sender + 0, # binding_index + 0, # address_index + b"broadcast message", + ], + ) - prot_hndl._handle_callback.assert_called_once_with( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 - aps_frame_2, # Parsed APS frame - 255, # lastHopLqi: 0xFF - -8, # lastHopRssi: 0xF8 - sender, # 0x1D6F - 255, # bindingIndex: 0xFF - 255, # addressIndex: 0xFF - reassembled, # Reassembled payload - ], + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ), + src_ep=1, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy.types.SerializableBytes(b"broadcast message"), + lqi=200, + rssi=-40, + ) + ) ) - assert ( - "Reassembled fragmented message. Proceeding with normal handling." - in caplog.text + ] + + +def test_incoming_message_multicast(prot_hndl) -> None: + """Test handling of incoming multicast message.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x5678, + sequence=0x42, + ) + + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MULTICAST, + aps_frame, + 200, + -40, + t.EmberNodeId(0x1234), + 0, + 0, + b"multicast message", + ], + ) + + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ), + src_ep=1, + dst=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, + address=0x5678, + ), + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy.types.SerializableBytes(b"multicast message"), + lqi=200, + rssi=-40, + ) + ) ) - mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) + ] + + +def test_incoming_message_ignored_type(prot_hndl, caplog) -> None: + """Test that unknown message types are ignored.""" + handler = MagicMock() + prot_hndl.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + caplog.set_level(logging.DEBUG) + prot_hndl.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, + aps_frame, + 200, + -40, + t.EmberNodeId(0x1234), + 0, + 0, + b"ignored message", + ], + ) + + # No event should be emitted for ignored message types + assert len(handler.mock_calls) == 0 + assert "Ignoring message type" in caplog.text diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 49bf152b..5d9055a7 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -3,7 +3,9 @@ import pytest import zigpy.exceptions import zigpy.state +import zigpy.types +from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent import bellows.ezsp.v14 import bellows.types as t @@ -226,3 +228,98 @@ async def test_send_broadcast(ezsp_f) -> None: message=b"hello", ) ] + + +def test_handle_parsed_callback_incoming_message(ezsp_f) -> None: + """Test handle_parsed_callback for incomingMessageHandler.""" + handler = MagicMock() + ezsp_f.on_event(PacketReceivedEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + # v14 field order: type, apsFrame, lqi, rssi, sender, bindingIndex, addressIndex, message + ezsp_f.handle_parsed_callback( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, + aps_frame, + 200, # lqi + -40, # rssi + t.EmberNodeId(0x1234), # sender + 0, # binding_index + 0, # address_index + b"test message", + ], + ) + + assert handler.mock_calls == [ + call( + PacketReceivedEvent( + packet=zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=zigpy.types.NWK(0x1234), + ), + src_ep=1, + dst=None, + dst_ep=2, + tsn=0x42, + profile_id=0x0104, + cluster_id=0x0006, + data=zigpy.types.SerializableBytes(b"test message"), + lqi=200, + rssi=-40, + ) + ) + ) + ] + + +def test_handle_parsed_callback_message_sent(ezsp_f) -> None: + """Test handle_parsed_callback for messageSentHandler.""" + handler = MagicMock() + ezsp_f.on_event(MessageSentEvent.event_type, handler) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0x42, + ) + + # v14 field order: status, type, nwk, apsFrame, messageTag, message + ezsp_f.handle_parsed_callback( + "messageSentHandler", + [ + t.sl_Status.OK, + t.EmberOutgoingMessageType.OUTGOING_DIRECT, + t.EmberNodeId(0x1234), + aps_frame, + 0x42, # message_tag + b"sent message", + ], + ) + + assert handler.mock_calls == [ + call( + MessageSentEvent( + status=t.sl_Status.OK, + message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + destination=t.EmberNodeId(0x1234), + aps_frame=aps_frame, + message_tag=0x42, + message_contents=b"sent message", + ) + ) + ] diff --git a/tests/test_ezsp_v18.py b/tests/test_ezsp_v18.py new file mode 100644 index 00000000..db0db68c --- /dev/null +++ b/tests/test_ezsp_v18.py @@ -0,0 +1,164 @@ +from unittest.mock import MagicMock, call + +import pytest + +import bellows.ezsp.v18 +import bellows.types as t + +from tests.common import mock_ezsp_commands + + +@pytest.fixture +def ezsp_f(): + """EZSP v18 protocol handler.""" + ezsp = bellows.ezsp.v18.EZSPv18(MagicMock(), MagicMock()) + mock_ezsp_commands(ezsp) + + return ezsp + + +def test_ezsp_frame(ezsp_f): + ezsp_f._seq = 0x22 + data = ezsp_f._ezsp_frame("version", 18) + assert data == b"\x22\x00\x01\x00\x00\x12" + + +def test_ezsp_frame_rx(ezsp_f): + """Test receiving a version frame.""" + ezsp_f(b"\x01\x01\x80\x00\x00\x01\x02\x34\x12") + assert ezsp_f._handle_callback.call_count == 1 + assert ezsp_f._handle_callback.call_args[0][0] == "version" + assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] + + +async def test_send_unicast(ezsp_f) -> None: + ezsp_f.sendUnicast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_unicast( + nwk=0x1234, + aps_frame=aps_frame, + message_tag=0x42, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendUnicast.mock_calls == [ + call( + message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, + nwk=0x1234, + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + radius=0, + ), + message_tag=0x42, + message=b"hello", + ) + ] + + +async def test_send_multicast(ezsp_f) -> None: + ezsp_f.sendMulticast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x1234, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_multicast( + aps_frame=aps_frame, + radius=12, + non_member_radius=34, + message_tag=0x42, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendMulticast.mock_calls == [ + call( + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x1234, + sequence=0x34, + radius=12, + ), + hops=12, + broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, + alias=0x0000, + sequence=0x34, + message_tag=0x0042, + message=b"hello", + ) + ] + + +async def test_send_broadcast(ezsp_f) -> None: + ezsp_f.sendBroadcast.return_value = (t.sl_Status.OK, 0x0042) + + aps_frame = t.EmberApsFrame( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + ) + + status, message_tag = await ezsp_f.send_broadcast( + address=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + aps_frame=aps_frame, + radius=12, + message_tag=0x42, + aps_sequence=34, + data=b"hello", + ) + + assert status == t.sl_Status.OK + assert message_tag == 0x42 + assert ezsp_f.sendBroadcast.mock_calls == [ + call( + alias=0x0000, + destination=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, + sequence=34, + aps_frame=t.EmberApsFrameV18( + profileId=0x0104, + clusterId=0x0006, + sourceEndpoint=1, + destinationEndpoint=2, + options=t.EmberApsOption.APS_OPTION_RETRY, + groupId=0x0000, + sequence=0x34, + radius=12, + ), + radius=12, + message_tag=0x42, + message=b"hello", + ) + ] From 072ee73171c14803d5042b5b0db2a9e04c97614b Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 17:58:51 -0500 Subject: [PATCH 11/18] Revert EmberApsFrameV18 addition --- bellows/ezsp/__init__.py | 2 +- bellows/ezsp/protocol.py | 2 +- bellows/ezsp/v14/__init__.py | 6 +- bellows/ezsp/v18/__init__.py | 86 ---------------------- bellows/ezsp/v18/commands.py | 87 ---------------------- bellows/types/struct.py | 20 ------ tests/test_ezsp_v18.py | 135 +---------------------------------- 7 files changed, 7 insertions(+), 331 deletions(-) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index b5bd087f..e9bbbd10 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -365,7 +365,7 @@ def frame_received(self, data: bytes) -> None: try: self._protocol(data) except Exception: - LOGGER.warning("Failed to parse frame, ignoring") + LOGGER.warning("Failed to parse frame. This is a bug!", exc_info=True) async def get_board_info( self, diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 52498fd7..c045c0da 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -265,7 +265,7 @@ async def _send_fragment_ack( def _handle_incoming_message( self, message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame | t.EmberApsFrameV18, + aps_frame: t.EmberApsFrame, sender: zigpy.types.NWK, eui64: zigpy.types.EUI64 | None, binding_index: t.uint8_t, diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index 4fb2ff6f..0877cb3c 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -156,11 +156,13 @@ def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: ( message_type, aps_frame, - lqi, - rssi, sender, + _eui64, binding_index, address_index, + lqi, + rssi, + _timestamp, message, ) = args diff --git a/bellows/ezsp/v18/__init__.py b/bellows/ezsp/v18/__init__.py index 996ee8d2..f087a8d1 100644 --- a/bellows/ezsp/v18/__init__.py +++ b/bellows/ezsp/v18/__init__.py @@ -4,7 +4,6 @@ import voluptuous as vol import bellows.config -import bellows.types as t from . import commands, config from ..v17 import EZSPv17 @@ -19,88 +18,3 @@ class EZSPv18(EZSPv17): bellows.config.CONF_EZSP_CONFIG: vol.Schema(config.EZSP_SCHEMA), bellows.config.CONF_EZSP_POLICIES: vol.Schema(config.EZSP_POLICIES_SCH), } - - async def send_unicast( - self, - nwk: t.NWK, - aps_frame: t.EmberApsFrame, - message_tag: t.uint8_t, - data: bytes, - ) -> tuple[t.sl_Status, t.uint8_t]: - status, sequence = await self.sendUnicast( - message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - nwk=nwk, - aps_frame=t.EmberApsFrameV18( - profileId=aps_frame.profileId, - clusterId=aps_frame.clusterId, - sourceEndpoint=aps_frame.sourceEndpoint, - destinationEndpoint=aps_frame.destinationEndpoint, - options=aps_frame.options, - groupId=aps_frame.groupId, - sequence=aps_frame.sequence, - radius=0, - ), - message_tag=message_tag, - message=data, - ) - - return status, sequence - - async def send_multicast( - self, - aps_frame: t.EmberApsFrame, - radius: t.uint8_t, - non_member_radius: t.uint8_t, - message_tag: t.uint8_t, - data: bytes, - ) -> tuple[t.sl_Status, t.uint8_t]: - status, sequence = await self.sendMulticast( - aps_frame=t.EmberApsFrameV18( - profileId=aps_frame.profileId, - clusterId=aps_frame.clusterId, - sourceEndpoint=aps_frame.sourceEndpoint, - destinationEndpoint=aps_frame.destinationEndpoint, - options=aps_frame.options, - groupId=aps_frame.groupId, - sequence=aps_frame.sequence, - radius=radius, - ), - hops=radius, - broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, - alias=0x0000, - sequence=aps_frame.sequence, - message_tag=message_tag, - message=data, - ) - - return status, sequence - - async def send_broadcast( - self, - address: t.BroadcastAddress, - aps_frame: t.EmberApsFrame, - radius: t.uint8_t, - message_tag: t.uint8_t, - aps_sequence: t.uint8_t, - data: bytes, - ) -> tuple[t.sl_Status, t.uint8_t]: - status, sequence = await self.sendBroadcast( - alias=0x0000, - destination=address, - sequence=aps_sequence, - aps_frame=t.EmberApsFrameV18( - profileId=aps_frame.profileId, - clusterId=aps_frame.clusterId, - sourceEndpoint=aps_frame.sourceEndpoint, - destinationEndpoint=aps_frame.destinationEndpoint, - options=aps_frame.options, - groupId=aps_frame.groupId, - sequence=aps_frame.sequence, - radius=radius, - ), - radius=radius, - message_tag=message_tag, - message=data, - ) - - return status, sequence diff --git a/bellows/ezsp/v18/commands.py b/bellows/ezsp/v18/commands.py index fdec8623..ff458c4c 100644 --- a/bellows/ezsp/v18/commands.py +++ b/bellows/ezsp/v18/commands.py @@ -1,96 +1,9 @@ -from zigpy.types import EUI64, NWK, BroadcastAddress - import bellows.types as t from ..v17.commands import COMMANDS as COMMANDS_v17 COMMANDS = { **COMMANDS_v17, - "sendUnicast": ( - 0x0034, - { - "message_type": t.EmberOutgoingMessageType, - "nwk": NWK, - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "message_tag": t.uint16_t, - "message": t.LVBytes, - }, - { - "status": t.sl_Status, - "sequence": t.uint8_t, - }, - ), - "sendBroadcast": ( - 0x0036, - { - "alias": t.uint16_t, - "destination": BroadcastAddress, - "sequence": t.uint8_t, - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "radius": t.uint8_t, - "message_tag": t.uint16_t, - "message": t.LVBytes, - }, - { - "status": t.sl_Status, - "sequence": t.uint8_t, - }, - ), - "sendMulticast": ( - 0x0038, - { - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "hops": t.uint8_t, - "broadcast_addr": t.BroadcastAddress, - "alias": t.uint16_t, - "sequence": t.uint8_t, - "message_tag": t.uint16_t, - "message": t.LVBytes, - }, - { - "status": t.sl_Status, - "sequence": t.uint8_t, - }, - ), - "sendReply": ( - 0x0039, - { - "sender": t.NWK, - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "message": t.LVBytes, - }, - { - "status": t.sl_Status, - }, - ), - "incomingMessageHandler": ( - 0x0045, - {}, - { - "message_type": t.EmberIncomingMessageType, - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "nwk": NWK, - "eui64": EUI64, - "binding_index": t.uint8_t, - "address_index": t.uint8_t, - "lqi": t.uint8_t, - "rssi": t.int8s, - "timestamp": t.uint32_t, - "message": t.LVBytes, - }, - ), - "messageSentHandler": ( - 0x003F, - {}, - { - "status": t.sl_Status, - "message_type": t.EmberOutgoingMessageType, - "nwk": NWK, - "aps_frame": t.EmberApsFrameV18, # APS frame format has changed - "message_tag": t.uint16_t, - "message": t.LVBytes, - }, - ), "macFilterMatchMessageHandler": ( 0x46, {}, diff --git a/bellows/types/struct.py b/bellows/types/struct.py index df34a440..91a17c01 100644 --- a/bellows/types/struct.py +++ b/bellows/types/struct.py @@ -67,26 +67,6 @@ class EmberApsFrame(EzspStruct): sequence: basic.uint8_t -class EmberApsFrameV18(EzspStruct): - # ZigBee APS frame parameters (EZSP v18+). - # The application profile ID that describes the format of the message. - profileId: basic.uint16_t - # The cluster ID for this message. - clusterId: basic.uint16_t - # The source endpoint. - sourceEndpoint: basic.uint8_t - # The destination endpoint. - destinationEndpoint: basic.uint8_t - # A bitmask of options. - options: named.EmberApsOption - # The group ID for this message, if it is multicast mode. - groupId: basic.uint16_t - # The sequence number. - sequence: basic.uint8_t - # The radius of the message. (Added in EZSP v18) - radius: basic.uint8_t - - class EmberBindingTableEntry(EzspStruct): # An entry in the binding table. # The type of binding. diff --git a/tests/test_ezsp_v18.py b/tests/test_ezsp_v18.py index db0db68c..b0c8c90e 100644 --- a/tests/test_ezsp_v18.py +++ b/tests/test_ezsp_v18.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, call +from unittest.mock import MagicMock import pytest @@ -29,136 +29,3 @@ def test_ezsp_frame_rx(ezsp_f): assert ezsp_f._handle_callback.call_count == 1 assert ezsp_f._handle_callback.call_args[0][0] == "version" assert ezsp_f._handle_callback.call_args[0][1] == [0x01, 0x02, 0x1234] - - -async def test_send_unicast(ezsp_f) -> None: - ezsp_f.sendUnicast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_unicast( - nwk=0x1234, - aps_frame=aps_frame, - message_tag=0x42, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendUnicast.mock_calls == [ - call( - message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - nwk=0x1234, - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - radius=0, - ), - message_tag=0x42, - message=b"hello", - ) - ] - - -async def test_send_multicast(ezsp_f) -> None: - ezsp_f.sendMulticast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x1234, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_multicast( - aps_frame=aps_frame, - radius=12, - non_member_radius=34, - message_tag=0x42, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendMulticast.mock_calls == [ - call( - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x1234, - sequence=0x34, - radius=12, - ), - hops=12, - broadcast_addr=t.BroadcastAddress.RX_ON_WHEN_IDLE, - alias=0x0000, - sequence=0x34, - message_tag=0x0042, - message=b"hello", - ) - ] - - -async def test_send_broadcast(ezsp_f) -> None: - ezsp_f.sendBroadcast.return_value = (t.sl_Status.OK, 0x0042) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - ) - - status, message_tag = await ezsp_f.send_broadcast( - address=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - aps_frame=aps_frame, - radius=12, - message_tag=0x42, - aps_sequence=34, - data=b"hello", - ) - - assert status == t.sl_Status.OK - assert message_tag == 0x42 - assert ezsp_f.sendBroadcast.mock_calls == [ - call( - alias=0x0000, - destination=t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - sequence=34, - aps_frame=t.EmberApsFrameV18( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_RETRY, - groupId=0x0000, - sequence=0x34, - radius=12, - ), - radius=12, - message_tag=0x42, - message=b"hello", - ) - ] From adba8a32b9f0a7d0f75960d746623b29ced83abb Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 18:38:38 -0500 Subject: [PATCH 12/18] Use dicts for tests --- tests/test_ezsp_v14.py | 115 +++++++++++++++++++++++------------------ 1 file changed, 64 insertions(+), 51 deletions(-) diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 5d9055a7..9eab1e96 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -235,29 +235,31 @@ def test_handle_parsed_callback_incoming_message(ezsp_f) -> None: handler = MagicMock() ezsp_f.on_event(PacketReceivedEvent.event_type, handler) - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0x42, - ) - - # v14 field order: type, apsFrame, lqi, rssi, sender, bindingIndex, addressIndex, message ezsp_f.handle_parsed_callback( "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_UNICAST, - aps_frame, - 200, # lqi - -40, # rssi - t.EmberNodeId(0x1234), # sender - 0, # binding_index - 0, # address_index - b"test message", - ], + { + "message_type": t.EmberIncomingMessageType.INCOMING_UNICAST, + "aps_frame": t.EmberApsFrame( + profileId=260, + clusterId=8, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=168, + ), + "nwk": 0x1174, + "eui64": t.EUI64.convert("00:00:00:00:00:00:00:00"), + "binding_index": 255, + "address_index": 13, + "lqi": 192, + "rssi": -63, + "timestamp": 1333671578, + "message": b"\x18,\x0b\x04\x00", + }.values(), ) assert handler.mock_calls == [ @@ -266,17 +268,17 @@ def test_handle_parsed_callback_incoming_message(ezsp_f) -> None: packet=zigpy.types.ZigbeePacket( src=zigpy.types.AddrModeAddress( addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(0x1234), + address=zigpy.types.NWK(0x1174), ), src_ep=1, dst=None, - dst_ep=2, - tsn=0x42, + dst_ep=1, + tsn=168, profile_id=0x0104, - cluster_id=0x0006, - data=zigpy.types.SerializableBytes(b"test message"), - lqi=200, - rssi=-40, + cluster_id=0x0008, + data=zigpy.types.SerializableBytes(b"\x18,\x0b\x04\x00"), + lqi=192, + rssi=-63, ) ) ) @@ -288,27 +290,27 @@ def test_handle_parsed_callback_message_sent(ezsp_f) -> None: handler = MagicMock() ezsp_f.on_event(MessageSentEvent.event_type, handler) - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0x42, - ) - - # v14 field order: status, type, nwk, apsFrame, messageTag, message ezsp_f.handle_parsed_callback( "messageSentHandler", - [ - t.sl_Status.OK, - t.EmberOutgoingMessageType.OUTGOING_DIRECT, - t.EmberNodeId(0x1234), - aps_frame, - 0x42, # message_tag - b"sent message", - ], + { + "status": t.sl_Status.OK, + "message_type": t.EmberOutgoingMessageType.OUTGOING_DIRECT, + "nwk": 0x0E0D, + "aps_frame": t.EmberApsFrame( + profileId=260, + clusterId=513, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=236, + ), + "message_tag": 103, + "message": b"", + }.values(), ) assert handler.mock_calls == [ @@ -316,10 +318,21 @@ def test_handle_parsed_callback_message_sent(ezsp_f) -> None: MessageSentEvent( status=t.sl_Status.OK, message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - destination=t.EmberNodeId(0x1234), - aps_frame=aps_frame, - message_tag=0x42, - message_contents=b"sent message", + destination=t.EmberNodeId(0x0E0D), + aps_frame=t.EmberApsFrame( + profileId=260, + clusterId=513, + sourceEndpoint=1, + destinationEndpoint=1, + options=( + t.EmberApsOption.APS_OPTION_RETRY + | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY + ), + groupId=0, + sequence=236, + ), + message_tag=103, + message_contents=b"", ) ) ] From 97cb90cdabe15ee81460c764a53f4d3dc7e6e069 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:09:02 -0500 Subject: [PATCH 13/18] Migrate the rest to events --- bellows/ezsp/protocol.py | 29 +++++++ bellows/ezsp/v14/__init__.py | 31 +++++++ bellows/ezsp/v4/__init__.py | 31 +++++++ bellows/zigbee/application.py | 96 +++++++++++---------- tests/test_application.py | 152 +++++++++++++++++----------------- 5 files changed, 216 insertions(+), 123 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index c045c0da..58465b98 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -49,6 +49,35 @@ class PacketReceivedEvent: packet: zigpy.types.ZigbeePacket +@dataclass(frozen=True, kw_only=True) +class TrustCenterJoinEvent: + event_type: Final[str] = "trust_center_join" + + nwk: t.EmberNodeId + ieee: t.EUI64 + device_update_status: t.EmberDeviceUpdate + decision: t.EmberJoinDecision + parent_nwk: t.EmberNodeId + + +@dataclass(frozen=True, kw_only=True) +class RouteRecordEvent: + event_type: Final[str] = "route_record" + + nwk: t.EmberNodeId + ieee: t.EUI64 + lqi: t.uint8_t + rssi: t.int8s + relays: t.LVList[t.EmberNodeId] + + +@dataclass(frozen=True, kw_only=True) +class IdConflictEvent: + event_type: Final[str] = "id_conflict" + + nwk: t.EmberNodeId + + class ProtocolHandler(EventBase, abc.ABC): """EZSP protocol specific handler.""" diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index 0877cb3c..b8ed240c 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -14,6 +14,7 @@ import bellows.types as t from . import commands, config +from ..protocol import IdConflictEvent, RouteRecordEvent, TrustCenterJoinEvent from ..v13 import EZSPv13 LOGGER = logging.getLogger(__name__) @@ -196,3 +197,33 @@ def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: status=status, message_contents=message, ) + elif frame_name == "trustCenterJoinHandler": + nwk, ieee, device_update_status, decision, parent_nwk = args + self.emit( + TrustCenterJoinEvent.event_type, + TrustCenterJoinEvent( + nwk=nwk, + ieee=ieee, + device_update_status=device_update_status, + decision=decision, + parent_nwk=parent_nwk, + ), + ) + elif frame_name == "incomingRouteRecordHandler": + nwk, ieee, lqi, rssi, relays = args + self.emit( + RouteRecordEvent.event_type, + RouteRecordEvent( + nwk=nwk, + ieee=ieee, + lqi=lqi, + rssi=rssi, + relays=relays, + ), + ) + elif frame_name == "idConflictHandler": + (nwk,) = args + self.emit( + IdConflictEvent.event_type, + IdConflictEvent(nwk=nwk), + ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index f5712a8b..ac0cf321 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -16,6 +16,7 @@ from . import commands, config from .. import protocol +from ..protocol import IdConflictEvent, RouteRecordEvent, TrustCenterJoinEvent LOGGER = logging.getLogger(__name__) @@ -282,3 +283,33 @@ def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: status=t.sl_Status.from_ember_status(status), message_contents=message, ) + elif frame_name == "trustCenterJoinHandler": + nwk, ieee, device_update_status, decision, parent_nwk = args + self.emit( + TrustCenterJoinEvent.event_type, + TrustCenterJoinEvent( + nwk=nwk, + ieee=ieee, + device_update_status=device_update_status, + decision=decision, + parent_nwk=parent_nwk, + ), + ) + elif frame_name == "incomingRouteRecordHandler": + nwk, ieee, lqi, rssi, relays = args + self.emit( + RouteRecordEvent.event_type, + RouteRecordEvent( + nwk=nwk, + ieee=ieee, + lqi=lqi, + rssi=rssi, + relays=relays, + ), + ) + elif frame_name == "idConflictHandler": + (nwk,) = args + self.emit( + IdConflictEvent.event_type, + IdConflictEvent(nwk=nwk), + ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 6a6ff271..51db4585 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -39,7 +39,13 @@ StackAlreadyRunning, ) import bellows.ezsp -from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent +from bellows.ezsp.protocol import ( + IdConflictEvent, + MessageSentEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) from bellows.ezsp.xncp import FirmwareFeatures import bellows.multicast import bellows.types as t @@ -242,7 +248,6 @@ async def start_network(self): for cnt_group in self.state.counters: cnt_group.reset() - ezsp.add_callback(self.ezsp_callback_handler) self._subscribe_to_protocol_events() self.controller_event.set() @@ -632,6 +637,21 @@ def _subscribe_to_protocol_events(self) -> None: MessageSentEvent.event_type, self._on_message_sent ) ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + TrustCenterJoinEvent.event_type, self._on_trust_center_join + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + RouteRecordEvent.event_type, self._on_route_record + ) + ) + self._protocol_on_remove_callbacks.append( + self._ezsp._protocol.on_event( + IdConflictEvent.event_type, self._on_id_conflict + ) + ) async def disconnect(self): # TODO: how do you shut down the stack? @@ -646,19 +666,6 @@ async def force_remove(self, dev): # of the device itself. await self._ezsp.removeDevice(dev.nwk, dev.ieee, dev.ieee) - def ezsp_callback_handler(self, frame_name, args): - LOGGER.debug("Received %s frame with %s", frame_name, args) - if frame_name == "trustCenterJoinHandler": - self._handle_tc_join_handler(*args) - elif frame_name == "incomingRouteRecordHandler": - self.handle_route_record(*args) - elif frame_name == "incomingRouteErrorHandler": - status, nwk = args - status = t.sl_Status.from_ember_status(status) - self.handle_route_error(status, nwk) - elif frame_name == "idConflictHandler": - self._handle_id_conflict(*args) - def _on_packet_received(self, message: PacketReceivedEvent) -> None: """Handle packet_received event from protocol handler.""" packet = message.packet @@ -728,33 +735,31 @@ def _on_message_sent(self, event: MessageSentEvent) -> None: exc, ) - def _handle_tc_join_handler( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - device_update_status: t.EmberDeviceUpdate, - decision: t.EmberJoinDecision, - parent_nwk: t.EmberNodeId, - ) -> None: - """Trust Center Join handler.""" - if device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: - self.handle_leave(nwk, ieee) + def _on_trust_center_join(self, event: TrustCenterJoinEvent) -> None: + """Handle trust_center_join event from protocol handler.""" + if event.device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: + self.handle_leave(event.nwk, event.ieee) return - if device_update_status == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN: - self.create_task(self.cleanup_tc_link_key(ieee), "cleanup_tc_link_key") + if ( + event.device_update_status + == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN + ): + self.create_task( + self.cleanup_tc_link_key(event.ieee), "cleanup_tc_link_key" + ) - if decision == t.EmberJoinDecision.DENY_JOIN: + if event.decision == t.EmberJoinDecision.DENY_JOIN: # no point in handling the join if it was denied return - mfg_id = IEEE_PREFIX_MFG_ID.get(str(ieee)[:8].upper()) + mfg_id = IEEE_PREFIX_MFG_ID.get(str(event.ieee)[:8].upper()) if mfg_id is not None: if self._mfg_id_task and not self._mfg_id_task.done(): self._mfg_id_task.cancel() self._mfg_id_task = asyncio.create_task(self._reset_mfg_id(mfg_id)) - self.handle_join(nwk, ieee, parent_nwk) + self.handle_join(event.nwk, event.ieee, event.parent_nwk) async def _reset_mfg_id(self, mfg_id: int) -> None: """Resets manufacturer id if was temporary overridden by a joining device.""" @@ -1048,20 +1053,21 @@ async def permit_with_link_key( return await super().permit(time_s) - def _handle_id_conflict(self, nwk: t.EmberNodeId) -> None: - LOGGER.warning("NWK conflict is reported for 0x%04x", nwk) + def _on_id_conflict(self, event: IdConflictEvent) -> None: + """Handle id_conflict event from protocol handler.""" + LOGGER.warning("NWK conflict is reported for 0x%04x", event.nwk) self.state.counters[COUNTERS_CTRL][COUNTER_NWK_CONFLICTS].increment() for device in self.devices.values(): - if device.nwk != nwk: + if device.nwk != event.nwk: continue LOGGER.warning( "Found %s device for 0x%04x NWK conflict: %s %s", device.ieee, - nwk, + event.nwk, device.manufacturer, device.model, ) - self.handle_leave(nwk, device.ieee) + self.handle_leave(event.nwk, device.ieee) async def _watchdog_loop(self): self._watchdog_failures = 0 @@ -1122,18 +1128,10 @@ async def _get_free_buffers(self) -> int | None: LOGGER.debug("Free buffers status %s, value: %s", status, buffers) return buffers - def handle_route_record( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - lqi: t.uint8_t, - rssi: t.int8s, - relays: t.LVList[t.EmberNodeId], - ) -> None: + def _on_route_record(self, event: RouteRecordEvent) -> None: + """Handle route_record event from protocol handler.""" LOGGER.debug( - "Processing route record request: %s", (nwk, ieee, lqi, rssi, relays) + "Processing route record request: %s", + (event.nwk, event.ieee, event.lqi, event.rssi, event.relays), ) - self.handle_relays(nwk=nwk, relays=relays) - - def handle_route_error(self, status: t.sl_Status, nwk: t.EmberNodeId) -> None: - LOGGER.debug("Processing route error: status=%s, nwk=%s", status, nwk) + self.handle_relays(nwk=event.nwk, relays=event.relays) diff --git a/tests/test_application.py b/tests/test_application.py index 2d15c420..f19eede1 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,7 +16,13 @@ import bellows.config as config from bellows.exception import ControllerError, EzspError, InvalidCommandError import bellows.ezsp as ezsp -from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent +from bellows.ezsp.protocol import ( + IdConflictEvent, + MessageSentEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) from bellows.ezsp.v9.commands import GetTokenDataRsp from bellows.ezsp.xncp import ( FirmwareFeatures, @@ -425,15 +431,14 @@ async def test_join_handler(app, ieee): # Calls device.initialize, leaks a task app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - ieee, - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0) assert ieee not in app.devices @@ -447,15 +452,14 @@ async def test_join_handler(app, ieee): # cleanup TCLK, but no join handling app.handle_join.reset_mock() app.cleanup_tc_link_key.reset_mock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - ieee, - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.DENY_JOIN, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.DENY_JOIN, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0) assert app.cleanup_tc_link_key.await_count == 1 @@ -466,8 +470,14 @@ async def test_join_handler(app, ieee): def test_leave_handler(app, ieee): app.handle_join = MagicMock() app.devices[ieee] = MagicMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", [1, ieee, t.EmberDeviceUpdate.DEVICE_LEFT, None, None] + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.DEVICE_LEFT, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=t.EmberNodeId(0x0000), + ) ) assert ieee in app.devices assert app.handle_join.call_count == 0 @@ -731,15 +741,14 @@ async def test_send_packet_unicast_extended_timeout_with_acks(app, ieee, packet) asyncio.get_running_loop().call_later( 0.1, - app.ezsp_callback_handler, - "incomingRouteRecordHandler", - { - "source": packet.dst.address, - "sourceEui": ieee, - "lastHopLqi": 123, - "lastHopRssi": -60, - "relayList": [0x1234], - }.values(), + app._on_route_record, + RouteRecordEvent( + nwk=packet.dst.address, + ieee=ieee, + lqi=123, + rssi=-60, + relays=[0x1234], + ), ) await _test_send_packet_unicast( @@ -761,15 +770,14 @@ async def test_send_packet_unicast_extended_timeout_without_acks(app, ieee, pack asyncio.get_running_loop().call_later( 0.1, - app.ezsp_callback_handler, - "incomingRouteRecordHandler", - { - "source": packet.dst.address, - "sourceEui": ieee, - "lastHopLqi": 123, - "lastHopRssi": -60, - "relayList": [0x1234], - }.values(), + app._on_route_record, + RouteRecordEvent( + nwk=packet.dst.address, + ieee=ieee, + lqi=123, + rssi=-60, + relays=[0x1234], + ), ) await _test_send_packet_unicast( @@ -1378,20 +1386,18 @@ def test_coordinator_model_manuf(coordinator): def test_handle_route_record(app): """Test route record handling for an existing device.""" app.handle_relays = MagicMock(spec_set=app.handle_relays) - app.ezsp_callback_handler( - "incomingRouteRecordHandler", - [sentinel.nwk, sentinel.ieee, sentinel.lqi, sentinel.rssi, sentinel.relays], - ) - app.handle_relays.assert_called_once_with(nwk=sentinel.nwk, relays=sentinel.relays) - - -def test_handle_route_error(app): - """Test route error handler.""" - app.handle_relays = MagicMock(spec_set=app.handle_relays) - app.ezsp_callback_handler( - "incomingRouteErrorHandler", [sentinel.status, sentinel.nwk] + app._on_route_record( + RouteRecordEvent( + nwk=sentinel.nwk, + ieee=sentinel.ieee, + lqi=sentinel.lqi, + rssi=sentinel.rssi, + relays=sentinel.relays, + ) ) - app.handle_relays.assert_not_called() + assert app.handle_relays.mock_calls == [ + call(nwk=sentinel.nwk, relays=sentinel.relays) + ] def test_handle_id_conflict(app, ieee): @@ -1400,10 +1406,10 @@ def test_handle_id_conflict(app, ieee): app.add_device(ieee, nwk) app.handle_leave = MagicMock() - app.ezsp_callback_handler("idConflictHandler", [nwk + 1]) + app._on_id_conflict(IdConflictEvent(nwk=nwk + 1)) assert app.handle_leave.call_count == 0 - app.ezsp_callback_handler("idConflictHandler", [nwk]) + app._on_id_conflict(IdConflictEvent(nwk=nwk)) assert app.handle_leave.call_count == 1 assert app.handle_leave.call_args[0][0] == nwk @@ -1452,26 +1458,24 @@ async def test_set_mfg_id(ieee, expected_mfg_id, app): app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - t.EUI64.convert(ieee), - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=t.EUI64.convert(ieee), + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) # preempt - app.ezsp_callback_handler( - "trustCenterJoinHandler", - [ - 1, - t.EUI64.convert(ieee), - t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - t.EmberJoinDecision.NO_ACTION, - sentinel.parent, - ], + app._on_trust_center_join( + TrustCenterJoinEvent( + nwk=1, + ieee=t.EUI64.convert(ieee), + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=sentinel.parent, + ) ) await asyncio.sleep(0.20) if expected_mfg_id is not None: @@ -2484,8 +2488,8 @@ async def test_reset_resubscribes_events(app: ControllerApplication) -> None: assert len(app._ezsp.startup_reset.mock_calls) == 1 assert len(app._ezsp.write_config.mock_calls) == 1 - # Verify we resubscribed (callbacks list should have 2 entries now) - assert len(app._protocol_on_remove_callbacks) == 2 + # Verify we resubscribed (callbacks list should have 5 entries now) + assert len(app._protocol_on_remove_callbacks) == 5 def test_on_packet_received_unicast(app: ControllerApplication) -> None: From f2b7b21a38116a3f3a11526e77bf5fecd6771be9 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:21:03 -0500 Subject: [PATCH 14/18] Use named callbacks to avoid duplication --- bellows/ezsp/protocol.py | 54 ++++++++++++++-- bellows/ezsp/v14/__init__.py | 121 +++++++++++++---------------------- bellows/ezsp/v4/__init__.py | 117 ++++++++++++--------------------- 3 files changed, 134 insertions(+), 158 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 58465b98..9821a14b 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -257,13 +257,15 @@ def __call__(self, data: bytes) -> None: self.handle_parsed_callback(frame_name, result) - # Always call legacy callback handler for backwards compatibility + # Legacy callback system for CLI tools self._handle_callback(frame_name, result) - @abc.abstractmethod def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: - """Handle a parsed callback frame.""" - raise NotImplementedError + """Dispatch a callback frame to the appropriate handler method.""" + handler = getattr(self, f"_handle_{frame_name}", None) + + if handler is not None: + handler(*args) async def _send_fragment_ack( self, @@ -397,6 +399,50 @@ def _handle_message_sent( ), ) + def _handle_trustCenterJoinHandler( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + device_update_status: t.EmberDeviceUpdate, + decision: t.EmberJoinDecision, + parent_nwk: t.EmberNodeId, + ) -> None: + self.emit( + TrustCenterJoinEvent.event_type, + TrustCenterJoinEvent( + nwk=nwk, + ieee=ieee, + device_update_status=device_update_status, + decision=decision, + parent_nwk=parent_nwk, + ), + ) + + def _handle_incomingRouteRecordHandler( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + lqi: t.uint8_t, + rssi: t.int8s, + relays: t.LVList[t.EmberNodeId], + ) -> None: + self.emit( + RouteRecordEvent.event_type, + RouteRecordEvent( + nwk=nwk, + ieee=ieee, + lqi=lqi, + rssi=rssi, + relays=relays, + ), + ) + + def _handle_idConflictHandler(self, nwk: t.EmberNodeId) -> None: + self.emit( + IdConflictEvent.event_type, + IdConflictEvent(nwk=nwk), + ) + def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index b8ed240c..773342c3 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -3,7 +3,6 @@ from collections.abc import AsyncGenerator import logging -from typing import Any import voluptuous as vol from zigpy.exceptions import NetworkNotFormed @@ -14,7 +13,6 @@ import bellows.types as t from . import commands, config -from ..protocol import IdConflictEvent, RouteRecordEvent, TrustCenterJoinEvent from ..v13 import EZSPv13 LOGGER = logging.getLogger(__name__) @@ -151,79 +149,46 @@ async def send_broadcast( return status, sequence - def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: - """Handle a parsed callback frame.""" - if frame_name == "incomingMessageHandler": - ( - message_type, - aps_frame, - sender, - _eui64, - binding_index, - address_index, - lqi, - rssi, - _timestamp, - message, - ) = args - - self._handle_incoming_message( - message_type=message_type, - aps_frame=aps_frame, - sender=sender, - eui64=None, - binding_index=binding_index, - address_index=address_index, - lqi=lqi, - rssi=rssi, - timestamp=None, - message=message, - ) - elif frame_name == "messageSentHandler": - ( - status, - message_type, - nwk, - aps_frame, - message_tag, - message, - ) = args - - self._handle_message_sent( - message_type=message_type, - destination=nwk, - aps_frame=aps_frame, - message_tag=message_tag, - status=status, - message_contents=message, - ) - elif frame_name == "trustCenterJoinHandler": - nwk, ieee, device_update_status, decision, parent_nwk = args - self.emit( - TrustCenterJoinEvent.event_type, - TrustCenterJoinEvent( - nwk=nwk, - ieee=ieee, - device_update_status=device_update_status, - decision=decision, - parent_nwk=parent_nwk, - ), - ) - elif frame_name == "incomingRouteRecordHandler": - nwk, ieee, lqi, rssi, relays = args - self.emit( - RouteRecordEvent.event_type, - RouteRecordEvent( - nwk=nwk, - ieee=ieee, - lqi=lqi, - rssi=rssi, - relays=relays, - ), - ) - elif frame_name == "idConflictHandler": - (nwk,) = args - self.emit( - IdConflictEvent.event_type, - IdConflictEvent(nwk=nwk), - ) + def _handle_incomingMessageHandler( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + sender: t.EmberNodeId, + eui64: t.EUI64, + binding_index: t.uint8_t, + address_index: t.uint8_t, + lqi: t.uint8_t, + rssi: t.int8s, + timestamp: t.uint32_t, + message: t.LVBytes, + ) -> None: + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, + lqi=lqi, + rssi=rssi, + timestamp=None, + message=message, + ) + + def _handle_messageSentHandler( + self, + status: t.sl_Status, + message_type: t.EmberOutgoingMessageType, + destination: t.EmberNodeId, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + message: t.LVBytes, + ) -> None: + self._handle_message_sent( + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + status=status, + message_contents=message, + ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index ac0cf321..4868b7d4 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator, Iterable import logging import random -from typing import Any import voluptuous as vol import zigpy.state @@ -16,7 +15,6 @@ from . import commands, config from .. import protocol -from ..protocol import IdConflictEvent, RouteRecordEvent, TrustCenterJoinEvent LOGGER = logging.getLogger(__name__) @@ -239,77 +237,44 @@ async def set_extended_timeout( newExtendedTimeout=extended_timeout, ) - def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: - """Handle a parsed callback frame.""" - if frame_name == "incomingMessageHandler": - ( - message_type, - aps_frame, - lqi, - rssi, - sender, - binding_index, - address_index, - message, - ) = args - - self._handle_incoming_message( - message_type=message_type, - aps_frame=aps_frame, - sender=sender, - eui64=None, - binding_index=binding_index, - address_index=address_index, - lqi=lqi, - rssi=rssi, - timestamp=None, - message=message, - ) - elif frame_name == "messageSentHandler": - ( - message_type, - destination, - aps_frame, - message_tag, - status, - message, - ) = args - - self._handle_message_sent( - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - status=t.sl_Status.from_ember_status(status), - message_contents=message, - ) - elif frame_name == "trustCenterJoinHandler": - nwk, ieee, device_update_status, decision, parent_nwk = args - self.emit( - TrustCenterJoinEvent.event_type, - TrustCenterJoinEvent( - nwk=nwk, - ieee=ieee, - device_update_status=device_update_status, - decision=decision, - parent_nwk=parent_nwk, - ), - ) - elif frame_name == "incomingRouteRecordHandler": - nwk, ieee, lqi, rssi, relays = args - self.emit( - RouteRecordEvent.event_type, - RouteRecordEvent( - nwk=nwk, - ieee=ieee, - lqi=lqi, - rssi=rssi, - relays=relays, - ), - ) - elif frame_name == "idConflictHandler": - (nwk,) = args - self.emit( - IdConflictEvent.event_type, - IdConflictEvent(nwk=nwk), - ) + def _handle_incomingMessageHandler( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + lqi: t.uint8_t, + rssi: t.int8s, + sender: t.EmberNodeId, + binding_index: t.uint8_t, + address_index: t.uint8_t, + message: t.LVBytes, + ) -> None: + self._handle_incoming_message( + message_type=message_type, + aps_frame=aps_frame, + sender=sender, + eui64=None, + binding_index=binding_index, + address_index=address_index, + lqi=lqi, + rssi=rssi, + timestamp=None, + message=message, + ) + + def _handle_messageSentHandler( + self, + message_type: t.EmberOutgoingMessageType, + destination: t.EmberNodeId, + aps_frame: t.EmberApsFrame, + message_tag: t.uint8_t, + status: t.EmberStatus, + message: t.LVBytes, + ) -> None: + self._handle_message_sent( + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + status=t.sl_Status.from_ember_status(status), + message_contents=message, + ) From 71b479a42d6c4c2917efe1546057838f457f6e50 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:26:32 -0500 Subject: [PATCH 15/18] Remove unnecessary proxy methods --- bellows/ezsp/protocol.py | 21 --------------------- bellows/ezsp/v14/__init__.py | 18 +++++++++++------- bellows/ezsp/v4/__init__.py | 18 +++++++++++------- 3 files changed, 22 insertions(+), 35 deletions(-) diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 9821a14b..0ae93b8e 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -378,27 +378,6 @@ def _handle_incoming_message( ), ) - def _handle_message_sent( - self, - message_type: t.EmberOutgoingMessageType, - destination: t.uint16_t, - aps_frame: t.EmberApsFrame, - message_tag: t.uint8_t, - status: t.sl_Status, - message_contents: t.LVBytes, - ) -> None: - self.emit( - MessageSentEvent.event_type, - MessageSentEvent( - status=t.sl_Status.from_ember_status(status), - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - message_contents=message_contents, - ), - ) - def _handle_trustCenterJoinHandler( self, nwk: t.EmberNodeId, diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index 773342c3..dfba42e0 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -13,6 +13,7 @@ import bellows.types as t from . import commands, config +from ..protocol import MessageSentEvent from ..v13 import EZSPv13 LOGGER = logging.getLogger(__name__) @@ -184,11 +185,14 @@ def _handle_messageSentHandler( message_tag: t.uint8_t, message: t.LVBytes, ) -> None: - self._handle_message_sent( - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - status=status, - message_contents=message, + self.emit( + MessageSentEvent.event_type, + MessageSentEvent( + status=status, + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + message_contents=message, + ), ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index 4868b7d4..e2946818 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -15,6 +15,7 @@ from . import commands, config from .. import protocol +from ..protocol import MessageSentEvent LOGGER = logging.getLogger(__name__) @@ -270,11 +271,14 @@ def _handle_messageSentHandler( status: t.EmberStatus, message: t.LVBytes, ) -> None: - self._handle_message_sent( - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - status=t.sl_Status.from_ember_status(status), - message_contents=message, + self.emit( + MessageSentEvent.event_type, + MessageSentEvent( + status=t.sl_Status.from_ember_status(status), + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + message_contents=message, + ), ) From b5d5ceac716b6bc37b20d4935827e65b8700796c Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:32:09 -0500 Subject: [PATCH 16/18] Bring coverage up --- tests/test_application.py | 22 ++++++++++ tests/test_ezsp_protocol.py | 80 ++++++++++++++++++++++++++++++++++++- 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/tests/test_application.py b/tests/test_application.py index f19eede1..f1fa27da 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -2607,3 +2607,25 @@ def test_on_packet_received_multicast(app: ControllerApplication) -> None: # Verify packet_received was called with the same packet (dst already set) assert packet_received_mock.mock_calls == [call(event.packet)] + + +async def test_on_message_sent_via_binding(app: ControllerApplication) -> None: + """Test _on_message_sent with OUTGOING_VIA_BINDING message type.""" + # Create a pending request future + future = asyncio.get_running_loop().create_future() + app._pending_requests[(0x1234, 0x42)] = future + + event = MessageSentEvent( + status=t.sl_Status.OK, + message_type=t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING, + destination=0x1234, + aps_frame=t.EmberApsFrame(), + message_tag=0x42, + message_contents=b"test", + ) + + app._on_message_sent(event) + + # Verify the future was resolved + assert future.done() + assert future.result() == (t.sl_Status.OK, "message send success") diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index c05db28c..8550c2c5 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -6,7 +6,12 @@ import zigpy.types from bellows.ezsp import EZSP -from bellows.ezsp.protocol import PacketReceivedEvent +from bellows.ezsp.protocol import ( + IdConflictEvent, + PacketReceivedEvent, + RouteRecordEvent, + TrustCenterJoinEvent, +) import bellows.ezsp.v4 import bellows.ezsp.v9 from bellows.ezsp.v9.commands import GetTokenDataRsp @@ -423,3 +428,76 @@ def test_incoming_message_ignored_type(prot_hndl, caplog) -> None: # No event should be emitted for ignored message types assert len(handler.mock_calls) == 0 assert "Ignoring message type" in caplog.text + + +def test_trust_center_join_handler(prot_hndl) -> None: + """Test trustCenterJoinHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(TrustCenterJoinEvent.event_type, handler) + + ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") + prot_hndl.handle_parsed_callback( + "trustCenterJoinHandler", + { + "newNodeId": t.EmberNodeId(0x1234), + "newNodeEui64": ieee, + "status": t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + "policyDecision": t.EmberJoinDecision.NO_ACTION, + "parentOfNewNodeId": t.EmberNodeId(0x0000), + }.values(), + ) + + assert handler.mock_calls == [ + call( + TrustCenterJoinEvent( + nwk=t.EmberNodeId(0x1234), + ieee=ieee, + device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + decision=t.EmberJoinDecision.NO_ACTION, + parent_nwk=t.EmberNodeId(0x0000), + ) + ) + ] + + +def test_incoming_route_record_handler(prot_hndl) -> None: + """Test incomingRouteRecordHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(RouteRecordEvent.event_type, handler) + + ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") + prot_hndl.handle_parsed_callback( + "incomingRouteRecordHandler", + { + "source": t.EmberNodeId(0x1234), + "sourceEui": ieee, + "lastHopLqi": t.uint8_t(200), + "lastHopRssi": t.int8s(-40), + "relayList": [t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], + }.values(), + ) + + assert handler.mock_calls == [ + call( + RouteRecordEvent( + nwk=t.EmberNodeId(0x1234), + ieee=ieee, + lqi=t.uint8_t(200), + rssi=t.int8s(-40), + relays=[t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], + ) + ) + ] + + +def test_id_conflict_handler(prot_hndl) -> None: + """Test idConflictHandler callback.""" + handler = MagicMock() + prot_hndl.on_event(IdConflictEvent.event_type, handler) + + prot_hndl.handle_parsed_callback( + "idConflictHandler", + {"conflictingId": t.EmberNodeId(0x1234)}.values(), + ) + + assert handler.mock_calls == [call(IdConflictEvent(nwk=t.EmberNodeId(0x1234)))] From 370e08c8fba4a7f5ade8ea23ef0df31af24908a7 Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:37:14 -0500 Subject: [PATCH 17/18] Reduce diff size --- bellows/ezsp/__init__.py | 2 +- bellows/ezsp/protocol.py | 247 ++++----------- bellows/ezsp/v14/__init__.py | 52 ---- bellows/ezsp/v4/__init__.py | 47 --- bellows/zigbee/application.py | 285 +++++++++++------- tests/test_application.py | 550 +++++++++++++++++++--------------- tests/test_ezsp_protocol.py | 299 +++--------------- tests/test_ezsp_v14.py | 110 ------- 8 files changed, 582 insertions(+), 1010 deletions(-) diff --git a/bellows/ezsp/__init__.py b/bellows/ezsp/__init__.py index e9bbbd10..b5bd087f 100644 --- a/bellows/ezsp/__init__.py +++ b/bellows/ezsp/__init__.py @@ -365,7 +365,7 @@ def frame_received(self, data: bytes) -> None: try: self._protocol(data) except Exception: - LOGGER.warning("Failed to parse frame. This is a bug!", exc_info=True) + LOGGER.warning("Failed to parse frame, ignoring") async def get_board_info( self, diff --git a/bellows/ezsp/protocol.py b/bellows/ezsp/protocol.py index 0ae93b8e..30006dc4 100644 --- a/bellows/ezsp/protocol.py +++ b/bellows/ezsp/protocol.py @@ -5,16 +5,13 @@ from asyncio import timeout as asyncio_timeout import binascii from collections.abc import AsyncGenerator, Callable, Iterable -from dataclasses import dataclass import functools import logging import time -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any from zigpy.datastructures import PriorityDynamicBoundedSemaphore -from zigpy.event.event_base import EventBase import zigpy.state -import zigpy.types from bellows.config import CONF_EZSP_POLICIES from bellows.exception import InvalidCommandError @@ -30,62 +27,13 @@ MAX_COMMAND_CONCURRENCY = 1 -@dataclass(frozen=True, kw_only=True) -class MessageSentEvent: - event_type: Final[str] = "message_sent" - - status: t.sl_Status - message_type: t.EmberOutgoingMessageType - destination: t.uint16_t - aps_frame: t.EmberApsFrame - message_tag: t.uint8_t - message_contents: t.LVBytes - - -@dataclass(frozen=True, kw_only=True) -class PacketReceivedEvent: - event_type: Final[str] = "packet_received" - - packet: zigpy.types.ZigbeePacket - - -@dataclass(frozen=True, kw_only=True) -class TrustCenterJoinEvent: - event_type: Final[str] = "trust_center_join" - - nwk: t.EmberNodeId - ieee: t.EUI64 - device_update_status: t.EmberDeviceUpdate - decision: t.EmberJoinDecision - parent_nwk: t.EmberNodeId - - -@dataclass(frozen=True, kw_only=True) -class RouteRecordEvent: - event_type: Final[str] = "route_record" - - nwk: t.EmberNodeId - ieee: t.EUI64 - lqi: t.uint8_t - rssi: t.int8s - relays: t.LVList[t.EmberNodeId] - - -@dataclass(frozen=True, kw_only=True) -class IdConflictEvent: - event_type: Final[str] = "id_conflict" - - nwk: t.EmberNodeId - - -class ProtocolHandler(EventBase, abc.ABC): +class ProtocolHandler(abc.ABC): """EZSP protocol specific handler.""" COMMANDS = {} VERSION = None def __init__(self, cb_handler: Callable, gateway: Gateway) -> None: - super().__init__() self._handle_callback = cb_handler self._awaiting = {} self._gw = gateway @@ -231,6 +179,52 @@ def __call__(self, data: bytes) -> None: if data: LOGGER.debug("Frame contains trailing data: %s", data) + if ( + frame_name == "incomingMessageHandler" + and result[1].options & t.EmberApsOption.APS_OPTION_FRAGMENT + ): + # Extract received APS frame and sender + aps_frame = result[1] + sender = result[4] + + # The fragment count and index are encoded in the groupId field + fragment_count = (aps_frame.groupId >> 8) & 0xFF + fragment_index = aps_frame.groupId & 0xFF + + ( + complete, + reassembled, + frag_count, + frag_index, + ) = self._fragment_manager.handle_incoming_fragment( + sender_nwk=sender, + aps_sequence=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + fragment_count=fragment_count, + fragment_index=fragment_index, + payload=result[7], + ) + + ack_task = asyncio.create_task( + self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) + ) # APS Ack + + self._fragment_ack_tasks.add(ack_task) + ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) + + if not complete: + # Do not pass partial data up the stack + LOGGER.debug("Fragment reassembly not complete. waiting for more data.") + return + + # Replace partial data with fully reassembled data + result[7] = reassembled + + LOGGER.debug( + "Reassembled fragmented message. Proceeding with normal handling." + ) + if sequence in self._awaiting: expected_id, schema, future = self._awaiting.pop(sequence) try: @@ -252,20 +246,8 @@ def __call__(self, data: bytes) -> None: sequence, self.COMMANDS_BY_ID.get(expected_id, [expected_id])[0], ) - - return - - self.handle_parsed_callback(frame_name, result) - - # Legacy callback system for CLI tools - self._handle_callback(frame_name, result) - - def handle_parsed_callback(self, frame_name: str, args: list[Any]) -> None: - """Dispatch a callback frame to the appropriate handler method.""" - handler = getattr(self, f"_handle_{frame_name}", None) - - if handler is not None: - handler(*args) + else: + self._handle_callback(frame_name, result) async def _send_fragment_ack( self, @@ -293,135 +275,6 @@ async def _send_fragment_ack( status = await self.sendReply(sender, ackFrame, b"") return status[0] - def _handle_incoming_message( - self, - message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame, - sender: zigpy.types.NWK, - eui64: zigpy.types.EUI64 | None, - binding_index: t.uint8_t, - address_index: t.uint8_t, - lqi: t.uint8_t, - rssi: t.int8s, - timestamp: t.uint32_t | None, - message: t.LVBytes, - ) -> None: - """Handle incomingMessageHandler callback and maybe return a packet.""" - - if aps_frame.options & t.EmberApsOption.APS_OPTION_FRAGMENT: - fragment_count = (aps_frame.groupId >> 8) & 0xFF - fragment_index = aps_frame.groupId & 0xFF - - ( - complete, - reassembled, - frag_count, - frag_index, - ) = self._fragment_manager.handle_incoming_fragment( - sender_nwk=sender, - aps_sequence=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - fragment_count=fragment_count, - fragment_index=fragment_index, - payload=message, - ) - - ack_task = asyncio.create_task( - self._send_fragment_ack(sender, aps_frame, frag_count, frag_index) - ) - self._fragment_ack_tasks.add(ack_task) - ack_task.add_done_callback(lambda t: self._fragment_ack_tasks.discard(t)) - - if not complete: - LOGGER.debug("Fragment reassembly not complete, waiting for more data") - return - - LOGGER.debug("Reassembled fragmented message, proceeding with handling") - message = reassembled - - # Determine destination address based on message type - if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: - dst = zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, - address=aps_frame.groupId, - ) - elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: - dst = None # We don't know the current NWK here - else: - LOGGER.debug("Ignoring message type: %r", message_type) - return - - self.emit( - PacketReceivedEvent.event_type, - PacketReceivedEvent( - packet=zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(sender), - ), - src_ep=aps_frame.sourceEndpoint, - dst=dst, - dst_ep=aps_frame.destinationEndpoint, - tsn=aps_frame.sequence, - profile_id=aps_frame.profileId, - cluster_id=aps_frame.clusterId, - data=zigpy.types.SerializableBytes(message), - lqi=lqi, - rssi=rssi, - ) - ), - ) - - def _handle_trustCenterJoinHandler( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - device_update_status: t.EmberDeviceUpdate, - decision: t.EmberJoinDecision, - parent_nwk: t.EmberNodeId, - ) -> None: - self.emit( - TrustCenterJoinEvent.event_type, - TrustCenterJoinEvent( - nwk=nwk, - ieee=ieee, - device_update_status=device_update_status, - decision=decision, - parent_nwk=parent_nwk, - ), - ) - - def _handle_incomingRouteRecordHandler( - self, - nwk: t.EmberNodeId, - ieee: t.EUI64, - lqi: t.uint8_t, - rssi: t.int8s, - relays: t.LVList[t.EmberNodeId], - ) -> None: - self.emit( - RouteRecordEvent.event_type, - RouteRecordEvent( - nwk=nwk, - ieee=ieee, - lqi=lqi, - rssi=rssi, - relays=relays, - ), - ) - - def _handle_idConflictHandler(self, nwk: t.EmberNodeId) -> None: - self.emit( - IdConflictEvent.event_type, - IdConflictEvent(nwk=nwk), - ) - def __getattr__(self, name: str) -> Callable: if name not in self.COMMANDS: raise AttributeError(f"{name} not found in COMMANDS") diff --git a/bellows/ezsp/v14/__init__.py b/bellows/ezsp/v14/__init__.py index dfba42e0..16dbeec7 100644 --- a/bellows/ezsp/v14/__init__.py +++ b/bellows/ezsp/v14/__init__.py @@ -2,22 +2,17 @@ from __future__ import annotations from collections.abc import AsyncGenerator -import logging import voluptuous as vol from zigpy.exceptions import NetworkNotFormed import zigpy.state -import zigpy.types import bellows.config import bellows.types as t from . import commands, config -from ..protocol import MessageSentEvent from ..v13 import EZSPv13 -LOGGER = logging.getLogger(__name__) - class EZSPv14(EZSPv13): """EZSP Version 14 Protocol version handler.""" @@ -149,50 +144,3 @@ async def send_broadcast( ) return status, sequence - - def _handle_incomingMessageHandler( - self, - message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame, - sender: t.EmberNodeId, - eui64: t.EUI64, - binding_index: t.uint8_t, - address_index: t.uint8_t, - lqi: t.uint8_t, - rssi: t.int8s, - timestamp: t.uint32_t, - message: t.LVBytes, - ) -> None: - self._handle_incoming_message( - message_type=message_type, - aps_frame=aps_frame, - sender=sender, - eui64=None, - binding_index=binding_index, - address_index=address_index, - lqi=lqi, - rssi=rssi, - timestamp=None, - message=message, - ) - - def _handle_messageSentHandler( - self, - status: t.sl_Status, - message_type: t.EmberOutgoingMessageType, - destination: t.EmberNodeId, - aps_frame: t.EmberApsFrame, - message_tag: t.uint8_t, - message: t.LVBytes, - ) -> None: - self.emit( - MessageSentEvent.event_type, - MessageSentEvent( - status=status, - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - message_contents=message, - ), - ) diff --git a/bellows/ezsp/v4/__init__.py b/bellows/ezsp/v4/__init__.py index e2946818..3b454ecd 100644 --- a/bellows/ezsp/v4/__init__.py +++ b/bellows/ezsp/v4/__init__.py @@ -7,7 +7,6 @@ import voluptuous as vol import zigpy.state -import zigpy.types import bellows.config import bellows.types as t @@ -15,7 +14,6 @@ from . import commands, config from .. import protocol -from ..protocol import MessageSentEvent LOGGER = logging.getLogger(__name__) @@ -237,48 +235,3 @@ async def set_extended_timeout( newId=nwk, newExtendedTimeout=extended_timeout, ) - - def _handle_incomingMessageHandler( - self, - message_type: t.EmberIncomingMessageType, - aps_frame: t.EmberApsFrame, - lqi: t.uint8_t, - rssi: t.int8s, - sender: t.EmberNodeId, - binding_index: t.uint8_t, - address_index: t.uint8_t, - message: t.LVBytes, - ) -> None: - self._handle_incoming_message( - message_type=message_type, - aps_frame=aps_frame, - sender=sender, - eui64=None, - binding_index=binding_index, - address_index=address_index, - lqi=lqi, - rssi=rssi, - timestamp=None, - message=message, - ) - - def _handle_messageSentHandler( - self, - message_type: t.EmberOutgoingMessageType, - destination: t.EmberNodeId, - aps_frame: t.EmberApsFrame, - message_tag: t.uint8_t, - status: t.EmberStatus, - message: t.LVBytes, - ) -> None: - self.emit( - MessageSentEvent.event_type, - MessageSentEvent( - status=t.sl_Status.from_ember_status(status), - message_type=message_type, - destination=destination, - aps_frame=aps_frame, - message_tag=message_tag, - message_contents=message, - ), - ) diff --git a/bellows/zigbee/application.py b/bellows/zigbee/application.py index 51db4585..1bd0271c 100644 --- a/bellows/zigbee/application.py +++ b/bellows/zigbee/application.py @@ -2,7 +2,7 @@ import asyncio from asyncio import timeout as asyncio_timeout -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator from datetime import UTC, datetime import importlib.metadata import logging @@ -39,13 +39,6 @@ StackAlreadyRunning, ) import bellows.ezsp -from bellows.ezsp.protocol import ( - IdConflictEvent, - MessageSentEvent, - PacketReceivedEvent, - RouteRecordEvent, - TrustCenterJoinEvent, -) from bellows.ezsp.xncp import FirmwareFeatures import bellows.multicast import bellows.types as t @@ -104,7 +97,6 @@ def __init__(self, config: dict) -> None: self._multicast = None self._mfg_id_task: asyncio.Task | None = None self._pending_requests = {} - self._protocol_on_remove_callbacks: list[Callable[[], None]] = [] self._watchdog_failures = 0 self._watchdog_feed_counter = 0 @@ -248,8 +240,7 @@ async def start_network(self): for cnt_group in self.state.counters: cnt_group.reset() - self._subscribe_to_protocol_events() - + ezsp.add_callback(self.ezsp_callback_handler) self.controller_event.set() group_membership = {} @@ -611,52 +602,14 @@ async def reset_network_info(self): else: await self._ezsp.leaveNetwork() - def _unsubscribe_from_protocol_events(self) -> None: - """Unsubscribe from protocol events.""" - for callback in self._protocol_on_remove_callbacks: - callback() - - self._protocol_on_remove_callbacks.clear() - async def _reset(self): - self._unsubscribe_from_protocol_events() self._ezsp.stop_ezsp() await self._ezsp.startup_reset() await self._ezsp.write_config(self.config[CONF_EZSP_CONFIG]) - self._subscribe_to_protocol_events() - - def _subscribe_to_protocol_events(self) -> None: - """Subscribe to protocol-level events.""" - self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event( - PacketReceivedEvent.event_type, self._on_packet_received - ) - ) - self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event( - MessageSentEvent.event_type, self._on_message_sent - ) - ) - self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event( - TrustCenterJoinEvent.event_type, self._on_trust_center_join - ) - ) - self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event( - RouteRecordEvent.event_type, self._on_route_record - ) - ) - self._protocol_on_remove_callbacks.append( - self._ezsp._protocol.on_event( - IdConflictEvent.event_type, self._on_id_conflict - ) - ) async def disconnect(self): # TODO: how do you shut down the stack? self.controller_event.clear() - self._unsubscribe_from_protocol_events() if self._ezsp is not None: await self._ezsp.disconnect() self._ezsp = None @@ -666,60 +619,172 @@ async def force_remove(self, dev): # of the device itself. await self._ezsp.removeDevice(dev.nwk, dev.ieee, dev.ieee) - def _on_packet_received(self, message: PacketReceivedEvent) -> None: - """Handle packet_received event from protocol handler.""" - packet = message.packet - - # The protocol handler doesn't know our current NWK address - if packet.dst is None: - packet = packet.replace( - dst=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=self.state.node_info.nwk, - ) + def ezsp_callback_handler(self, frame_name, args): + LOGGER.debug("Received %s frame with %s", frame_name, args) + if frame_name == "incomingMessageHandler": + if self._ezsp.ezsp_version >= 14: + ( + message_type, + aps_frame, + nwk, + _eui64, + binding_index, + address_index, + lqi, + rssi, + _timestamp, + message, + ) = args + else: + ( + message_type, + aps_frame, + lqi, + rssi, + nwk, + binding_index, + address_index, + message, + ) = args + + self._handle_frame( + message_type=message_type, + aps_frame=aps_frame, + lqi=lqi, + rssi=rssi, + sender=nwk, + binding_index=binding_index, + address_index=address_index, + message=message, + ) + elif frame_name == "messageSentHandler": + if self._ezsp.ezsp_version >= 14: + ( + status, + message_type, + destination, + aps_frame, + message_tag, + message, + ) = args + else: + ( + message_type, + destination, + aps_frame, + message_tag, + status, + message, + ) = args + status = t.sl_Status.from_ember_status(status) + + self._handle_frame_sent( + message_type=message_type, + destination=destination, + aps_frame=aps_frame, + message_tag=message_tag, + status=status, + message=message, + ) + elif frame_name == "trustCenterJoinHandler": + self._handle_tc_join_handler(*args) + elif frame_name == "incomingRouteRecordHandler": + self.handle_route_record(*args) + elif frame_name == "incomingRouteErrorHandler": + status, nwk = args + status = t.sl_Status.from_ember_status(status) + self.handle_route_error(status, nwk) + elif frame_name == "idConflictHandler": + self._handle_id_conflict(*args) + + def _handle_frame( + self, + message_type: t.EmberIncomingMessageType, + aps_frame: t.EmberApsFrame, + lqi: t.uint8_t, + rssi: t.int8s, + sender: t.EmberNodeId, + binding_index: t.uint8_t, + address_index: t.uint8_t, + message: bytes, + ) -> None: + if message_type == t.EmberIncomingMessageType.INCOMING_BROADCAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Broadcast, + address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, ) - - if packet.dst.addr_mode == zigpy.types.AddrMode.NWK: - self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() - elif packet.dst.addr_mode == zigpy.types.AddrMode.Broadcast: self.state.counters[COUNTERS_CTRL][COUNTER_RX_BCAST].increment() - elif packet.dst.addr_mode == zigpy.types.AddrMode.Group: + elif message_type == t.EmberIncomingMessageType.INCOMING_MULTICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.Group, address=aps_frame.groupId + ) self.state.counters[COUNTERS_CTRL][COUNTER_RX_MCAST].increment() + elif message_type == t.EmberIncomingMessageType.INCOMING_UNICAST: + dst = zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, address=self.state.node_info.nwk + ) + self.state.counters[COUNTERS_CTRL][COUNTER_RX_UNICAST].increment() + else: + LOGGER.debug("Ignoring message type: %r", message_type) + return - self.packet_received(packet) + self.packet_received( + zigpy.types.ZigbeePacket( + src=zigpy.types.AddrModeAddress( + addr_mode=zigpy.types.AddrMode.NWK, + address=sender, + ), + src_ep=aps_frame.sourceEndpoint, + dst=dst, + dst_ep=aps_frame.destinationEndpoint, + tsn=aps_frame.sequence, + profile_id=aps_frame.profileId, + cluster_id=aps_frame.clusterId, + data=zigpy.types.SerializableBytes(message), + lqi=lqi, + rssi=rssi, + ) + ) - def _on_message_sent(self, event: MessageSentEvent) -> None: - """Handle message_sent event from protocol handler.""" - if event.status == t.sl_Status.OK: + def _handle_frame_sent( + self, + message_type: t.EmberIncomingMessageType, + destination: t.EmberNodeId, + aps_frame: t.EmberApsFrame, + message_tag: int, + status: t.sl_Status, + message: bytes, + ): + if status == t.sl_Status.OK: msg = "success" else: msg = "failure" - if event.message_type in ( + if message_type in ( t.EmberOutgoingMessageType.OUTGOING_BROADCAST, t.EmberOutgoingMessageType.OUTGOING_BROADCAST_WITH_ALIAS, ): cnt_name = f"broadcast_tx_{msg}" - elif event.message_type in ( + elif message_type in ( t.EmberOutgoingMessageType.OUTGOING_MULTICAST, t.EmberOutgoingMessageType.OUTGOING_MULTICAST_WITH_ALIAS, ): cnt_name = f"multicast_tx_{msg}" - elif event.message_type in ( + elif message_type in ( t.EmberOutgoingMessageType.OUTGOING_DIRECT, t.EmberOutgoingMessageType.OUTGOING_VIA_ADDRESS_TABLE, ): cnt_name = f"unicast_tx_{msg}" - elif event.message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: + elif message_type == t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING: cnt_name = f"via_binding_tx_{msg}" else: cnt_name = f"unknown_msg_type_{msg}" - pending_tag = (event.destination, event.message_tag) + pending_tag = (destination, message_tag) try: future = self._pending_requests[pending_tag] - future.set_result((event.status, f"message send {msg}")) + future.set_result((status, f"message send {msg}")) self.state.counters[COUNTERS_CTRL][cnt_name].increment() except KeyError: self.state.counters[COUNTERS_CTRL][f"{cnt_name}_unexpected"].increment() @@ -735,31 +800,44 @@ def _on_message_sent(self, event: MessageSentEvent) -> None: exc, ) - def _on_trust_center_join(self, event: TrustCenterJoinEvent) -> None: - """Handle trust_center_join event from protocol handler.""" - if event.device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: - self.handle_leave(event.nwk, event.ieee) + async def _handle_no_such_device(self, sender: int) -> None: + """Try to match unknown device by its EUI64 address.""" + status, ieee = await self._ezsp.lookupEui64ByNodeId(nodeId=sender) + status = t.sl_Status.from_ember_status(status) + + if status == t.sl_Status.OK: + LOGGER.debug("Found %s ieee for %s sender", ieee, sender) + self.handle_join(sender, ieee, 0) return + LOGGER.debug("Couldn't look up ieee for %s", sender) - if ( - event.device_update_status - == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN - ): - self.create_task( - self.cleanup_tc_link_key(event.ieee), "cleanup_tc_link_key" - ) + def _handle_tc_join_handler( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + device_update_status: t.EmberDeviceUpdate, + decision: t.EmberJoinDecision, + parent_nwk: t.EmberNodeId, + ) -> None: + """Trust Center Join handler.""" + if device_update_status == t.EmberDeviceUpdate.DEVICE_LEFT: + self.handle_leave(nwk, ieee) + return + + if device_update_status == t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN: + self.create_task(self.cleanup_tc_link_key(ieee), "cleanup_tc_link_key") - if event.decision == t.EmberJoinDecision.DENY_JOIN: + if decision == t.EmberJoinDecision.DENY_JOIN: # no point in handling the join if it was denied return - mfg_id = IEEE_PREFIX_MFG_ID.get(str(event.ieee)[:8].upper()) + mfg_id = IEEE_PREFIX_MFG_ID.get(str(ieee)[:8].upper()) if mfg_id is not None: if self._mfg_id_task and not self._mfg_id_task.done(): self._mfg_id_task.cancel() self._mfg_id_task = asyncio.create_task(self._reset_mfg_id(mfg_id)) - self.handle_join(event.nwk, event.ieee, event.parent_nwk) + self.handle_join(nwk, ieee, parent_nwk) async def _reset_mfg_id(self, mfg_id: int) -> None: """Resets manufacturer id if was temporary overridden by a joining device.""" @@ -1053,21 +1131,20 @@ async def permit_with_link_key( return await super().permit(time_s) - def _on_id_conflict(self, event: IdConflictEvent) -> None: - """Handle id_conflict event from protocol handler.""" - LOGGER.warning("NWK conflict is reported for 0x%04x", event.nwk) + def _handle_id_conflict(self, nwk: t.EmberNodeId) -> None: + LOGGER.warning("NWK conflict is reported for 0x%04x", nwk) self.state.counters[COUNTERS_CTRL][COUNTER_NWK_CONFLICTS].increment() for device in self.devices.values(): - if device.nwk != event.nwk: + if device.nwk != nwk: continue LOGGER.warning( "Found %s device for 0x%04x NWK conflict: %s %s", device.ieee, - event.nwk, + nwk, device.manufacturer, device.model, ) - self.handle_leave(event.nwk, device.ieee) + self.handle_leave(nwk, device.ieee) async def _watchdog_loop(self): self._watchdog_failures = 0 @@ -1128,10 +1205,18 @@ async def _get_free_buffers(self) -> int | None: LOGGER.debug("Free buffers status %s, value: %s", status, buffers) return buffers - def _on_route_record(self, event: RouteRecordEvent) -> None: - """Handle route_record event from protocol handler.""" + def handle_route_record( + self, + nwk: t.EmberNodeId, + ieee: t.EUI64, + lqi: t.uint8_t, + rssi: t.int8s, + relays: t.LVList[t.EmberNodeId], + ) -> None: LOGGER.debug( - "Processing route record request: %s", - (event.nwk, event.ieee, event.lqi, event.rssi, event.relays), + "Processing route record request: %s", (nwk, ieee, lqi, rssi, relays) ) - self.handle_relays(nwk=event.nwk, relays=event.relays) + self.handle_relays(nwk=nwk, relays=relays) + + def handle_route_error(self, status: t.sl_Status, nwk: t.EmberNodeId) -> None: + LOGGER.debug("Processing route error: status=%s, nwk=%s", status, nwk) diff --git a/tests/test_application.py b/tests/test_application.py index f1fa27da..018dd632 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -16,13 +16,6 @@ import bellows.config as config from bellows.exception import ControllerError, EzspError, InvalidCommandError import bellows.ezsp as ezsp -from bellows.ezsp.protocol import ( - IdConflictEvent, - MessageSentEvent, - PacketReceivedEvent, - RouteRecordEvent, - TrustCenterJoinEvent, -) from bellows.ezsp.v9.commands import GetTokenDataRsp from bellows.ezsp.xncp import ( FirmwareFeatures, @@ -78,9 +71,6 @@ def inner(config, send_timeout: float = 0.05, **kwargs): app.handle_message = MagicMock() app.packet_received = MagicMock() - # Set up event subscriptions normally done in start_network() - app._subscribe_to_protocol_events() - return app return inner @@ -427,18 +417,215 @@ async def test_startup_no_board_info(app, ieee, caplog): assert "EZSP Radio does not support getMfgToken command" in caplog.text +@pytest.fixture +def aps_frame(): + return t.EmberApsFrame( + profileId=0x1234, + clusterId=0x5678, + sourceEndpoint=0x9A, + destinationEndpoint=0xBC, + options=t.EmberApsOption.APS_OPTION_NONE, + groupId=0x0000, + sequence=0xDE, + ) + + +def _handle_incoming_aps_frame(app, aps_frame, type): + app.ezsp_callback_handler( + "incomingMessageHandler", + list( + dict( + type=type, + apsFrame=aps_frame, + lastHopLqi=123, + lastHopRssi=-45, + sender=0xABCD, + bindingIndex=56, + addressIndex=78, + message=b"test message", + ).values() + ), + ) + + +def test_frame_handler_unicast(app, aps_frame): + _handle_incoming_aps_frame( + app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_UNICAST + ) + assert app.packet_received.call_count == 1 + + packet = app.packet_received.mock_calls[0].args[0] + assert packet.profile_id == 0x1234 + assert packet.cluster_id == 0x5678 + assert packet.src_ep == 0x9A + assert packet.dst_ep == 0xBC + assert packet.tsn == 0xDE + assert packet.src.addr_mode == zigpy_t.AddrMode.NWK + assert packet.src.address == 0xABCD + assert packet.dst.addr_mode == zigpy_t.AddrMode.NWK + assert packet.dst.address == app.state.node_info.nwk + assert packet.data.serialize() == b"test message" + assert packet.lqi == 123 + assert packet.rssi == -45 + + assert ( + app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ + bellows.zigbee.application.COUNTER_RX_UNICAST + ] + == 1 + ) + + +def test_frame_handler_broadcast(app, aps_frame): + _handle_incoming_aps_frame( + app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST + ) + assert app.packet_received.call_count == 1 + + packet = app.packet_received.mock_calls[0].args[0] + assert packet.profile_id == 0x1234 + assert packet.cluster_id == 0x5678 + assert packet.src_ep == 0x9A + assert packet.dst_ep == 0xBC + assert packet.tsn == 0xDE + assert packet.src.addr_mode == zigpy_t.AddrMode.NWK + assert packet.src.address == 0xABCD + assert packet.dst.addr_mode == zigpy_t.AddrMode.Broadcast + assert packet.dst.address == zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR + assert packet.data.serialize() == b"test message" + assert packet.lqi == 123 + assert packet.rssi == -45 + + assert ( + app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ + bellows.zigbee.application.COUNTER_RX_BCAST + ] + == 1 + ) + + +def test_frame_handler_multicast(app, aps_frame): + aps_frame.groupId = 0xEF12 + _handle_incoming_aps_frame( + app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_MULTICAST + ) + + assert app.packet_received.call_count == 1 + + packet = app.packet_received.mock_calls[0].args[0] + assert packet.profile_id == 0x1234 + assert packet.cluster_id == 0x5678 + assert packet.src_ep == 0x9A + assert packet.dst_ep == 0xBC + assert packet.tsn == 0xDE + assert packet.src.addr_mode == zigpy_t.AddrMode.NWK + assert packet.src.address == 0xABCD + assert packet.dst.addr_mode == zigpy_t.AddrMode.Group + assert packet.dst.address == 0xEF12 + assert packet.data.serialize() == b"test message" + assert packet.lqi == 123 + assert packet.rssi == -45 + + assert ( + app.state.counters[bellows.zigbee.application.COUNTERS_CTRL][ + bellows.zigbee.application.COUNTER_RX_MCAST + ] + == 1 + ) + + +def test_frame_handler_ignored(app, aps_frame): + _handle_incoming_aps_frame( + app, aps_frame, type=t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK + ) + assert app.packet_received.call_count == 0 + + +@pytest.mark.parametrize( + "msg_type", + ( + t.EmberIncomingMessageType.INCOMING_BROADCAST, + t.EmberIncomingMessageType.INCOMING_MULTICAST, + t.EmberIncomingMessageType.INCOMING_UNICAST, + 0xFF, + ), +) +async def test_send_failure(app, aps, ieee, msg_type): + fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() + app.ezsp_callback_handler( + "messageSentHandler", [msg_type, 0xBEED, aps, 254, t.EmberStatus.SUCCESS, b""] + ) + assert fut.result() == (t.sl_Status.OK, "message send success") + + +async def test_dup_send_failure(app, aps, ieee): + fut = app._pending_requests[(0xBEED, 254)] = asyncio.Future() + fut.set_result("Already set") + + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, + 0xBEED, + aps, + 254, + sentinel.status, + b"", + ], + ) + + +def test_send_failure_unexpected(app, aps, ieee): + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_BROADCAST_LOOPBACK, + 0xBEED, + aps, + 257, + 1, + b"", + ], + ) + + +async def test_send_success(app, aps, ieee): + fut = app._pending_requests[(0xBEED, 253)] = asyncio.Future() + app.ezsp_callback_handler( + "messageSentHandler", + [ + t.EmberIncomingMessageType.INCOMING_MULTICAST_LOOPBACK, + 0xBEED, + aps, + 253, + t.EmberStatus.SUCCESS, + b"", + ], + ) + + assert fut.result() == (t.sl_Status.OK, "message send success") + + +def test_unexpected_send_success(app, aps, ieee): + app.ezsp_callback_handler( + "messageSentHandler", + [t.EmberIncomingMessageType.INCOMING_MULTICAST, 0xBEED, aps, 253, 0, b""], + ) + + async def test_join_handler(app, ieee): # Calls device.initialize, leaks a task app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app._on_trust_center_join( - TrustCenterJoinEvent( - nwk=1, - ieee=ieee, - device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - decision=t.EmberJoinDecision.NO_ACTION, - parent_nwk=sentinel.parent, - ) + app.ezsp_callback_handler( + "trustCenterJoinHandler", + [ + 1, + ieee, + t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + t.EmberJoinDecision.NO_ACTION, + sentinel.parent, + ], ) await asyncio.sleep(0) assert ieee not in app.devices @@ -452,14 +639,15 @@ async def test_join_handler(app, ieee): # cleanup TCLK, but no join handling app.handle_join.reset_mock() app.cleanup_tc_link_key.reset_mock() - app._on_trust_center_join( - TrustCenterJoinEvent( - nwk=1, - ieee=ieee, - device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - decision=t.EmberJoinDecision.DENY_JOIN, - parent_nwk=sentinel.parent, - ) + app.ezsp_callback_handler( + "trustCenterJoinHandler", + [ + 1, + ieee, + t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + t.EmberJoinDecision.DENY_JOIN, + sentinel.parent, + ], ) await asyncio.sleep(0) assert app.cleanup_tc_link_key.await_count == 1 @@ -470,14 +658,8 @@ async def test_join_handler(app, ieee): def test_leave_handler(app, ieee): app.handle_join = MagicMock() app.devices[ieee] = MagicMock() - app._on_trust_center_join( - TrustCenterJoinEvent( - nwk=1, - ieee=ieee, - device_update_status=t.EmberDeviceUpdate.DEVICE_LEFT, - decision=t.EmberJoinDecision.NO_ACTION, - parent_nwk=t.EmberNodeId(0x0000), - ) + app.ezsp_callback_handler( + "trustCenterJoinHandler", [1, ieee, t.EmberDeviceUpdate.DEVICE_LEFT, None, None] ) assert ieee in app.devices assert app.handle_join.call_count == 0 @@ -557,7 +739,7 @@ async def test_request_concurrency_duplicate_failure( ) -> None: def send_unicast(aps_frame, data, message_tag, nwk): asyncio.get_running_loop().call_soon( - app._ezsp._protocol.handle_parsed_callback, + app.ezsp_callback_handler, "messageSentHandler", list( dict( @@ -609,7 +791,7 @@ async def _test_send_packet_unicast( def send_unicast(*args, **kwargs): asyncio.get_running_loop().call_later( 0.01, - app._ezsp._protocol.handle_parsed_callback, + app.ezsp_callback_handler, "messageSentHandler", list( dict( @@ -741,14 +923,15 @@ async def test_send_packet_unicast_extended_timeout_with_acks(app, ieee, packet) asyncio.get_running_loop().call_later( 0.1, - app._on_route_record, - RouteRecordEvent( - nwk=packet.dst.address, - ieee=ieee, - lqi=123, - rssi=-60, - relays=[0x1234], - ), + app.ezsp_callback_handler, + "incomingRouteRecordHandler", + { + "source": packet.dst.address, + "sourceEui": ieee, + "lastHopLqi": 123, + "lastHopRssi": -60, + "relayList": [0x1234], + }.values(), ) await _test_send_packet_unicast( @@ -770,14 +953,15 @@ async def test_send_packet_unicast_extended_timeout_without_acks(app, ieee, pack asyncio.get_running_loop().call_later( 0.1, - app._on_route_record, - RouteRecordEvent( - nwk=packet.dst.address, - ieee=ieee, - lqi=123, - rssi=-60, - relays=[0x1234], - ), + app.ezsp_callback_handler, + "incomingRouteRecordHandler", + { + "source": packet.dst.address, + "sourceEui": ieee, + "lastHopLqi": 123, + "lastHopRssi": -60, + "relayList": [0x1234], + }.values(), ) await _test_send_packet_unicast( @@ -861,7 +1045,7 @@ async def send_message_sent_reply( await asyncio.sleep(0.01) - app._ezsp._protocol.handle_parsed_callback( + app.ezsp_callback_handler( "messageSentHandler", list( dict( @@ -918,7 +1102,7 @@ async def test_send_packet_broadcast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol.handle_parsed_callback, + app.ezsp_callback_handler, "messageSentHandler", list( dict( @@ -964,7 +1148,7 @@ async def test_send_packet_broadcast_ignored_delivery_failure(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol.handle_parsed_callback, + app.ezsp_callback_handler, "messageSentHandler", list( dict( @@ -1017,7 +1201,7 @@ async def test_send_packet_multicast(app, packet): app.get_sequence = MagicMock(return_value=sentinel.msg_tag) asyncio.get_running_loop().call_soon( - app._ezsp._protocol.handle_parsed_callback, + app.ezsp_callback_handler, "messageSentHandler", list( dict( @@ -1386,18 +1570,20 @@ def test_coordinator_model_manuf(coordinator): def test_handle_route_record(app): """Test route record handling for an existing device.""" app.handle_relays = MagicMock(spec_set=app.handle_relays) - app._on_route_record( - RouteRecordEvent( - nwk=sentinel.nwk, - ieee=sentinel.ieee, - lqi=sentinel.lqi, - rssi=sentinel.rssi, - relays=sentinel.relays, - ) + app.ezsp_callback_handler( + "incomingRouteRecordHandler", + [sentinel.nwk, sentinel.ieee, sentinel.lqi, sentinel.rssi, sentinel.relays], ) - assert app.handle_relays.mock_calls == [ - call(nwk=sentinel.nwk, relays=sentinel.relays) - ] + app.handle_relays.assert_called_once_with(nwk=sentinel.nwk, relays=sentinel.relays) + + +def test_handle_route_error(app): + """Test route error handler.""" + app.handle_relays = MagicMock(spec_set=app.handle_relays) + app.ezsp_callback_handler( + "incomingRouteErrorHandler", [sentinel.status, sentinel.nwk] + ) + app.handle_relays.assert_not_called() def test_handle_id_conflict(app, ieee): @@ -1406,14 +1592,43 @@ def test_handle_id_conflict(app, ieee): app.add_device(ieee, nwk) app.handle_leave = MagicMock() - app._on_id_conflict(IdConflictEvent(nwk=nwk + 1)) + app.ezsp_callback_handler("idConflictHandler", [nwk + 1]) assert app.handle_leave.call_count == 0 - app._on_id_conflict(IdConflictEvent(nwk=nwk)) + app.ezsp_callback_handler("idConflictHandler", [nwk]) assert app.handle_leave.call_count == 1 assert app.handle_leave.call_args[0][0] == nwk +async def test_handle_no_such_device(app, ieee): + """Test handling of an unknown device IEEE lookup.""" + + app._ezsp.lookupEui64ByNodeId = AsyncMock() + + p1 = patch.object( + app._ezsp, + "lookupEui64ByNodeId", + AsyncMock(return_value=(t.EmberStatus.ERR_FATAL, ieee)), + ) + p2 = patch.object(app, "handle_join") + with p1 as lookup_mock, p2 as handle_join_mock: + await app._handle_no_such_device(sentinel.nwk) + assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] + assert handle_join_mock.call_count == 0 + + p1 = patch.object( + app._ezsp, + "lookupEui64ByNodeId", + AsyncMock(return_value=(t.EmberStatus.SUCCESS, sentinel.ieee)), + ) + with p1 as lookup_mock, p2 as handle_join_mock: + await app._handle_no_such_device(sentinel.nwk) + assert lookup_mock.mock_calls == [call(nodeId=sentinel.nwk)] + assert handle_join_mock.call_count == 1 + assert handle_join_mock.call_args[0][0] == sentinel.nwk + assert handle_join_mock.call_args[0][1] == sentinel.ieee + + async def test_cleanup_tc_link_key(app): """Test cleaning up tc link key.""" ezsp = app._ezsp @@ -1458,24 +1673,26 @@ async def test_set_mfg_id(ieee, expected_mfg_id, app): app.handle_join = MagicMock() app.cleanup_tc_link_key = AsyncMock() - app._on_trust_center_join( - TrustCenterJoinEvent( - nwk=1, - ieee=t.EUI64.convert(ieee), - device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - decision=t.EmberJoinDecision.NO_ACTION, - parent_nwk=sentinel.parent, - ) + app.ezsp_callback_handler( + "trustCenterJoinHandler", + [ + 1, + t.EUI64.convert(ieee), + t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + t.EmberJoinDecision.NO_ACTION, + sentinel.parent, + ], ) # preempt - app._on_trust_center_join( - TrustCenterJoinEvent( - nwk=1, - ieee=t.EUI64.convert(ieee), - device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - decision=t.EmberJoinDecision.NO_ACTION, - parent_nwk=sentinel.parent, - ) + app.ezsp_callback_handler( + "trustCenterJoinHandler", + [ + 1, + t.EUI64.convert(ieee), + t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, + t.EmberJoinDecision.NO_ACTION, + sentinel.parent, + ], ) await asyncio.sleep(0.20) if expected_mfg_id is not None: @@ -2466,166 +2683,3 @@ async def test_set_tx_power(app: ControllerApplication) -> None: assert result == 12.0 assert app._ezsp.setRadioPower.mock_calls == [call(power=12)] assert mock_update.mock_calls == [call(app._ezsp, tx_power=12)] - - -async def test_reset_resubscribes_events(app: ControllerApplication) -> None: - """Test that _reset unsubscribes, resets, and resubscribes to protocol events.""" - app._ezsp.stop_ezsp = MagicMock() - app._ezsp.startup_reset = AsyncMock() - app._ezsp.write_config = AsyncMock() - - # Add a dummy callback to verify unsubscribe is called - unsubscribe_mock = MagicMock() - app._protocol_on_remove_callbacks.append(unsubscribe_mock) - - await app._reset() - - # Verify unsubscribe was called - assert unsubscribe_mock.mock_calls == [call()] - - # Verify EZSP reset sequence - assert len(app._ezsp.stop_ezsp.mock_calls) == 1 - assert len(app._ezsp.startup_reset.mock_calls) == 1 - assert len(app._ezsp.write_config.mock_calls) == 1 - - # Verify we resubscribed (callbacks list should have 5 entries now) - assert len(app._protocol_on_remove_callbacks) == 5 - - -def test_on_packet_received_unicast(app: ControllerApplication) -> None: - """Test _on_packet_received with unicast message (dst=None gets replaced).""" - app.state.node_info.nwk = zigpy_t.NWK(0x0000) - - packet_received_mock = MagicMock() - app.packet_received = packet_received_mock - - # Unicast packets have dst=None, protocol handler doesn't know our NWK - event = PacketReceivedEvent( - packet=zigpy_t.ZigbeePacket( - src=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.NWK, - address=zigpy_t.NWK(0x1234), - ), - src_ep=1, - dst=None, # Will be replaced with our NWK - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy_t.SerializableBytes(b"test"), - lqi=200, - rssi=-40, - ) - ) - - app._on_packet_received(event) - - # Verify packet_received was called with dst replaced - assert packet_received_mock.mock_calls == [ - call( - zigpy_t.ZigbeePacket( - src=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.NWK, - address=zigpy_t.NWK(0x1234), - ), - src_ep=1, - dst=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.NWK, - address=zigpy_t.NWK(0x0000), - ), - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy_t.SerializableBytes(b"test"), - lqi=200, - rssi=-40, - ) - ) - ] - - -def test_on_packet_received_broadcast(app: ControllerApplication) -> None: - """Test _on_packet_received with broadcast message.""" - packet_received_mock = MagicMock() - app.packet_received = packet_received_mock - - event = PacketReceivedEvent( - packet=zigpy_t.ZigbeePacket( - src=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.NWK, - address=zigpy_t.NWK(0x1234), - ), - src_ep=1, - dst=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.Broadcast, - address=zigpy_t.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ), - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy_t.SerializableBytes(b"broadcast"), - lqi=200, - rssi=-40, - ) - ) - - app._on_packet_received(event) - - # Verify packet_received was called with the same packet (dst already set) - assert packet_received_mock.mock_calls == [call(event.packet)] - - -def test_on_packet_received_multicast(app: ControllerApplication) -> None: - """Test _on_packet_received with multicast message.""" - packet_received_mock = MagicMock() - app.packet_received = packet_received_mock - - event = PacketReceivedEvent( - packet=zigpy_t.ZigbeePacket( - src=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.NWK, - address=zigpy_t.NWK(0x1234), - ), - src_ep=1, - dst=zigpy_t.AddrModeAddress( - addr_mode=zigpy_t.AddrMode.Group, - address=0x5678, - ), - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy_t.SerializableBytes(b"multicast"), - lqi=200, - rssi=-40, - ) - ) - - app._on_packet_received(event) - - # Verify packet_received was called with the same packet (dst already set) - assert packet_received_mock.mock_calls == [call(event.packet)] - - -async def test_on_message_sent_via_binding(app: ControllerApplication) -> None: - """Test _on_message_sent with OUTGOING_VIA_BINDING message type.""" - # Create a pending request future - future = asyncio.get_running_loop().create_future() - app._pending_requests[(0x1234, 0x42)] = future - - event = MessageSentEvent( - status=t.sl_Status.OK, - message_type=t.EmberOutgoingMessageType.OUTGOING_VIA_BINDING, - destination=0x1234, - aps_frame=t.EmberApsFrame(), - message_tag=0x42, - message_contents=b"test", - ) - - app._on_message_sent(event) - - # Verify the future was resolved - assert future.done() - assert future.result() == (t.sl_Status.OK, "message send success") diff --git a/tests/test_ezsp_protocol.py b/tests/test_ezsp_protocol.py index 8550c2c5..3906eb5e 100644 --- a/tests/test_ezsp_protocol.py +++ b/tests/test_ezsp_protocol.py @@ -3,15 +3,8 @@ from unittest.mock import AsyncMock, MagicMock, call, patch import pytest -import zigpy.types from bellows.ezsp import EZSP -from bellows.ezsp.protocol import ( - IdConflictEvent, - PacketReceivedEvent, - RouteRecordEvent, - TrustCenterJoinEvent, -) import bellows.ezsp.v4 import bellows.ezsp.v9 from bellows.ezsp.v9.commands import GetTokenDataRsp @@ -213,9 +206,9 @@ async def test_incoming_fragmented_message_incomplete(prot_hndl, caplog): len(prot_hndl._fragment_ack_tasks) == 0 ), "Done callback should have removed task" - assert len(prot_hndl._handle_callback.mock_calls) == 1 - assert "Fragment reassembly not complete, waiting for more data" in caplog.text - assert mock_ack.mock_calls == [call(sender, aps_frame, 2, 0)] + prot_hndl._handle_callback.assert_not_called() + assert "Fragment reassembly not complete. waiting for more data." in caplog.text + mock_ack.assert_called_once_with(sender, aps_frame, 2, 0) async def test_incoming_fragmented_message_complete(prot_hndl, caplog): @@ -228,34 +221,27 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): b"\x90\x01\x45\x00\x04\x01\x01\xff\x02\x02\x40\x81\x01\x02\xee\xff\xf8\x6f\x1d\xff\xff\x07" + b"message" ) # fragment index 1 + sender = 0x1D6F aps_frame_1 = t.EmberApsFrame( profileId=260, - clusterId=0xFF01, + clusterId=65281, sourceEndpoint=2, destinationEndpoint=2, - options=( - t.EmberApsOption.APS_OPTION_RETRY - | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY - | t.EmberApsOption.APS_OPTION_FRAGMENT - ), - groupId=0x0200, # fragment_count=2, fragment_index=0 + options=33088, # Includes APS_OPTION_FRAGMENT + groupId=512, # fragment_count=2, fragment_index=0 sequence=238, ) - aps_frame_2 = t.EmberApsFrame( profileId=260, - clusterId=0xFF01, + clusterId=65281, sourceEndpoint=2, destinationEndpoint=2, - options=( - t.EmberApsOption.APS_OPTION_RETRY - | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY - | t.EmberApsOption.APS_OPTION_FRAGMENT - ), - groupId=0x0201, # fragment_count=2, fragment_index=1 + options=33088, + groupId=513, # fragment_count=2, fragment_index=1 sequence=238, ) + reassembled = b"complete message" with patch.object(prot_hndl, "_send_fragment_ack", new=AsyncMock()) as mock_ack: mock_ack.return_value = None @@ -264,240 +250,43 @@ async def test_incoming_fragmented_message_complete(prot_hndl, caplog): # Packet 1 prot_hndl(packet1) assert len(prot_hndl._fragment_ack_tasks) == 1 - await asyncio.gather( - *prot_hndl._fragment_ack_tasks - ) # Ensure task completes and triggers callback - assert len(prot_hndl._fragment_ack_tasks) == 0 + ack_task = next(iter(prot_hndl._fragment_ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert ( + len(prot_hndl._fragment_ack_tasks) == 0 + ), "Done callback should have removed task" + + prot_hndl._handle_callback.assert_not_called() + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + not in caplog.text + ) + mock_ack.assert_called_with(sender, aps_frame_1, 2, 0) # Packet 2 prot_hndl(packet2) assert len(prot_hndl._fragment_ack_tasks) == 1 - await asyncio.gather( - *prot_hndl._fragment_ack_tasks - ) # Ensure task completes and triggers callback - assert len(prot_hndl._fragment_ack_tasks) == 0 - - assert "Reassembled fragmented message, proceeding with handling" in caplog.text - assert mock_ack.mock_calls == [ - call(0x1D6F, aps_frame_1, 2, 0), - call(0x1D6F, aps_frame_2, 2, 1), - ] - - -def test_incoming_message_broadcast(prot_hndl) -> None: - """Test handling of incoming broadcast message.""" - handler = MagicMock() - prot_hndl.on_event(PacketReceivedEvent.event_type, handler) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0x42, - ) - - # v4 field order: type, apsFrame, lqi, rssi, sender, bindingIndex, addressIndex, message - prot_hndl.handle_parsed_callback( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_BROADCAST, - aps_frame, - 200, # lqi - -40, # rssi - t.EmberNodeId(0x1234), # sender - 0, # binding_index - 0, # address_index - b"broadcast message", - ], - ) - - assert handler.mock_calls == [ - call( - PacketReceivedEvent( - packet=zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(0x1234), - ), - src_ep=1, - dst=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Broadcast, - address=zigpy.types.BroadcastAddress.ALL_ROUTERS_AND_COORDINATOR, - ), - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy.types.SerializableBytes(b"broadcast message"), - lqi=200, - rssi=-40, - ) - ) - ) - ] - - -def test_incoming_message_multicast(prot_hndl) -> None: - """Test handling of incoming multicast message.""" - handler = MagicMock() - prot_hndl.on_event(PacketReceivedEvent.event_type, handler) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x5678, - sequence=0x42, - ) - - prot_hndl.handle_parsed_callback( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_MULTICAST, - aps_frame, - 200, - -40, - t.EmberNodeId(0x1234), - 0, - 0, - b"multicast message", - ], - ) - - assert handler.mock_calls == [ - call( - PacketReceivedEvent( - packet=zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(0x1234), - ), - src_ep=1, - dst=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.Group, - address=0x5678, - ), - dst_ep=2, - tsn=0x42, - profile_id=0x0104, - cluster_id=0x0006, - data=zigpy.types.SerializableBytes(b"multicast message"), - lqi=200, - rssi=-40, - ) - ) - ) - ] - - -def test_incoming_message_ignored_type(prot_hndl, caplog) -> None: - """Test that unknown message types are ignored.""" - handler = MagicMock() - prot_hndl.on_event(PacketReceivedEvent.event_type, handler) - - aps_frame = t.EmberApsFrame( - profileId=0x0104, - clusterId=0x0006, - sourceEndpoint=1, - destinationEndpoint=2, - options=t.EmberApsOption.APS_OPTION_NONE, - groupId=0x0000, - sequence=0x42, - ) - - caplog.set_level(logging.DEBUG) - prot_hndl.handle_parsed_callback( - "incomingMessageHandler", - [ - t.EmberIncomingMessageType.INCOMING_MANY_TO_ONE_ROUTE_REQUEST, - aps_frame, - 200, - -40, - t.EmberNodeId(0x1234), - 0, - 0, - b"ignored message", - ], - ) - - # No event should be emitted for ignored message types - assert len(handler.mock_calls) == 0 - assert "Ignoring message type" in caplog.text - - -def test_trust_center_join_handler(prot_hndl) -> None: - """Test trustCenterJoinHandler callback.""" - handler = MagicMock() - prot_hndl.on_event(TrustCenterJoinEvent.event_type, handler) - - ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") - prot_hndl.handle_parsed_callback( - "trustCenterJoinHandler", - { - "newNodeId": t.EmberNodeId(0x1234), - "newNodeEui64": ieee, - "status": t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - "policyDecision": t.EmberJoinDecision.NO_ACTION, - "parentOfNewNodeId": t.EmberNodeId(0x0000), - }.values(), - ) + ack_task = next(iter(prot_hndl._fragment_ack_tasks)) + await asyncio.gather(ack_task) # Ensure task completes and triggers callback + assert ( + len(prot_hndl._fragment_ack_tasks) == 0 + ), "Done callback should have removed task" - assert handler.mock_calls == [ - call( - TrustCenterJoinEvent( - nwk=t.EmberNodeId(0x1234), - ieee=ieee, - device_update_status=t.EmberDeviceUpdate.STANDARD_SECURITY_UNSECURED_JOIN, - decision=t.EmberJoinDecision.NO_ACTION, - parent_nwk=t.EmberNodeId(0x0000), - ) + prot_hndl._handle_callback.assert_called_once_with( + "incomingMessageHandler", + [ + t.EmberIncomingMessageType.INCOMING_UNICAST, # 0x00 + aps_frame_2, # Parsed APS frame + 255, # lastHopLqi: 0xFF + -8, # lastHopRssi: 0xF8 + sender, # 0x1D6F + 255, # bindingIndex: 0xFF + 255, # addressIndex: 0xFF + reassembled, # Reassembled payload + ], ) - ] - - -def test_incoming_route_record_handler(prot_hndl) -> None: - """Test incomingRouteRecordHandler callback.""" - handler = MagicMock() - prot_hndl.on_event(RouteRecordEvent.event_type, handler) - - ieee = t.EUI64.convert("aa:bb:cc:dd:ee:ff:00:11") - prot_hndl.handle_parsed_callback( - "incomingRouteRecordHandler", - { - "source": t.EmberNodeId(0x1234), - "sourceEui": ieee, - "lastHopLqi": t.uint8_t(200), - "lastHopRssi": t.int8s(-40), - "relayList": [t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], - }.values(), - ) - - assert handler.mock_calls == [ - call( - RouteRecordEvent( - nwk=t.EmberNodeId(0x1234), - ieee=ieee, - lqi=t.uint8_t(200), - rssi=t.int8s(-40), - relays=[t.EmberNodeId(0x0001), t.EmberNodeId(0x0002)], - ) + assert ( + "Reassembled fragmented message. Proceeding with normal handling." + in caplog.text ) - ] - - -def test_id_conflict_handler(prot_hndl) -> None: - """Test idConflictHandler callback.""" - handler = MagicMock() - prot_hndl.on_event(IdConflictEvent.event_type, handler) - - prot_hndl.handle_parsed_callback( - "idConflictHandler", - {"conflictingId": t.EmberNodeId(0x1234)}.values(), - ) - - assert handler.mock_calls == [call(IdConflictEvent(nwk=t.EmberNodeId(0x1234)))] + mock_ack.assert_called_with(sender, aps_frame_2, 2, 1) diff --git a/tests/test_ezsp_v14.py b/tests/test_ezsp_v14.py index 9eab1e96..49bf152b 100644 --- a/tests/test_ezsp_v14.py +++ b/tests/test_ezsp_v14.py @@ -3,9 +3,7 @@ import pytest import zigpy.exceptions import zigpy.state -import zigpy.types -from bellows.ezsp.protocol import MessageSentEvent, PacketReceivedEvent import bellows.ezsp.v14 import bellows.types as t @@ -228,111 +226,3 @@ async def test_send_broadcast(ezsp_f) -> None: message=b"hello", ) ] - - -def test_handle_parsed_callback_incoming_message(ezsp_f) -> None: - """Test handle_parsed_callback for incomingMessageHandler.""" - handler = MagicMock() - ezsp_f.on_event(PacketReceivedEvent.event_type, handler) - - ezsp_f.handle_parsed_callback( - "incomingMessageHandler", - { - "message_type": t.EmberIncomingMessageType.INCOMING_UNICAST, - "aps_frame": t.EmberApsFrame( - profileId=260, - clusterId=8, - sourceEndpoint=1, - destinationEndpoint=1, - options=( - t.EmberApsOption.APS_OPTION_RETRY - | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY - ), - groupId=0, - sequence=168, - ), - "nwk": 0x1174, - "eui64": t.EUI64.convert("00:00:00:00:00:00:00:00"), - "binding_index": 255, - "address_index": 13, - "lqi": 192, - "rssi": -63, - "timestamp": 1333671578, - "message": b"\x18,\x0b\x04\x00", - }.values(), - ) - - assert handler.mock_calls == [ - call( - PacketReceivedEvent( - packet=zigpy.types.ZigbeePacket( - src=zigpy.types.AddrModeAddress( - addr_mode=zigpy.types.AddrMode.NWK, - address=zigpy.types.NWK(0x1174), - ), - src_ep=1, - dst=None, - dst_ep=1, - tsn=168, - profile_id=0x0104, - cluster_id=0x0008, - data=zigpy.types.SerializableBytes(b"\x18,\x0b\x04\x00"), - lqi=192, - rssi=-63, - ) - ) - ) - ] - - -def test_handle_parsed_callback_message_sent(ezsp_f) -> None: - """Test handle_parsed_callback for messageSentHandler.""" - handler = MagicMock() - ezsp_f.on_event(MessageSentEvent.event_type, handler) - - ezsp_f.handle_parsed_callback( - "messageSentHandler", - { - "status": t.sl_Status.OK, - "message_type": t.EmberOutgoingMessageType.OUTGOING_DIRECT, - "nwk": 0x0E0D, - "aps_frame": t.EmberApsFrame( - profileId=260, - clusterId=513, - sourceEndpoint=1, - destinationEndpoint=1, - options=( - t.EmberApsOption.APS_OPTION_RETRY - | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY - ), - groupId=0, - sequence=236, - ), - "message_tag": 103, - "message": b"", - }.values(), - ) - - assert handler.mock_calls == [ - call( - MessageSentEvent( - status=t.sl_Status.OK, - message_type=t.EmberOutgoingMessageType.OUTGOING_DIRECT, - destination=t.EmberNodeId(0x0E0D), - aps_frame=t.EmberApsFrame( - profileId=260, - clusterId=513, - sourceEndpoint=1, - destinationEndpoint=1, - options=( - t.EmberApsOption.APS_OPTION_RETRY - | t.EmberApsOption.APS_OPTION_ENABLE_ROUTE_DISCOVERY - ), - groupId=0, - sequence=236, - ), - message_tag=103, - message_contents=b"", - ) - ) - ] From 185284a4c4991a564e0a94cf5483b8c0d2697faf Mon Sep 17 00:00:00 2001 From: puddly <32534428+puddly@users.noreply.github.com> Date: Thu, 1 Jan 2026 19:45:00 -0500 Subject: [PATCH 18/18] Fix new unit test --- tests/test_ezsp.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_ezsp.py b/tests/test_ezsp.py index b3e24821..d548309a 100644 --- a/tests/test_ezsp.py +++ b/tests/test_ezsp.py @@ -951,21 +951,20 @@ async def test_cfg_initialize_skip(): ) -@pytest.mark.parametrize("unsupported_version", [15, 18, 99]) -async def test_unsupported_ezsp_version_startup(unsupported_version: int, caplog): +async def test_unsupported_ezsp_version_startup(caplog): """Test that startup works with an unsupported EZSP version.""" - ezsp = make_ezsp(version=unsupported_version) + ezsp = make_ezsp(version=99) with patch("bellows.uart.connect"): await ezsp.connect() # The EZSP version should be stored as the unsupported version - assert ezsp._ezsp_version == unsupported_version + assert ezsp._ezsp_version == 99 # But the protocol should fall back to the latest assert ezsp._protocol.VERSION == EZSP_LATEST - assert f"Protocol version {unsupported_version} is not supported" in caplog.text + assert "Protocol version 99 is not supported" in caplog.text ezsp.getConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS, 0)) ezsp.setConfigurationValue = AsyncMock(return_value=(t.EzspStatus.SUCCESS,))