diff --git a/src/pymc_core/companion/binary_parsing.py b/src/pymc_core/companion/binary_parsing.py index fc150da..2c4195e 100644 --- a/src/pymc_core/companion/binary_parsing.py +++ b/src/pymc_core/companion/binary_parsing.py @@ -5,7 +5,12 @@ import logging from typing import Optional -from .constants import BinaryReqType +from .constants import ( + ANON_REQ_TYPE_BASIC, + ANON_REQ_TYPE_OWNER, + ANON_REQ_TYPE_REGIONS, + BinaryReqType, +) logger = logging.getLogger(__name__) @@ -17,6 +22,19 @@ def parse_binary_response( context: Optional[dict] = None, ) -> Optional[dict]: """Parse response_data by request_type. Returns dict or None.""" + context = context or {} + # Anonymous requests (CMD_SEND_ANON_REQ) all carry request_type 0x07, which + # collides with BinaryReqType.OWNER_INFO. Disambiguate by the recorded + # ANON_REQ_TYPE_* sub-type so a regions reply is not parsed as owner info. + anon_sub_type = context.get("anon_sub_type") + if anon_sub_type is not None: + if anon_sub_type == ANON_REQ_TYPE_REGIONS: + return _parse_regions(data) + if anon_sub_type == ANON_REQ_TYPE_OWNER: + return _parse_anon_owner(data) + if anon_sub_type == ANON_REQ_TYPE_BASIC: + return _parse_anon_basic(data) + return {"raw_hex": data.hex(), "anon_sub_type": anon_sub_type} if request_type == BinaryReqType.STATUS and len(data) >= 52: return _parse_status(data, pubkey_prefix=pubkey_prefix or None) if request_type == BinaryReqType.TELEMETRY and len(data) >= 0: @@ -111,6 +129,55 @@ def _parse_owner_info(data: bytes) -> dict: return {"raw_hex": data.hex(), "request_type": BinaryReqType.OWNER_INFO} +def _parse_regions(data: bytes) -> dict: + """Parse ANON_REQ_TYPE_REGIONS response: clock(4) + region-name list. + + The responder replies with tag(4) + clock(4) + names; the tag is stripped by + the caller, so ``data`` is clock(4) + names. Names are a null-terminated, + comma-separated string ('*' denotes the wildcard region; '#' prefixes are + already stripped by the firmware's exportNamesTo). + """ + try: + clock = int.from_bytes(data[:4], "little") if len(data) >= 4 else 0 + raw = data[4:].split(b"\x00", 1)[0] + text = raw.decode("utf-8", errors="replace") + regions = [r for r in text.split(",") if r != ""] + return {"type": "regions", "clock": clock, "regions": regions} + except Exception: + logger.debug("Regions parse failed, returning fallback", exc_info=True) + return {"raw_hex": data.hex(), "anon_sub_type": ANON_REQ_TYPE_REGIONS} + + +def _parse_anon_owner(data: bytes) -> dict: + """Parse ANON_REQ_TYPE_OWNER response: clock(4) + 'name\\nowner'.""" + try: + clock = int.from_bytes(data[:4], "little") if len(data) >= 4 else 0 + text = data[4:].split(b"\x00", 1)[0].decode("utf-8", errors="replace") + parts = text.split("\n", 1) + return { + "type": "owner", + "clock": clock, + "node_name": parts[0] if len(parts) > 0 else "", + "owner_info": parts[1] if len(parts) > 1 else "", + } + except Exception: + logger.debug("Anon owner parse failed, returning fallback", exc_info=True) + return {"raw_hex": data.hex(), "anon_sub_type": ANON_REQ_TYPE_OWNER} + + +def _parse_anon_basic(data: bytes) -> dict: + """Parse ANON_REQ_TYPE_BASIC response: clock(4) + feature flags(1).""" + clock = int.from_bytes(data[:4], "little") if len(data) >= 4 else 0 + features = data[4] if len(data) >= 5 else 0 + return { + "type": "basic", + "clock": clock, + "features": features, + "is_bridge": bool(features & 0x01), + "is_disabled": bool(features & 0x80), + } + + def _parse_acl(buf: bytes) -> dict: """ACL: 7-byte entries (key 6 + perm 1).""" res = [] diff --git a/src/pymc_core/companion/companion_base.py b/src/pymc_core/companion/companion_base.py index b98bd8c..f773197 100644 --- a/src/pymc_core/companion/companion_base.py +++ b/src/pymc_core/companion/companion_base.py @@ -77,9 +77,35 @@ from .models import AdvertPath, Channel, Contact, NodePrefs, QueuedMessage, SentResult from .path_cache import PathCache from .stats_collector import StatsCollector +from .timing import DEFAULT_MAX_ATTEMPTS, response_timeout_ms logger = logging.getLogger("CompanionBase") + +def _fmt_path(out_path_len: int, out_path: Any) -> str: + """Format a contact's out_path for [PATHDIAG] logs without ambiguity. + + ``out_path_len`` is the firmware-encoded path_len byte, not a hop count: + the top 2 bits are (hash_size - 1) and the low 6 bits are the hop count. + E.g. 0x42 == hash_size 2, 2 hops -> 4 path bytes. Render the decoded form + plus the path as hex so the byte value is never misread as a hop count. + """ + if out_path_len is None or out_path_len < 0: + return "unknown (out_path_len=-1, flood)" + if isinstance(out_path, (bytes, bytearray)): + path_hex = bytes(out_path).hex() + elif isinstance(out_path, (list, tuple)): + path_hex = bytes(int(b) & 0xFF for b in out_path).hex() + else: + path_hex = str(out_path) + return ( + f"path_len_byte=0x{out_path_len & 0xFF:02X} " + f"(hash_size={PathUtils.get_path_hash_size(out_path_len)}, " + f"hops={PathUtils.get_path_hash_count(out_path_len)}) " + f"path={path_hex or '(empty)'}" + ) + + PUSH_CALLBACK_KEYS = [ "message_received", "channel_message_received", @@ -813,7 +839,17 @@ async def _on_contact_path_updated(self, pub: bytes, path_len: int, path_bytes: """ contact = self.get_contact_by_key(pub) if contact is None: + logger.debug( + "[PATHDIAG] _on_contact_path_updated: no contact for pub=%s (ignored)", + pub[:4].hex(), + ) return # Firmware does not send PATH for non-contacts + logger.debug( + "[PATHDIAG] _on_contact_path_updated pub=%s name=%s %s", + pub[:4].hex(), + getattr(contact, "name", "?"), + _fmt_path(path_len, path_bytes), + ) contact.out_path_len = path_len contact.out_path = path_bytes self.contacts.update(contact) @@ -908,12 +944,25 @@ async def _on_binary_response( tag_hex = tag_bytes.hex() info = self._pending_binary_requests.pop(tag_hex, None) if not info: - # Skip log for small payloads (e.g. login response handled elsewhere) - if len(response_data) >= 20: - logger.debug(f"Binary response for unknown tag {tag_hex}") + # A decryptable response arrived but no request is waiting for this tag. + # This is the signature of "response arrived but we already timed out" + # (or a tag mismatch); distinct from "no response arrived at all". + logger.debug( + "[PATHDIAG] anon/binary response UNMATCHED tag=%s (%dB) — no pending " + "request (arrived after timeout, or tag mismatch). pending=%s", + tag_hex, + len(response_data), + list(self._pending_binary_requests.keys()), + ) await self._fire_callbacks("binary_response", tag_bytes, response_data) return request_type = info["request_type"] + logger.debug( + "[PATHDIAG] anon/binary response MATCHED tag=%s type=%s (%dB)", + tag_hex, + request_type, + len(response_data), + ) pubkey_prefix = info.get("pubkey_prefix", "") context = info.get("context", {}) parsed = None @@ -1148,26 +1197,36 @@ async def send_anon_req( return SentResult(success=False) request_type = PROTOCOL_CODE_ANON_REQ req_payload = data # no random tag; timestamp provides uniqueness + # The first byte is the ANON_REQ_TYPE_* sub-type (e.g. REGIONS/OWNER); + # record it so the response can be parsed by sub-type rather than being + # mistaken for a binary REQ_TYPE_GET_OWNER_INFO (both use code 0x07). + anon_sub_type = req_payload[0] if len(req_payload) >= 1 else None self.cleanup_expired_binary_requests() try: - pkt, timestamp = PacketBuilder.create_protocol_request( + pkt, timestamp = PacketBuilder.create_anon_request( contact=proxy, local_identity=self._identity, - protocol_code=PROTOCOL_CODE_ANON_REQ, - data=req_payload, + req_data=req_payload, ) # Use the timestamp as the tag — matches what the repeater echoes back tag_int = timestamp tag_bytes = tag_int.to_bytes(4, "little") tag_hex = tag_bytes.hex() + self._apply_flood_scope(pkt) + self._apply_path_hash_mode(pkt) + # Adaptive timeout (firmware calcFlood/DirectTimeoutMillisFor). This is + # fire-and-forget: the response arrives async via the binary-response + # push, and the client retries on this timeout hint — the same model + # firmware uses for anon/discovery (it returns est_timeout and the host + # app re-issues). A short adaptive hint => fast client-driven retry. + timeout_s = self._response_timeout_s(pkt, proxy) self.register_binary_request( tag_hex, request_type=request_type, - timeout_seconds=timeout_seconds, + timeout_seconds=max(timeout_seconds, timeout_s * DEFAULT_MAX_ATTEMPTS), pubkey_prefix=pub_key[:6].hex(), + context={"anon_sub_type": anon_sub_type}, ) - self._apply_flood_scope(pkt) - self._apply_path_hash_mode(pkt) success = await self._send_packet(pkt, wait_for_ack=False) except Exception as e: logger.error(f"Anon request send error: {e}") @@ -1179,9 +1238,11 @@ async def send_anon_req( return SentResult(success=False) return SentResult( success=True, - is_flood=contact.out_path_len <= 0, + # Direct (incl. zero-hop, out_path_len == 0) when the path is known; + # flood only when the out_path is unknown (-1). Mirrors create_anon_request. + is_flood=contact.out_path_len < 0, expected_ack=tag_int, - timeout_ms=DEFAULT_RESPONSE_TIMEOUT_MS, + timeout_ms=int(timeout_s * 1000), ) async def send_path_discovery(self, pub_key: bytes) -> bool: @@ -1446,6 +1507,43 @@ async def send_raw_data_direct( logger.error(f"Error sending raw data direct: {e}") return SentResult(success=False) + async def send_raw_packet(self, priority: int, packet_bytes: bytes) -> bool: + """Inject a fully-formed on-air packet for transmission (CMD_SEND_RAW_PACKET). + + Mirrors firmware ``MyMesh.cpp`` ``CMD_SEND_RAW_PACKET``: parse the raw + on-air bytes into a :class:`Packet` (``tryParsePacket``) and enqueue it + for TX (``sendPacket``). ``packet_bytes`` is the complete wire packet + (header, optional transport codes, path, payload) as produced by + :meth:`Packet.write_to`; it is sent verbatim, with no encryption, + contact lookup, flood-scope, or path-hash-mode rewriting. + + The ``priority`` argument is accepted for protocol compatibility but is + currently ignored: the bridge's low-level send path + (:meth:`_send_packet`) does not expose a prioritized TX queue. + + Returns True if the packet parsed and was handed off for transmission, + False on parse failure or send error (the frame_server handler maps + False to ``ERR_CODE_TABLE_FULL``). + """ + try: + pkt = Packet() + if not pkt.read_from(bytes(packet_bytes)): + return False + except Exception as e: + logger.warning(f"send_raw_packet: failed to parse packet: {e}") + return False + try: + success = await self._send_packet(pkt, wait_for_ack=False) + if success: + self.stats.record_tx(is_flood=False) + else: + self.stats.record_tx_error() + return success + except Exception as e: + logger.error(f"Error sending raw packet: {e}") + self.stats.record_tx_error() + return False + async def send_trace_path( self, pub_key: bytes, @@ -1523,14 +1621,42 @@ def _login_cb(success: bool, data: dict) -> None: login_handler.set_login_callback(_login_cb) try: - pkt = PacketBuilder.create_login_packet( - contact=proxy, local_identity=self._identity, password=password - ) - self._apply_path_hash_mode(pkt) - await self._send_packet(pkt, wait_for_ack=False) - try: - await asyncio.wait_for(login_event.wait(), timeout=10.0) - except asyncio.TimeoutError: + # The login callback fires on any decryptable login response from this + # repeater (keyed by password/dest_hash, not by tag), so we can resend + # a freshly-built login packet each attempt and a single event resolves + # whichever attempt's reply arrives. Each attempt waits one adaptive + # timeout (firmware cadence) instead of a single fixed 10s wait. + for attempt in range(DEFAULT_MAX_ATTEMPTS): + pkt = PacketBuilder.create_login_packet( + contact=proxy, local_identity=self._identity, password=password + ) + self._apply_path_hash_mode(pkt) + timeout_s = self._response_timeout_s(pkt, proxy) + logger.debug( + "[PATHDIAG] login -> 0x%02X (%s) route=%s attempt=%d/%d " + "timeout=%.1fs out_path_len=%s; listening for reply", + dest_hash, + contact.name, + "FLOOD" if pkt.is_route_flood() else "DIRECT", + attempt + 1, + DEFAULT_MAX_ATTEMPTS, + timeout_s, + getattr(proxy, "out_path_len", -1), + ) + await self._send_packet(pkt, wait_for_ack=False) + try: + await asyncio.wait_for(login_event.wait(), timeout=timeout_s) + break # got a response + except asyncio.TimeoutError: + logger.debug( + "[PATHDIAG] login to 0x%02X attempt %d/%d TIMEOUT after %.1fs — " + "no decryptable login response arrived", + dest_hash, + attempt + 1, + DEFAULT_MAX_ATTEMPTS, + timeout_s, + ) + if not login_event.is_set(): return {"success": False, "reason": "Login response timeout"} data = login_result["data"] return { @@ -1566,24 +1692,39 @@ async def send_logout(self, pub_key: bytes) -> bool: logger.error(f"Logout error: {e}") return False + def _response_timeout_s(self, pkt: Packet, proxy: Any) -> float: + """Adaptive response timeout (seconds) for a request packet. + + Mirrors firmware calcFloodTimeoutMillisFor / calcDirectTimeoutMillisFor + using the radio's SF/BW/CR and the packet's on-air length, so a lost + round-trip is retried on a ~3s cadence instead of a fixed 10-15s wait. + """ + try: + out_path_len = getattr(proxy, "out_path_len", -1) + ms = response_timeout_ms( + raw_length=pkt.get_raw_length(), + is_flood=pkt.is_route_flood(), + out_path_len=out_path_len, + sf=int(getattr(self.prefs, "spreading_factor", 10)), + bw_hz=int(getattr(self.prefs, "bandwidth_hz", 250000)), + cr=int(getattr(self.prefs, "coding_rate", 5)), + ) + return ms / 1000.0 + except Exception: + return 5.0 # safe fallback + async def _wait_for_path_propagation(self, proxy: Any, request_type: str) -> None: - """Wait for reciprocal PATH to propagate through the mesh for multi-hop contacts. + """Log the pre-send path; no longer sleeps. - After login, pyMC sends a reciprocal PATH so the remote repeater learns - the return route. Each mesh hop adds ~500ms (airtime + processing). - Without this delay, the first REQ may arrive before the reciprocal PATH, - causing the remote to fall back to sendFlood() — which gets dropped by - intermediate repeaters due to transport-code region filtering. + Firmware sends the request immediately and relies on the reciprocal PATH + (which pyMC already sends at login time, see ProtocolResponseHandler). + The previous 0.5s/hop sleep added up to ~1.5s+ of latency per request for + multi-hop contacts with no reliability benefit and has been removed; the + adaptive timeout + internal resend now handle a lost first attempt. """ out_path_len = getattr(proxy, "out_path_len", -1) - if out_path_len > 0: - hop_count = PathUtils.get_path_hash_count(out_path_len) - propagation_delay = hop_count * 0.5 # e.g. 3 hops → 1.5s - logger.debug( - f"Multi-hop {request_type}: waiting {propagation_delay:.1f}s for " - f"reciprocal PATH propagation ({hop_count} hops)" - ) - await asyncio.sleep(propagation_delay) + out_path = getattr(proxy, "out_path", b"") or b"" + logger.debug("[PATHDIAG] %s pre-send: %s", request_type, _fmt_path(out_path_len, out_path)) async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> dict: """Send a protocol request for repeater status/stats.""" @@ -1601,15 +1742,39 @@ async def send_status_request(self, pub_key: bytes, timeout: float = 15.0) -> di proto_handler.set_response_callback(contact_hash, waiter.callback) try: await self._wait_for_path_propagation(proxy, "stats request") - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_STATUS, - data=b"", - ) - self._apply_path_hash_mode(pkt) - await self._send_packet(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) + # Status responses resolve the waiter by contact_hash (not tag), so a + # fresh REQ each attempt is fine and dodges the repeater's flood dedup. + # Each attempt waits one adaptive timeout (firmware cadence); a late + # reply that lands between attempts resolves the waiter immediately. + result: dict = {"timeout": True} + for attempt in range(DEFAULT_MAX_ATTEMPTS): + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_STATUS, + data=b"", + ) + self._apply_path_hash_mode(pkt) + timeout_s = self._response_timeout_s(pkt, proxy) + logger.debug( + "[PATHDIAG] stats REQ: route=%s attempt=%d/%d timeout=%.1fs " + "path_len_byte=0x%02X (hops=%s) path=%s", + "FLOOD" if pkt.is_route_flood() else "DIRECT", + attempt + 1, + DEFAULT_MAX_ATTEMPTS, + timeout_s, + pkt.path_len & 0xFF, + pkt.get_path_hash_count() if pkt.path_len else 0, + ( + bytes(pkt.path[: pkt.get_path_byte_len()]).hex() + if pkt.path_len + else "(empty)" + ), + ) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout_s) + if not result.get("timeout"): + break return { "success": result.get("success", False), "repeater": contact.name, @@ -1649,15 +1814,20 @@ async def send_telemetry_request( inv = PacketBuilder._compute_inverse_perm_mask( want_base, want_location, want_environment ) - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, - data=bytes([inv]), - ) - self._apply_path_hash_mode(pkt) - await self._send_packet(pkt, wait_for_ack=False) - result = await waiter.wait(timeout) + result: dict = {"timeout": True} + for attempt in range(DEFAULT_MAX_ATTEMPTS): + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=REQ_TYPE_GET_TELEMETRY_DATA, + data=bytes([inv]), + ) + self._apply_path_hash_mode(pkt) + timeout_s = self._response_timeout_s(pkt, proxy) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout_s) + if not result.get("timeout"): + break telemetry_data = dict(result.get("parsed", {})) raw_bytes = telemetry_data.get("raw_bytes", b"") if raw_bytes and len(pub_key) >= 6: @@ -1704,15 +1874,20 @@ async def _send_protocol_request(self, pub_key: bytes, protocol_code: int, data: waiter = ResponseWaiter() proto_handler.set_response_callback(contact_hash, waiter.callback) try: - pkt, _ = PacketBuilder.create_protocol_request( - contact=proxy, - local_identity=self._identity, - protocol_code=protocol_code, - data=data, - ) - self._apply_path_hash_mode(pkt) - await self._send_packet(pkt, wait_for_ack=False) - result = await waiter.wait(10.0) + result: dict = {"timeout": True} + for _attempt in range(DEFAULT_MAX_ATTEMPTS): + pkt, _ = PacketBuilder.create_protocol_request( + contact=proxy, + local_identity=self._identity, + protocol_code=protocol_code, + data=data, + ) + self._apply_path_hash_mode(pkt) + timeout_s = self._response_timeout_s(pkt, proxy) + await self._send_packet(pkt, wait_for_ack=False) + result = await waiter.wait(timeout_s) + if not result.get("timeout"): + break return { "success": result.get("success", False), "response": result.get("text"), @@ -1993,7 +2168,9 @@ async def _handle_group_data_packet(self, packet: Packet) -> None: else: secret = secret[:32] try: - plaintext = CryptoUtils.mac_then_decrypt(hashlib.sha256(secret).digest(), secret, cipher_mac + ciphertext) + plaintext = CryptoUtils.mac_then_decrypt( + hashlib.sha256(secret).digest(), secret, cipher_mac + ciphertext + ) except Exception: plaintext = None if plaintext is not None: @@ -2009,7 +2186,11 @@ async def _handle_group_data_packet(self, packet: Packet) -> None: blob = bytes(plaintext[3 : 3 + data_len]) route_type = packet.get_route_type() - path_len = packet.path_len if route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) else 0xFF + path_len = ( + packet.path_len + if route_type in (ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD) + else 0xFF + ) snr = packet.get_snr() if hasattr(packet, "get_snr") else getattr(packet, "_snr", 0.0) rssi = packet.rssi if hasattr(packet, "rssi") else getattr(packet, "_rssi", 0) queued = QueuedMessage( diff --git a/src/pymc_core/companion/companion_radio.py b/src/pymc_core/companion/companion_radio.py index 81e4e55..b051315 100644 --- a/src/pymc_core/companion/companion_radio.py +++ b/src/pymc_core/companion/companion_radio.py @@ -269,6 +269,10 @@ def _setup_packet_callbacks(self) -> None: dispatcher.protocol_response_handler.set_contact_path_updated_callback( self._on_contact_path_updated ) + # Wire the TX path so the handler can send reciprocal PATH packets + # (firmware onContactPathRecv behaviour). Without this the remote + # repeater never learns its route back to us and floods every reply. + dispatcher.protocol_response_handler.set_packet_injector(self._send_packet) async def _on_packet_received(self, pkt: Any) -> None: route_type = pkt.get_route_type() diff --git a/src/pymc_core/companion/constants.py b/src/pymc_core/companion/constants.py index 63cd292..f6e0cb4 100644 --- a/src/pymc_core/companion/constants.py +++ b/src/pymc_core/companion/constants.py @@ -79,6 +79,16 @@ class BinaryReqType(IntEnum): PROTOCOL_CODE_BINARY_REQ = 0x02 PROTOCOL_CODE_ANON_REQ = 0x07 +# --------------------------------------------------------------------------- +# Anonymous request sub-types (first byte of an ANON_REQ payload, after the +# 4-byte timestamp). Used by the "discover regions from zero-hop repeaters" +# feature and related anonymous queries. Note these collide numerically with +# BinaryReqType values, so anon responses must be disambiguated by sub-type. +# --------------------------------------------------------------------------- +ANON_REQ_TYPE_REGIONS = 0x01 # repeater replies with comma-separated region names +ANON_REQ_TYPE_OWNER = 0x02 # repeater replies with "name\nowner" +ANON_REQ_TYPE_BASIC = 0x03 # repeater replies with clock + feature flags + # --------------------------------------------------------------------------- # Default configuration # --------------------------------------------------------------------------- @@ -98,7 +108,9 @@ class BinaryReqType(IntEnum): # CMD_SEND_ANON_REQ (owner requests, etc.) is supported. # 10+ provides support for multi-byte path lengths. # 11+ adds channel binary datagrams and default flood scope commands. -FIRMWARE_VER_CODE = 11 +# 12+ matches the MeshCore dev-branch companion (v1.15.x/1.16.0 family): adds +# CMD_GET_ALLOWED_REPEAT_FREQ and CMD_SEND_RAW_PACKET. +FIRMWARE_VER_CODE = 12 # --------------------------------------------------------------------------- # Commands (app -> radio) @@ -155,10 +167,12 @@ class BinaryReqType(IntEnum): CMD_SEND_ANON_REQ = 57 CMD_SET_AUTOADD_CONFIG = 58 CMD_GET_AUTOADD_CONFIG = 59 +CMD_GET_ALLOWED_REPEAT_FREQ = 60 CMD_SET_PATH_HASH_MODE = 61 CMD_SEND_CHANNEL_DATA = 62 CMD_SET_DEFAULT_FLOOD_SCOPE = 63 CMD_GET_DEFAULT_FLOOD_SCOPE = 64 +CMD_SEND_RAW_PACKET = 65 # --------------------------------------------------------------------------- # Response codes (radio -> app) @@ -189,6 +203,7 @@ class BinaryReqType(IntEnum): RESP_CODE_TUNING_PARAMS = 23 RESP_CODE_STATS = 24 RESP_CODE_AUTOADD_CONFIG = 25 +RESP_CODE_ALLOWED_REPEAT_FREQ = 26 RESP_CODE_CHANNEL_DATA_RECV = 27 RESP_CODE_DEFAULT_FLOOD_SCOPE = 28 diff --git a/src/pymc_core/companion/frame_server.py b/src/pymc_core/companion/frame_server.py index 646bf3a..852c5cc 100644 --- a/src/pymc_core/companion/frame_server.py +++ b/src/pymc_core/companion/frame_server.py @@ -29,12 +29,14 @@ CMD_EXPORT_CONTACT, CMD_EXPORT_PRIVATE_KEY, CMD_GET_ADVERT_PATH, + CMD_GET_ALLOWED_REPEAT_FREQ, CMD_GET_AUTOADD_CONFIG, CMD_GET_BATT_AND_STORAGE, CMD_GET_CHANNEL, CMD_GET_CONTACT_BY_KEY, CMD_GET_CONTACTS, CMD_GET_CUSTOM_VARS, + CMD_GET_DEFAULT_FLOOD_SCOPE, CMD_GET_DEVICE_TIME, CMD_GET_STATS, CMD_IMPORT_CONTACT, @@ -50,6 +52,7 @@ CMD_SEND_LOGIN, CMD_SEND_PATH_DISCOVERY_REQ, CMD_SEND_RAW_DATA, + CMD_SEND_RAW_PACKET, CMD_SEND_SELF_ADVERT, CMD_SEND_STATUS_REQ, CMD_SEND_TELEMETRY_REQ, @@ -60,10 +63,9 @@ CMD_SET_AUTOADD_CONFIG, CMD_SET_CHANNEL, CMD_SET_CUSTOM_VAR, - CMD_SET_DEVICE_TIME, CMD_SET_DEFAULT_FLOOD_SCOPE, + CMD_SET_DEVICE_TIME, CMD_SET_FLOOD_SCOPE, - CMD_GET_DEFAULT_FLOOD_SCOPE, CMD_SET_OTHER_PARAMS, CMD_SET_PATH_HASH_MODE, CMD_SET_RADIO_PARAMS, @@ -79,8 +81,8 @@ FIRMWARE_VER_CODE, FRAME_INBOUND_PREFIX, FRAME_OUTBOUND_PREFIX, - MAX_FRAME_SIZE, MAX_CHANNEL_DATA_LENGTH, + MAX_FRAME_SIZE, MAX_PATH_SIZE, MAX_PAYLOAD_SIZE, OUT_PATH_UNKNOWN, @@ -103,11 +105,11 @@ PUSH_CODE_TELEMETRY_RESPONSE, PUSH_CODE_TRACE_DATA, RESP_CODE_ADVERT_PATH, + RESP_CODE_ALLOWED_REPEAT_FREQ, RESP_CODE_AUTOADD_CONFIG, RESP_CODE_BATT_AND_STORAGE, RESP_CODE_CHANNEL_DATA_RECV, RESP_CODE_CHANNEL_INFO, - RESP_CODE_DEFAULT_FLOOD_SCOPE, RESP_CODE_CHANNEL_MSG_RECV, RESP_CODE_CHANNEL_MSG_RECV_V3, RESP_CODE_CONTACT, @@ -116,6 +118,7 @@ RESP_CODE_CONTACTS_START, RESP_CODE_CURR_TIME, RESP_CODE_CUSTOM_VARS, + RESP_CODE_DEFAULT_FLOOD_SCOPE, RESP_CODE_DEVICE_INFO, RESP_CODE_END_OF_CONTACTS, RESP_CODE_ERR, @@ -270,7 +273,9 @@ def __init__( CMD_GET_AUTOADD_CONFIG: self._cmd_get_autoadd_config, CMD_SET_OTHER_PARAMS: self._cmd_set_other_params, CMD_SEND_RAW_DATA: self._cmd_send_raw_data, + CMD_SEND_RAW_PACKET: self._cmd_send_raw_packet, CMD_SET_PATH_HASH_MODE: self._cmd_set_path_hash_mode, + CMD_GET_ALLOWED_REPEAT_FREQ: self._cmd_get_allowed_repeat_freq, } # ------------------------------------------------------------------------- @@ -1134,13 +1139,8 @@ async def _cmd_send_channel_txt_msg(self, data: bytes) -> None: if self.bridge.get_channel(channel_idx) is None: self._write_err(ERR_CODE_NOT_FOUND) return - self._write_ok() ok = await self.bridge.send_channel_message(channel_idx, text) - if not ok: - logger.warning( - "Channel message send failed for channel %d after OK response was already sent", - channel_idx, - ) + self._write_ok() if ok else self._write_err(ERR_CODE_BAD_STATE) async def _cmd_send_channel_data(self, data: bytes) -> None: """Handle CMD_SEND_CHANNEL_DATA (62).""" @@ -2019,3 +2019,39 @@ async def _cmd_set_path_hash_mode(self, data: bytes) -> None: return self.bridge.set_path_hash_mode(mode) self._write_ok() + + async def _cmd_get_allowed_repeat_freq(self, data: bytes) -> None: + """Handle CMD_GET_ALLOWED_REPEAT_FREQ (60). + + Firmware (MyMesh.cpp:1958) replies with RESP_ALLOWED_REPEAT_FREQ followed + by zero or more (lower_freq, upper_freq) little-endian u32 pairs. The + virtual companion does not model regional repeat-frequency restrictions, + so it advertises an empty range list (response code with no pairs). + """ + self._write_frame(bytes([RESP_CODE_ALLOWED_REPEAT_FREQ])) + + async def _cmd_send_raw_packet(self, data: bytes) -> None: + """Handle CMD_SEND_RAW_PACKET (65). Format: [priority(1)][raw_packet...]. + + Mirrors MyMesh.cpp:1967: inject a low-level packet with a TX priority. + Delegates to the bridge's ``send_raw_packet`` if available. + """ + if len(data) < 3: + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + priority = data[0] + packet_bytes = data[1:] + send_raw_packet = getattr(self.bridge, "send_raw_packet", None) + if not send_raw_packet: + self._write_err(ERR_CODE_UNSUPPORTED_CMD) + return + try: + ok = await send_raw_packet(priority, packet_bytes) + except Exception as e: + logger.error("send_raw_packet error: %s", e, exc_info=True) + self._write_err(ERR_CODE_ILLEGAL_ARG) + return + if ok: + self._write_ok() + else: + self._write_err(ERR_CODE_TABLE_FULL) diff --git a/src/pymc_core/companion/timing.py b/src/pymc_core/companion/timing.py new file mode 100644 index 0000000..0d04f5a --- /dev/null +++ b/src/pymc_core/companion/timing.py @@ -0,0 +1,102 @@ +"""Adaptive request timeouts mirroring MeshCore firmware. + +The firmware companion (``BaseChatMesh``) sizes each request's response timeout +from the packet's airtime and route, then retries on that cadence: + + calcFloodTimeoutMillisFor(t) = 500 + 16.0 * t + calcDirectTimeoutMillisFor(t, hops) = 500 + (6.0 * t + 250) * (hops + 1) + +where ``t`` is the estimated airtime in ms (see ``examples/companion_radio``). +pyMC previously used fixed 10 s / 15 s waits with no resend, so a single lost +packet stalled for 10-15 s where firmware recovers in ~3 s. This module +reproduces the firmware math so login/stats/discovery use the same cadence. +""" + +import math + +from ..protocol.packet_utils import PathUtils + +# Firmware constants (examples/companion_radio/MyMesh.cpp). +SEND_TIMEOUT_BASE_MILLIS = 500 +FLOOD_SEND_TIMEOUT_FACTOR = 16.0 +DIRECT_SEND_PERHOP_FACTOR = 6.0 +DIRECT_SEND_PERHOP_EXTRA_MILLIS = 250 + +# Default number of attempts (initial send + resends) for a request before +# giving up. Each attempt waits one adaptive timeout. Firmware relies on the +# host app to re-issue; we resend internally so recovery is independent of it. +DEFAULT_MAX_ATTEMPTS = 3 + +# Guard rails so a fast SF doesn't produce a pathologically short timeout and a +# slow SF / huge path doesn't block for too long before a resend. +MIN_TIMEOUT_MILLIS = 1500 +MAX_TIMEOUT_MILLIS = 12000 + + +def estimate_airtime_ms( + packet_length: int, + sf: int, + bw_hz: int, + cr: int, + preamble_symbols: int = 8, + low_dr_opt: bool = None, +) -> float: + """Estimate LoRa airtime (ms) for a packet, per the Semtech formula. + + Mirrors ``SX1262Wrapper`` airtime math: explicit header, CRC on. ``cr`` is + the MeshCore coding-rate index (1->4/5 .. 4->4/8). ``packet_length`` is the + full on-air byte length (use ``Packet.get_raw_length()``). + """ + sf = max(6, min(12, int(sf))) + bw_hz = int(bw_hz) or 250000 + cr = max(1, min(4, int(cr))) + if low_dr_opt is None: + low_dr_opt = sf >= 11 and bw_hz <= 125000 + ldro = 1 if low_dr_opt else 0 + + symbol_time = (1 << sf) / float(bw_hz) + preamble_time = (preamble_symbols + 4.25) * symbol_time + tmp = 8 * packet_length - 4 * sf + 28 + 16 * 1 - 20 * 0 # crc=1, explicit header + denom = 4 * (sf - 2 * ldro) + if tmp > 0 and denom > 0: + payload_symbols = 8 + max(math.ceil(tmp / denom) * (cr + 4), 0) + else: + payload_symbols = 8 + payload_time = payload_symbols * symbol_time + return (preamble_time + payload_time) * 1000.0 + + +def calc_flood_timeout_ms(airtime_ms: float) -> int: + """Firmware ``calcFloodTimeoutMillisFor``.""" + return int(SEND_TIMEOUT_BASE_MILLIS + FLOOD_SEND_TIMEOUT_FACTOR * airtime_ms) + + +def calc_direct_timeout_ms(airtime_ms: float, out_path_len: int) -> int: + """Firmware ``calcDirectTimeoutMillisFor`` (out_path_len is the encoded byte).""" + hops = PathUtils.get_path_hash_count(out_path_len) if out_path_len > 0 else 0 + return int( + SEND_TIMEOUT_BASE_MILLIS + + (DIRECT_SEND_PERHOP_FACTOR * airtime_ms + DIRECT_SEND_PERHOP_EXTRA_MILLIS) * (hops + 1) + ) + + +def response_timeout_ms( + raw_length: int, + is_flood: bool, + out_path_len: int, + sf: int, + bw_hz: int, + cr: int, + preamble_symbols: int = 8, +) -> int: + """Adaptive response timeout (ms) for a request packet, clamped. + + ``is_flood`` selects the flood vs direct firmware formula; ``out_path_len`` + is the contact's encoded path_len byte (used for the per-hop direct term). + """ + airtime = estimate_airtime_ms(raw_length, sf, bw_hz, cr, preamble_symbols) + if is_flood: + ms = calc_flood_timeout_ms(airtime) + else: + ms = calc_direct_timeout_ms(airtime, out_path_len) + return max(MIN_TIMEOUT_MILLIS, min(MAX_TIMEOUT_MILLIS, ms)) diff --git a/src/pymc_core/hardware/kiss_modem_wrapper.py b/src/pymc_core/hardware/kiss_modem_wrapper.py index 17180c1..0e41a88 100644 --- a/src/pymc_core/hardware/kiss_modem_wrapper.py +++ b/src/pymc_core/hardware/kiss_modem_wrapper.py @@ -14,6 +14,7 @@ import random import struct import threading +import time from collections import deque from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Optional, Union @@ -174,6 +175,8 @@ def _invoke_rx_callback( HW_ERR_MAC_FAILED = 0x04 HW_ERR_UNKNOWN_CMD = 0x05 HW_ERR_ENCRYPT_FAILED = 0x06 +# Emitted only on the DATA path when a transmit is already pending (single-slot modem TX). +HW_ERR_TX_BUSY = 0x07 ERR_INVALID_LENGTH = HW_ERR_INVALID_LENGTH ERR_INVALID_PARAM = HW_ERR_INVALID_PARAM @@ -181,6 +184,7 @@ def _invoke_rx_callback( ERR_MAC_FAILED = HW_ERR_MAC_FAILED ERR_UNKNOWN_CMD = HW_ERR_UNKNOWN_CMD ERR_ENCRYPT_FAILED = HW_ERR_ENCRYPT_FAILED +ERR_TX_BUSY = HW_ERR_TX_BUSY # Buffer and timing constants MAX_FRAME_SIZE = 512 @@ -192,6 +196,12 @@ def _invoke_rx_callback( DEFAULT_BAUDRATE = 115200 DEFAULT_TIMEOUT = 1.0 RESPONSE_TIMEOUT = 5.0 # Timeout for command responses +# Extra margin added to estimated airtime when waiting for a DATA TX_DONE, so a long +# transmit (e.g. a high-SF flood advert) is not cut short by the flat command timeout. +TX_DONE_TIMEOUT_MARGIN_S = 1.0 +POST_CONNECT_SETTLE_SECONDS = 0.75 +POST_CONNECT_CONFIGURE_RETRIES = 2 +POST_CONNECT_CONFIGURE_RETRY_BACKOFF_SECONDS = 0.25 logger = logging.getLogger("KissModemWrapper") @@ -221,6 +231,15 @@ class KissModemWrapper(LoRaRadio): specific packet, avoiding race conditions with get_last_rssi/get_last_snr. """ + # Some SetHardware requests may legitimately respond with OK instead of the + # command|0x80 specific response code. + _SETHW_ALLOW_OK_FOR: set[int] = { + HW_CMD_SET_RADIO, + HW_CMD_SET_TX_POWER, + HW_CMD_SET_SIGNAL_REPORT, + HW_CMD_REBOOT, + } + def __init__( self, port: str, @@ -230,6 +249,10 @@ def __init__( radio_config: Optional[Dict[str, Any]] = None, auto_configure: bool = True, lbt_enabled: bool = False, + connect_retries: int = 3, + post_open_delay_ms: int = 500, + usb_reset_on_connect: Optional[bool] = None, + startup_retry_budget_sec: float = 5.0, ): """ Initialize MeshCore KISS Modem Wrapper @@ -259,6 +282,14 @@ def __init__( self.timeout = timeout self.auto_configure = auto_configure self.lbt_enabled = lbt_enabled + self.connect_retries = max(1, int(connect_retries)) + self.post_open_delay_ms = max(0, int(post_open_delay_ms)) + self.startup_retry_budget_sec = max(1.0, float(startup_retry_budget_sec)) + if usb_reset_on_connect is None: + self.usb_reset_on_connect = str(port).startswith("/dev/serial/by-id/") + else: + self.usb_reset_on_connect = bool(usb_reset_on_connect) + self._shutting_down = False self.radio_config = radio_config or {} self.is_configured = False @@ -274,6 +305,8 @@ def __init__( self.serial_conn: Optional[serial.Serial] = None self.is_connected = False + self._degraded = False + self._degraded_reason: Optional[str] = None self.rx_buffer = deque(maxlen=RX_BUFFER_SIZE) self.tx_buffer = deque(maxlen=TX_BUFFER_SIZE) @@ -284,7 +317,49 @@ def __init__( self.rx_thread: Optional[threading.Thread] = None self.tx_thread: Optional[threading.Thread] = None + self.reconnect_thread: Optional[threading.Thread] = None self.stop_event = threading.Event() + self._reconnecting_event = threading.Event() + self._connection_lock = threading.RLock() + self._failure_log_lock = threading.Lock() + self._last_failure_log_ts = 0.0 + self._failure_log_interval_s = float( + self.radio_config.get("failure_log_interval_seconds", 10.0) + ) + self._reconnect_base_delay_s = float( + self.radio_config.get("reconnect_base_delay_seconds", 0.5) + ) + self._reconnect_max_delay_s = float( + self.radio_config.get("reconnect_max_delay_seconds", 15.0) + ) + self._reconnect_max_attempts = int(self.radio_config.get("reconnect_max_attempts", 0)) + self._post_connect_settle_s = max( + 0.0, + float( + self.radio_config.get( + "post_connect_settle_seconds", + POST_CONNECT_SETTLE_SECONDS, + ) + ), + ) + self._post_connect_configure_retries = max( + 0, + int( + self.radio_config.get( + "post_connect_configure_retries", + POST_CONNECT_CONFIGURE_RETRIES, + ) + ), + ) + self._post_connect_configure_retry_backoff_s = max( + 0.0, + float( + self.radio_config.get( + "post_connect_configure_retry_backoff_seconds", + POST_CONNECT_CONFIGURE_RETRY_BACKOFF_SECONDS, + ) + ), + ) # Callbacks self.on_frame_received = on_frame_received @@ -295,11 +370,23 @@ def __init__( self._callback_executor: Optional[ThreadPoolExecutor] = None # Response handling + # Single-flight SetHardware command execution (send -> wait -> return) + self._command_lock = threading.Lock() + # Serialize all UART writes so frame bytes from different callers/threads + # (TX worker vs SetHardware/control paths) cannot interleave. + self._serial_write_lock = threading.Lock() self._response_event = threading.Event() self._pending_response: Optional[tuple[int, bytes]] = None self._response_lock = threading.Lock() + self._expected_response_subcmds: Optional[set[int]] = None + self._active_request_subcmd: Optional[int] = None + self._response_queue: deque[tuple[int, bytes]] = deque(maxlen=32) # TX completion tracking + # Single-flight DATA transmit: the modem holds one pending TX, so only one + # frame may be in flight at a time. Serializing here prevents a second frame + # being written mid-transmit and rejected with TX_BUSY (0x07). + self._tx_inflight_lock = threading.Lock() self._tx_done_event = threading.Event() self._tx_done_result: Optional[bool] = None @@ -361,7 +448,44 @@ def connect(self) -> bool: Returns: True if connection successful, False otherwise """ + with self._connection_lock: + self.stop_event.clear() + self.is_connected = False + if not self._open_serial_and_start_threads(): + return False + if not self._run_post_connect_handshake(): + self._close_serial_connection() + self.is_connected = False + return False + self.is_connected = True + self._reconnecting_event.clear() + self._degraded = False + self._degraded_reason = None + return True + + def disconnect(self): + """Disconnect from serial port and stop threads""" + with self._connection_lock: + self.stop_event.set() + self.is_connected = False + self._degraded = False + self._degraded_reason = None + self._reconnecting_event.clear() + self._close_serial_connection() + + self._stop_io_threads(join_timeout=2.0) + self._stop_reconnect_thread(join_timeout=2.0) + + if self._callback_executor is not None: + self._callback_executor.shutdown(wait=False) + self._callback_executor = None + + logger.info(f"KISS modem disconnected from {self.port}") + + def _open_serial_and_start_threads(self) -> bool: + """Open serial device and start RX/TX workers.""" try: + self._shutting_down = False self.serial_conn = serial.Serial( port=self.port, baudrate=self.baudrate, @@ -370,68 +494,154 @@ def connect(self) -> bool: parity=serial.PARITY_NONE, stopbits=serial.STOPBITS_ONE, ) + self.is_connected = False - self.is_connected = True - self.stop_event.clear() - - # Start communication threads self.rx_thread = threading.Thread(target=self._rx_worker, daemon=True) self.tx_thread = threading.Thread(target=self._tx_worker, daemon=True) - self.rx_thread.start() self.tx_thread.start() + logger.info("KISS modem connected to %s at %s baud", self.port, self.baudrate) - logger.info(f"KISS modem connected to {self.port} at {self.baudrate} baud") - - # Auto-configure if requested - if self.auto_configure and self.radio_config: - if not self.configure_radio(): - logger.warning("Auto-configuration failed") - return False - - # Query modem info - self._query_modem_info() - - # Set KISS TXDELAY so key-up delay is not the firmware default 500ms (reduces - # round-trip latency for repeaters). Value in 10ms units; default 50ms. - tx_delay_ms = self.radio_config.get("tx_delay_ms", 50) - self._set_kiss_tx_delay(tx_delay_ms) - if "kiss_persistence" in self.radio_config: - self.set_kiss_persistence(self.radio_config["kiss_persistence"]) - if "kiss_slottime_ms" in self.radio_config: - self.set_kiss_slottime(self.radio_config["kiss_slottime_ms"]) - if "kiss_txtail_ms" in self.radio_config: - self.set_kiss_txtail(self.radio_config["kiss_txtail_ms"]) - if "kiss_full_duplex" in self.radio_config: - self.set_kiss_full_duplex(bool(self.radio_config["kiss_full_duplex"])) + if not self._wait_for_modem_ready(): + logger.warning("KISS modem did not become ready after reconnect") + return False return True - except Exception as e: - logger.error(f"Failed to connect to {self.port}: {e}") + logger.error("Failed to connect to %s: %s", self.port, e) self.is_connected = False return False - def disconnect(self): - """Disconnect from serial port and stop threads""" - self.is_connected = False - self.stop_event.set() + def _run_post_connect_handshake(self) -> bool: + """Run modem setup steps after serial open.""" + if self._post_connect_settle_s > 0: + logger.debug( + "Post-connect settle delay %.2fs before SetHardware handshake", + self._post_connect_settle_s, + ) + time.sleep(self._post_connect_settle_s) - # Wait for threads to finish - if self.rx_thread and self.rx_thread.is_alive(): - self.rx_thread.join(timeout=2.0) - if self.tx_thread and self.tx_thread.is_alive(): - self.tx_thread.join(timeout=2.0) + # Auto-configure if requested + if self.auto_configure and self.radio_config: + if not self._configure_radio_with_retries(): + logger.warning("Auto-configuration failed after retries") + return False - if self._callback_executor is not None: - self._callback_executor.shutdown(wait=False) - self._callback_executor = None + # Query modem info + self._query_modem_info() + + # Set KISS TXDELAY so key-up delay is not the firmware default 500ms. + tx_delay_ms = self.radio_config.get("tx_delay_ms", 50) + self._set_kiss_tx_delay(tx_delay_ms) + if "kiss_persistence" in self.radio_config: + self.set_kiss_persistence(self.radio_config["kiss_persistence"]) + if "kiss_slottime_ms" in self.radio_config: + self.set_kiss_slottime(self.radio_config["kiss_slottime_ms"]) + if "kiss_txtail_ms" in self.radio_config: + self.set_kiss_txtail(self.radio_config["kiss_txtail_ms"]) + if "kiss_full_duplex" in self.radio_config: + self.set_kiss_full_duplex(bool(self.radio_config["kiss_full_duplex"])) + return True - # Close serial connection - if self.serial_conn and self.serial_conn.is_open: - self.serial_conn.close() + def _close_serial_connection(self) -> None: + """Close serial handle without waiting for worker threads.""" + conn = self.serial_conn + self.serial_conn = None + if conn and conn.is_open: + try: + conn.close() + except Exception: + pass + + def _stop_io_threads(self, join_timeout: float = 2.0) -> None: + """Join RX/TX threads, skipping current thread to avoid deadlock.""" + current = threading.current_thread() + if self.rx_thread and self.rx_thread.is_alive() and self.rx_thread is not current: + self.rx_thread.join(timeout=join_timeout) + if self.tx_thread and self.tx_thread.is_alive() and self.tx_thread is not current: + self.tx_thread.join(timeout=join_timeout) + + def _stop_reconnect_thread(self, join_timeout: float = 2.0) -> None: + """Join reconnect thread if it is running.""" + current = threading.current_thread() + if ( + self.reconnect_thread + and self.reconnect_thread.is_alive() + and self.reconnect_thread is not current + ): + self.reconnect_thread.join(timeout=join_timeout) + + def cleanup(self) -> None: + """Release resources and abort any in-flight blocking waits.""" + self._shutting_down = True + self._response_event.set() + self._tx_done_event.set() + self.disconnect() - logger.info(f"KISS modem disconnected from {self.port}") + def _wait_for_modem_ready(self) -> bool: + """ + Perform serial resync and readiness probing after opening the port. + """ + if not self.serial_conn: + return False + + if self.post_open_delay_ms > 0: + threading.Event().wait(self.post_open_delay_ms / 1000.0) + + try: + self.serial_conn.reset_input_buffer() + self.serial_conn.reset_output_buffer() + except Exception as e: + logger.debug("Serial buffer reset skipped/failed: %s", e) + + if self.usb_reset_on_connect and hasattr(self.serial_conn, "dtr"): + try: + self.serial_conn.dtr = False + threading.Event().wait(0.1) + self.serial_conn.dtr = True + except Exception as e: + logger.debug("USB DTR toggle skipped/failed: %s", e) + + try: + self.serial_conn.write(bytes([KISS_FEND])) + self.serial_conn.flush() + except Exception as e: + logger.debug("KISS parser resync write failed: %s", e) + + backoff_seconds = [0.5, 1.0, 2.0, 2.0, 2.0] + deadline = time.monotonic() + self.startup_retry_budget_sec + + for attempt in range(self.connect_retries): + if self._shutting_down: + return False + resp = self._send_command(CMD_PING, timeout=1.0) + if resp and resp[0] == RESP_PONG: + logger.debug("KISS modem responded to ping on attempt %d", attempt + 1) + return True + + remaining = deadline - time.monotonic() + if remaining <= 0: + break + delay = min(backoff_seconds[min(attempt, len(backoff_seconds) - 1)], remaining) + threading.Event().wait(delay) + + return False + + def _configure_radio_with_retries(self) -> bool: + """Attempt auto-configuration with bounded retries/backoff.""" + deadline = time.monotonic() + self.startup_retry_budget_sec + backoff_seconds = [0.5, 1.0, 2.0, 2.0, 2.0] + for attempt in range(self.connect_retries): + if self._shutting_down: + return False + if self.configure_radio(): + return True + remaining = deadline - time.monotonic() + if remaining <= 0: + break + delay = min(backoff_seconds[min(attempt, len(backoff_seconds) - 1)], remaining) + threading.Event().wait(delay) + return False def _write_frame(self, frame: bytes) -> bool: """ @@ -439,29 +649,109 @@ def _write_frame(self, frame: bytes) -> bool: Ensures the entire frame (including trailing FEND) is written; retries on partial write so we never send a truncated frame. + This method is atomic across threads so frame bytes cannot interleave + on the UART when multiple callers write concurrently. Returns: True if all bytes written, False on error or incomplete write. """ - if not self.serial_conn or not self.serial_conn.is_open: - return False - offset = 0 - while offset < len(frame): - try: - n = self.serial_conn.write(frame[offset:]) - if n is None or n <= 0: - logger.error("Serial write returned %s", n) + with self._serial_write_lock: + if not self.serial_conn or not self.serial_conn.is_open: + self._mark_serial_failure("Serial connection closed during write") + return False + offset = 0 + while offset < len(frame): + try: + n = self.serial_conn.write(frame[offset:]) + if n is None or n <= 0: + logger.error("Serial write returned %s", n) + self._mark_serial_failure(f"Serial write returned {n}") + return False + offset += n + except Exception as e: + logger.error("Serial write error: %s", e) + self._mark_serial_failure(f"Serial write failed: {e}") return False - offset += n + try: + self.serial_conn.flush() except Exception as e: - logger.error("Serial write error: %s", e) + logger.error("Serial flush error: %s", e) + self._mark_serial_failure(f"Serial flush failed: {e}") return False - try: - self.serial_conn.flush() - except Exception as e: - logger.error("Serial flush error: %s", e) - return False - return True + return True + + def _mark_serial_failure(self, reason: str) -> None: + """Transition to degraded mode and trigger reconnect loop once.""" + if self.stop_event.is_set(): + return + + now = time.time() + with self._failure_log_lock: + should_log = (now - self._last_failure_log_ts) >= self._failure_log_interval_s + if should_log: + self._last_failure_log_ts = now + logger.warning("Marking KISS serial link degraded: %s", reason) + + with self._connection_lock: + self._degraded = True + self._degraded_reason = reason + self.is_connected = False + self._close_serial_connection() + + # Wake any in-flight DATA sender so it fails fast instead of waiting out the + # full TX_DONE timeout on a link that is already gone (cleanup() does the same). + self._tx_done_event.set() + + self._start_reconnect_worker() + + def _start_reconnect_worker(self) -> None: + """Start reconnect thread once.""" + if self.stop_event.is_set() or self._reconnecting_event.is_set(): + return + self._reconnecting_event.set() + self.reconnect_thread = threading.Thread(target=self._reconnect_worker, daemon=True) + self.reconnect_thread.start() + + def _reconnect_worker(self) -> None: + """Reconnect with exponential backoff and re-run modem handshake.""" + attempts = 0 + while not self.stop_event.is_set(): + attempts += 1 + if self._reconnect_max_attempts > 0 and attempts > self._reconnect_max_attempts: + logger.error( + "KISS modem reconnect exhausted after %s attempts (last reason: %s)", + self._reconnect_max_attempts, + self._degraded_reason or "unknown", + ) + break + + delay = min( + self._reconnect_base_delay_s * (2 ** max(0, attempts - 1)), + self._reconnect_max_delay_s, + ) + jitter = random.uniform(0.0, min(0.25, delay * 0.2)) + if attempts > 1: + time.sleep(delay + jitter) + + with self._connection_lock: + if self.stop_event.is_set(): + break + self.is_connected = False + self._stop_io_threads(join_timeout=0.5) + if not self._open_serial_and_start_threads(): + continue + if not self._run_post_connect_handshake(): + self._close_serial_connection() + self.is_connected = False + continue + self.is_connected = True + self._degraded = False + self._degraded_reason = None + logger.info("KISS modem serial reconnect successful on attempt %s", attempts) + self._reconnecting_event.clear() + return + + self._reconnecting_event.clear() def _set_kiss_tx_delay(self, delay_ms: int) -> None: """ @@ -599,8 +889,8 @@ def configure_radio( Returns: True if configuration successful, False otherwise """ - if not self.is_connected: - logger.error("Cannot configure radio: not connected") + if not self.serial_conn or not self.serial_conn.is_open: + logger.error("Cannot configure radio: serial link not ready") return False try: @@ -696,25 +986,61 @@ def send_frame(self, data: bytes) -> bool: def send_frame_and_wait(self, data: bytes, timeout: float = RESPONSE_TIMEOUT) -> bool: """ - Send a data frame and wait for TX_DONE response + Send a data frame and wait for the modem's TX_DONE. + + DATA transmits are single-flight: the modem holds only one pending TX, so + only one frame may be in flight at a time. Concurrent callers serialize on + ``_tx_inflight_lock`` so a second frame is never written mid-transmit and + rejected with TX_BUSY (0x07). Args: data: Raw packet data to send - timeout: Timeout in seconds to wait for TX_DONE + timeout: Base timeout in seconds to wait for TX_DONE; extended to cover + the estimated airtime of long frames. Returns: - True if transmission successful, False otherwise + True if transmission completed (TX_DONE ok), False otherwise. """ - self._tx_done_event.clear() - self._tx_done_result = None + if self._shutting_down: + return False - if not self.send_frame(data): + # Don't enqueue DATA while the link is down/reconnecting (mirror _send_command). + in_reconnect_thread = threading.current_thread() is self.reconnect_thread + if (self._reconnecting_event.is_set() or self._degraded) and not in_reconnect_thread: return False - # Wait for TX_DONE response - if self._tx_done_event.wait(timeout): - return self._tx_done_result or False - else: + # Extend the wait to cover real airtime; a high-SF flood advert can exceed the + # flat command timeout, which would otherwise look like a spurious TX_DONE timeout. + try: + airtime_s = PacketTimingUtils.estimate_airtime_ms(len(data), self.radio_config) / 1000.0 + except Exception: + airtime_s = 0.0 + effective_timeout = max(timeout, airtime_s + TX_DONE_TIMEOUT_MARGIN_S) + + with self._tx_inflight_lock: + self._tx_done_event.clear() + self._tx_done_result = None + + if not self.send_frame(data): + return False + + # Poll in short slices so a shutdown or mid-flight link failure returns + # promptly instead of stalling the full timeout. TX_DONE (success/fail) + # and TX_BUSY both set the event; a link drop sets it via + # _mark_serial_failure / cleanup. + deadline = time.monotonic() + effective_timeout + while not self._shutting_down: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + if self._tx_done_event.wait(min(0.1, remaining)): + return self._tx_done_result or False + if self._degraded or self.stop_event.is_set(): + return False + + if self._shutting_down: + return False + logger.warning("TX_DONE timeout") return False @@ -734,33 +1060,86 @@ def _send_command( Returns: Tuple of (response_sub_cmd, response_data) or None on timeout """ - with self._response_lock: - self._response_event.clear() - self._pending_response = None - - # SetHardware frame: type 0x06, payload = sub_cmd (1 byte) + data - kiss_frame = self._encode_kiss_frame(KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data) + if self._shutting_down: + return None - if not self._write_frame(kiss_frame): - logger.warning("SetHardware frame write failed") + # Ensure SetHardware requests are single-flight. This prevents concurrent + # callers from clearing the shared waiter state or stealing responses. + in_reconnect_thread = threading.current_thread() is self.reconnect_thread + reconnecting_from_non_reconnect_thread = ( + self._reconnecting_event.is_set() and not in_reconnect_thread + ) + degraded_from_non_reconnect_thread = self._degraded and not in_reconnect_thread + if reconnecting_from_non_reconnect_thread or degraded_from_non_reconnect_thread: return None - # Wait for response - if self._response_event.wait(timeout): + with self._command_lock: + expected = sub_cmd | 0x80 + acceptable: set[int] = {expected, HW_RESP_ERROR} + if sub_cmd in self._SETHW_ALLOW_OK_FOR: + acceptable.add(HW_RESP_OK) + + # Check queued responses first (late/out-of-order arrivals). with self._response_lock: - return self._pending_response - else: - logger.warning(f"SetHardware sub_cmd 0x{sub_cmd:02X} timeout") - return None + if self._response_queue: + n = len(self._response_queue) + matched: Optional[tuple[int, bytes]] = None + for _ in range(n): + resp_sub, resp_payload = self._response_queue.popleft() + if matched is None and resp_sub in acceptable: + matched = (resp_sub, resp_payload) + else: + self._response_queue.append((resp_sub, resp_payload)) + if matched is not None: + return matched + + self._response_event.clear() + self._pending_response = None + self._expected_response_subcmds = acceptable + self._active_request_subcmd = sub_cmd - def get_radio_config(self) -> Optional[Dict[str, Any]]: + try: + # SetHardware frame: type 0x06, payload = sub_cmd (1 byte) + data + kiss_frame = self._encode_kiss_frame(KISS_CMD_SETHARDWARE, bytes([sub_cmd]) + data) + + if not self._write_frame(kiss_frame): + logger.warning("SetHardware frame write failed") + return None + + # Wait for response with shutdown-aware polling. + deadline = time.monotonic() + timeout + while not self._shutting_down: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + if self._response_event.wait(min(0.1, remaining)): + with self._response_lock: + return self._pending_response + + if self._shutting_down: + return None + + logger.warning(f"SetHardware sub_cmd 0x{sub_cmd:02X} timeout") + return None + finally: + with self._response_lock: + self._expected_response_subcmds = None + self._active_request_subcmd = None + + def get_radio_config(self, timeout: Optional[float] = None) -> Optional[Dict[str, Any]]: """ - Get current radio configuration from modem + Get current radio configuration from modem. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. Returns: Dict with frequency, bandwidth, sf, cr, or None on error """ - resp = self._send_command(CMD_GET_RADIO) + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_RADIO, timeout=t) if resp and resp[0] == RESP_RADIO and len(resp[1]) >= 10: freq, bw, sf, cr = struct.unpack(" bool: logger.error(f"Error setting TX power: {e}") return False - def get_tx_power(self) -> Optional[int]: - """Get current TX power in dBm""" - resp = self._send_command(CMD_GET_TX_POWER) + def get_tx_power(self, timeout: Optional[float] = None) -> Optional[int]: + """Get current TX power in dBm. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. + """ + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_TX_POWER, timeout=t) if resp and resp[0] == RESP_TX_POWER and len(resp[1]) >= 1: return resp[1][0] return None - def get_current_rssi(self) -> int: - """Get current RSSI from modem""" - resp = self._send_command(CMD_GET_CURRENT_RSSI) + def get_current_rssi(self, timeout: Optional[float] = None) -> int: + """Get current RSSI from modem. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. + """ + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_CURRENT_RSSI, timeout=t) if resp and resp[0] == RESP_CURRENT_RSSI and len(resp[1]) >= 1: # RSSI is signed byte rssi = resp[1][0] @@ -810,9 +1203,16 @@ def get_current_rssi(self) -> int: return rssi return -999 - def is_channel_busy(self) -> bool: - """Check if channel is busy""" - resp = self._send_command(CMD_IS_CHANNEL_BUSY) + def is_channel_busy(self, timeout: Optional[float] = None) -> bool: + """Check if channel is busy. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. + """ + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_IS_CHANNEL_BUSY, timeout=t) if resp and resp[0] == RESP_CHANNEL_BUSY and len(resp[1]) >= 1: return resp[1][0] == 0x01 return False @@ -836,9 +1236,16 @@ def get_airtime(self, packet_length: int, timeout: Optional[float] = None) -> Op return struct.unpack(" Optional[int]: - """Get noise floor in dBm""" - resp = self._send_command(CMD_GET_NOISE_FLOOR) + def get_noise_floor(self, timeout: Optional[float] = None) -> Optional[int]: + """Get noise floor in dBm. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. + """ + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_NOISE_FLOOR, timeout=t) if resp and resp[0] == RESP_NOISE_FLOOR and len(resp[1]) >= 2: # Noise floor is signed 16-bit noise = struct.unpack(" Optional[int]: return noise return None - def get_modem_stats(self) -> Optional[Dict[str, int]]: + def get_modem_stats(self, timeout: Optional[float] = None) -> Optional[Dict[str, int]]: """ - Get modem statistics + Get modem statistics. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. Returns: Dict with rx, tx, errors counts or None on error """ - resp = self._send_command(CMD_GET_STATS) + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_STATS, timeout=t) if resp and resp[0] == RESP_STATS and len(resp[1]) >= 12: rx, tx, errors = struct.unpack(" Optional[int]: - """Get battery voltage in millivolts""" - resp = self._send_command(CMD_GET_BATTERY) + def get_battery(self, timeout: Optional[float] = None) -> Optional[int]: + """Get battery voltage in millivolts. + + Blocks the caller thread for up to ``timeout`` seconds (default RESPONSE_TIMEOUT). + + Args: + timeout: SetHardware response wait in seconds, or None for RESPONSE_TIMEOUT. + """ + t = timeout if timeout is not None else RESPONSE_TIMEOUT + resp = self._send_command(CMD_GET_BATTERY, timeout=t) if resp and resp[0] == RESP_BATTERY and len(resp[1]) >= 2: return struct.unpack(" bool: if not self.is_connected: return False try: - return self.ping() + healthy = self.ping() + if not healthy: + self._mark_serial_failure("Health check ping failed") + return healthy except Exception as e: logger.debug(f"KISS modem health check failed: {e}") + self._mark_serial_failure(f"Health check exception: {e}") return False # Optional host-side LBT (only when lbt_enabled, e.g. full-duplex on half-duplex link) @@ -1163,13 +1587,16 @@ async def send(self, data: bytes) -> Optional[Dict[str, Any]]: if self.lbt_enabled: _, lbt_backoff_delays = await self._prepare_for_tx_lbt() - success = self.send_frame(data) + # Wait for modem-level TX_DONE instead of treating queueing as success. + # Run the blocking wait off the event loop. + success = await asyncio.to_thread(self.send_frame_and_wait, data, RESPONSE_TIMEOUT) if not success: - raise Exception("Failed to send frame via KISS modem") + raise Exception("Failed to send frame via KISS modem (no TX_DONE)") # Use short timeout for GET_AIRTIME so TX path is not blocked if modem # is busy or unresponsive (avoids 5s stall and subsequent bad state). - airtime = self.get_airtime(len(data), timeout=1.0) + # Run off the event loop: get_airtime uses blocking SetHardware wait. + airtime = await asyncio.to_thread(self.get_airtime, len(data), 1.0) if airtime is None: airtime = int(PacketTimingUtils.estimate_airtime_ms(len(data), self.radio_config)) return { @@ -1226,10 +1653,10 @@ def get_stats(self) -> Dict[str, Any]: """Get interface statistics""" return self.stats.copy() - def get_status(self) -> Dict[str, Any]: - """Get radio status. Uses cached config/stats where possible.""" - cfg = self.get_radio_config() - tx_power = self.get_tx_power() + def _sync_get_status(self, timeout: Optional[float] = None) -> Dict[str, Any]: + """Build radio status dict (blocking SetHardware reads for config and TX power).""" + cfg = self.get_radio_config(timeout=timeout) + tx_power = self.get_tx_power(timeout=timeout) status: Dict[str, Any] = { "initialized": self.is_connected, "frequency": cfg["frequency"] if cfg else self.radio_config.get("frequency", 0), @@ -1248,6 +1675,54 @@ def get_status(self) -> Dict[str, Any]: } return status + def get_status(self, timeout: Optional[float] = None) -> Dict[str, Any]: + """Get radio status. Queries modem for config and TX power; blocks the caller thread. + + Args: + timeout: Per-query SetHardware timeout in seconds for each modem read + (default RESPONSE_TIMEOUT). + """ + return self._sync_get_status(timeout) + + async def get_status_async(self, timeout: Optional[float] = None) -> Dict[str, Any]: + """Get radio status without blocking the asyncio event loop. + + Runs blocking modem I/O in ``asyncio``'s default thread pool executor. + """ + return await asyncio.to_thread(self._sync_get_status, timeout) + + async def get_radio_config_async( + self, timeout: Optional[float] = None + ) -> Optional[Dict[str, Any]]: + """Async-safe :meth:`get_radio_config`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_radio_config, timeout) + + async def get_tx_power_async(self, timeout: Optional[float] = None) -> Optional[int]: + """Async-safe :meth:`get_tx_power`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_tx_power, timeout) + + async def get_current_rssi_async(self, timeout: Optional[float] = None) -> int: + """Async-safe :meth:`get_current_rssi`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_current_rssi, timeout) + + async def is_channel_busy_async(self, timeout: Optional[float] = None) -> bool: + """Async-safe :meth:`is_channel_busy`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.is_channel_busy, timeout) + + async def get_noise_floor_async(self, timeout: Optional[float] = None) -> Optional[int]: + """Async-safe :meth:`get_noise_floor`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_noise_floor, timeout) + + async def get_modem_stats_async( + self, timeout: Optional[float] = None + ) -> Optional[Dict[str, int]]: + """Async-safe :meth:`get_modem_stats`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_modem_stats, timeout) + + async def get_battery_async(self, timeout: Optional[float] = None) -> Optional[int]: + """Async-safe :meth:`get_battery`; runs blocking modem I/O in a worker thread.""" + return await asyncio.to_thread(self.get_battery, timeout) + # KISS frame encoding/decoding def _encode_kiss_frame(self, cmd: int, data: bytes) -> bytes: @@ -1424,23 +1899,56 @@ def _process_received_frame(self): self._tx_done_event.set() elif sub_cmd == HW_RESP_ERROR: - if len(payload) >= 1: + err_code = payload[0] if len(payload) >= 1 else None + if err_code is not None: self.stats["errors"] += 1 - logger.warning(f"Modem error: 0x{payload[0]:02X}") - with self._response_lock: - self._pending_response = (sub_cmd, payload) - self._response_event.set() + logger.warning(f"Modem error: 0x{err_code:02X}") + if err_code == HW_ERR_TX_BUSY: + # TX_BUSY is a DATA-transmit rejection, not a SetHardware response. + # Wake the in-flight DATA sender so it fails fast (and can retry) + # rather than stalling until the TX_DONE timeout, and keep it out + # of the SetHardware response path (where it would be mis-consumed + # as the in-flight command's error reply). + self._tx_done_result = False + self._tx_done_event.set() + else: + with self._response_lock: + expected = self._expected_response_subcmds + if expected is not None and sub_cmd in expected: + self._pending_response = (sub_cmd, payload) + self._response_event.set() + else: + if len(self._response_queue) == self._response_queue.maxlen: + logger.debug( + "Dropping oldest SetHardware response (queue full); " + "sub_cmd=0x%02X", + sub_cmd, + ) + self._response_queue.append((sub_cmd, payload)) else: # Other response sub-commands (Identity, Radio, OK, etc.) with self._response_lock: - self._pending_response = (sub_cmd, payload) - self._response_event.set() + expected = self._expected_response_subcmds + if expected is not None and sub_cmd in expected: + self._pending_response = (sub_cmd, payload) + self._response_event.set() + else: + if len(self._response_queue) == self._response_queue.maxlen: + logger.debug( + "Dropping oldest SetHardware response (queue full); sub_cmd=0x%02X", + sub_cmd, + ) + self._response_queue.append((sub_cmd, payload)) # cmd 0xFF (Return) has port=15 so is already discarded above def _rx_worker(self): """Background thread for receiving data""" - while not self.stop_event.is_set() and self.is_connected: + while ( + not self.stop_event.is_set() + and self.serial_conn is not None + and self.serial_conn.is_open + ): try: if self.serial_conn and self.serial_conn.in_waiting > 0: data = self.serial_conn.read(self.serial_conn.in_waiting) @@ -1454,11 +1962,16 @@ def _rx_worker(self): except Exception as e: if self.is_connected: logger.error(f"RX worker error: {e}") + self._mark_serial_failure(f"RX worker error: {e}") break def _tx_worker(self): """Background thread for sending data""" - while not self.stop_event.is_set() and self.is_connected: + while ( + not self.stop_event.is_set() + and self.serial_conn is not None + and self.serial_conn.is_open + ): try: if self.tx_buffer: frame = self.tx_buffer.popleft() @@ -1471,12 +1984,14 @@ def _tx_worker(self): logger.warning("TX frame write failed, dropping frame") else: logger.warning("Serial connection not open") + self._mark_serial_failure("Serial connection not open in TX worker") else: threading.Event().wait(0.01) except Exception as e: if self.is_connected: logger.error(f"TX worker error: {e}") + self._mark_serial_failure(f"TX worker error: {e}") break def __enter__(self): @@ -1491,6 +2006,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): def __del__(self): """Destructor to ensure cleanup""" try: - self.disconnect() + self.cleanup() except Exception: pass diff --git a/src/pymc_core/node/dispatcher.py b/src/pymc_core/node/dispatcher.py index a0525b6..7042959 100644 --- a/src/pymc_core/node/dispatcher.py +++ b/src/pymc_core/node/dispatcher.py @@ -544,6 +544,10 @@ async def _send_packet_immediate( self._log(f"Radio transmit error: {e}") self.state = DispatcherState.IDLE return False + if tx_metadata is None: + self._log("Radio transmit returned no confirmation metadata") + self.state = DispatcherState.IDLE + return False # Log what we sent type_name = PAYLOAD_TYPES.get(payload_type, f"UNKNOWN_{payload_type}") route_name = ROUTE_TYPES.get(packet.get_route_type(), f"UNKNOWN_{packet.get_route_type()}") @@ -676,7 +680,7 @@ async def run_forever(self) -> None: if health_check_counter >= 60: health_check_counter = 0 if hasattr(self.radio, "check_radio_health"): - self.radio.check_radio_health() + await asyncio.to_thread(self.radio.check_radio_health) # With callback-based RX, just do maintenance tasks await asyncio.sleep(1.0) # Check every second for cleanup diff --git a/src/pymc_core/node/handlers/__init__.py b/src/pymc_core/node/handlers/__init__.py index 363176c..7188617 100644 --- a/src/pymc_core/node/handlers/__init__.py +++ b/src/pymc_core/node/handlers/__init__.py @@ -4,6 +4,7 @@ from .ack import AckHandler from .advert import AdvertHandler +from .anon_request import AnonRateLimiter, AnonRequestHandler from .base import BaseHandler from .control import ControlHandler from .group_text import GroupTextHandler @@ -17,6 +18,8 @@ __all__ = [ "BaseHandler", + "AnonRequestHandler", + "AnonRateLimiter", "TextMessageHandler", "AdvertHandler", "AckHandler", diff --git a/src/pymc_core/node/handlers/anon_request.py b/src/pymc_core/node/handlers/anon_request.py new file mode 100644 index 0000000..aa16e25 --- /dev/null +++ b/src/pymc_core/node/handlers/anon_request.py @@ -0,0 +1,304 @@ +"""Anonymous-request dispatch handler for repeaters/room servers. + +Mirrors the firmware ``MyMesh::onAnonDataRecv`` (``simple_repeater/MyMesh.cpp``): +a single ``PAYLOAD_TYPE_ANON_REQ`` packet is decrypted and the first byte after +the 4-byte timestamp (``data[4]``) selects the handler: + +- ``0`` or ``>= 0x20`` -> login request (delegated to the wrapped + :class:`LoginServerHandler`, unchanged). +- ``ANON_REQ_TYPE_REGIONS`` (0x01) -> comma-separated region names. +- ``ANON_REQ_TYPE_OWNER`` (0x02) -> ``"node_name\nowner_info"``. +- ``ANON_REQ_TYPE_BASIC`` (0x03) -> clock + a feature-flags byte. + +The regions/owner/basic responders only answer route-direct requests and are +gated behind a shared :class:`AnonRateLimiter` (mirroring the firmware +``anon_limiter``) so the node does not become a flood amplifier. + +This is a pure protocol handler: the application supplies the actual data +(region names, owner info, feature flags, clock) via callbacks. +""" + +import struct +import time +from typing import Callable, Optional, Tuple + +from ...protocol import CryptoUtils, Identity, Packet, PacketBuilder +from ...protocol.constants import MAX_PACKET_PAYLOAD, PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_RESPONSE +from ...protocol.packet_utils import PathUtils +from .base import BaseHandler +from .login_server import LoginServerHandler + +# Anonymous-request sub-types (first byte of an ANON_REQ payload after the +# 4-byte timestamp). Mirrors ``pymc_core.companion.constants`` but defined here +# to avoid a circular import (the companion package imports node.handlers). +ANON_REQ_TYPE_REGIONS = 0x01 +ANON_REQ_TYPE_OWNER = 0x02 +ANON_REQ_TYPE_BASIC = 0x03 + +# Server response delay (ms) — matches firmware SERVER_RESPONSE_DELAY. +SERVER_RESPONSE_DELAY_MS = 300 + + +class AnonRateLimiter: + """Fixed-window limiter for anonymous discovery replies. + + Direct port of the firmware ``RateLimiter`` (``simple_repeater/RateLimiter.h``), + configured to match ``anon_limiter(4, 180)`` — at most 4 replies per fixed + 3-minute window. One instance is shared across all identities so the node's + total anon-reply rate is bounded regardless of which identity is targeted. + """ + + def __init__(self, maximum: int = 4, secs: float = 180.0): + self._maximum = maximum + self._secs = secs + self._start_timestamp = 0.0 + self._count = 0 + + def allow(self, now: Optional[float] = None) -> bool: + """Return ``True`` if under the cap, else ``False`` (mirrors RateLimiter.h).""" + if now is None: + now = time.time() + if now < self._start_timestamp + self._secs: + self._count += 1 + if self._count > self._maximum: + return False # deny + else: # window expired -> reset + self._start_timestamp = now + self._count = 1 + return True + + +class AnonRequestHandler(BaseHandler): + """Decrypts ANON_REQ packets and dispatches on the sub-type byte. + + Wraps an existing :class:`LoginServerHandler` so the login/password path is + byte-for-byte unchanged; the new regions/owner/basic responders are handled + here using application-supplied data callbacks. + """ + + @staticmethod + def payload_type() -> int: + return PAYLOAD_TYPE_ANON_REQ + + def __init__( + self, + local_identity, + log_fn: Callable[[str], None], + login_handler: LoginServerHandler, + anon_limiter: AnonRateLimiter, + *, + region_names_fn: Optional[Callable[[], str]] = None, + owner_info_fn: Optional[Callable[[], Tuple[str, str]]] = None, + features_fn: Optional[Callable[[], int]] = None, + clock_fn: Optional[Callable[[], int]] = None, + ): + """Initialize the dispatcher. + + Args: + local_identity: Server's local identity. + log_fn: Logging function. + login_handler: Existing login handler to delegate password logins to. + anon_limiter: Shared rate limiter for regions/owner/basic replies. + region_names_fn: Returns the comma-separated region-names string. + owner_info_fn: Returns ``(node_name, owner_info)``. + features_fn: Returns the feature-flags byte (bit0 = bridge, + bit7 = forwarding disabled). + clock_fn: Returns the current clock as a Unix timestamp (seconds). + """ + self.local_identity = local_identity + self.log = log_fn + self.login_handler = login_handler + self.anon_limiter = anon_limiter + self.region_names_fn = region_names_fn + self.owner_info_fn = owner_info_fn + self.features_fn = features_fn + self.clock_fn = clock_fn or (lambda: int(time.time())) + self._send_packet_callback: Optional[Callable[[Packet, int], None]] = None + + def set_send_packet_callback(self, callback: Callable[[Packet, int], None]): + """Set callback for sending response packets: ``callback(packet, delay_ms)``.""" + self._send_packet_callback = callback + # Keep the wrapped login handler wired through the same sender. + self.login_handler.set_send_packet_callback(callback) + + async def __call__(self, packet: Packet) -> None: + """Handle an ANON_REQ packet: decrypt, then dispatch on the sub-type byte.""" + try: + # Parse ANON_REQ: dest_hash(1) + client_pubkey(32) + encrypted_data + if len(packet.payload) < 34: + return + + dest_hash = packet.payload[0] + our_hash = self.local_identity.get_public_key()[0] + if dest_hash != our_hash: + return # Not for us + + client_pubkey = bytes(packet.payload[1:33]) + encrypted_data = bytes(packet.payload[33:]) + + client_identity = Identity(client_pubkey) + shared_secret = client_identity.calc_shared_secret( + self.local_identity.get_private_key() + ) + aes_key = shared_secret[:16] + + try: + plaintext = CryptoUtils.mac_then_decrypt(aes_key, shared_secret, encrypted_data) + except Exception as e: + self.log(f"[AnonReq] Failed to decrypt request: {e}") + return + + if len(plaintext) < 5: + # Too short to carry a sub-type byte; treat as login (let it + # apply its own length checks). + await self.login_handler(packet) + return + + subtype = plaintext[4] + + # Login request: sub-type 0x00 or any printable ASCII (>= 0x20), + # i.e. an actual password. Delegate verbatim to the login handler. + if subtype == 0x00 or subtype >= 0x20: + await self.login_handler(packet) + return + + # Regions/owner/basic discovery: route-direct only (firmware parity). + if subtype in (ANON_REQ_TYPE_REGIONS, ANON_REQ_TYPE_OWNER, ANON_REQ_TYPE_BASIC): + kind = { + ANON_REQ_TYPE_REGIONS: "regions", + ANON_REQ_TYPE_OWNER: "owner", + ANON_REQ_TYPE_BASIC: "basic", + }[subtype] + client_hex = client_pubkey[:4].hex() + if not packet.is_route_direct(): + self.log( + f"[AnonReq] {kind} request from {client_hex} ignored " + f"(not route-direct — firmware only answers direct)" + ) + return + if not self.anon_limiter.allow(time.time()): + self.log(f"[AnonReq] {kind} request from {client_hex} rate limited, dropping") + return + self.log(f"[AnonReq] {kind} request from {client_hex} -> replying") + await self._handle_discovery( + packet, client_identity, shared_secret, subtype, plaintext + ) + return + + # Unknown/invalid sub-type: ignore. + self.log(f"[AnonReq] Unknown anon sub-type 0x{subtype:02X}, ignoring") + + except Exception as e: + self.log(f"[AnonReq] Error handling anon request: {e}") + + async def _handle_discovery( + self, + packet: Packet, + client_identity: Identity, + shared_secret: bytes, + subtype: int, + plaintext: bytes, + ) -> None: + """Build and send a regions/owner/basic discovery reply.""" + # plaintext layout: timestamp(4) + subtype(1) + reply_path_byte(1) + reply_path... + sender_timestamp = bytes(plaintext[:4]) # echoed back verbatim as a tag + now_clock = int(self.clock_fn()) & 0xFFFFFFFF + + # Reply-path descriptor follows the sub-type byte. Firmware relies on its + # ``data[len] = 0`` terminator; mirror that by treating a missing + # descriptor byte as zero (no reply path, zero-hop direct). + reply_path_byte = plaintext[5] if len(plaintext) > 5 else 0 + reply_path_len = reply_path_byte & 0x3F + hash_size = (reply_path_byte >> 6) + 1 + path_bytes = bytes(plaintext[6 : 6 + reply_path_len * hash_size]) + + # Common prefix: sender_timestamp(4) + now_clock(4) + reply = bytearray(sender_timestamp) + reply += struct.pack(" max_names: + name_bytes = name_bytes[:max_names] + cut = name_bytes.rfind(b",") + name_bytes = name_bytes[:cut] if cut > 0 else name_bytes + reply += name_bytes + elif subtype == ANON_REQ_TYPE_OWNER: + node_name, owner = self.owner_info_fn() if self.owner_info_fn else ("", "") + reply += f"{node_name}\n{owner}".encode("utf-8", errors="ignore") + elif subtype == ANON_REQ_TYPE_BASIC: + features = self.features_fn() if self.features_fn else 0 + reply.append(features & 0xFF) + else: + return + + self._send_response( + packet, + client_identity, + shared_secret, + bytes(reply), + reply_path_len, + hash_size, + path_bytes, + ) + + def _send_response( + self, + packet: Packet, + client_identity: Identity, + shared_secret: bytes, + reply_data: bytes, + reply_path_len: int, + hash_size: int, + path_bytes: bytes, + ) -> None: + """Encode and dispatch the RESPONSE packet (path-return for flood, direct otherwise).""" + if self._send_packet_callback is None: + self.log("[AnonReq] No send packet callback set, cannot send response") + return + + try: + if packet.is_route_flood(): + # Fallback: tell the sender the path TO here and carry the reply. + client_hash = client_identity.get_public_key()[0] + server_hash = self.local_identity.get_public_key()[0] + in_path = ( + list(packet.path[: packet.get_path_byte_len()]) if packet.path_len > 0 else [] + ) + response_pkt = PacketBuilder.create_path_return( + dest_hash=client_hash, + src_hash=server_hash, + secret=shared_secret, + path=in_path, + extra_type=PAYLOAD_TYPE_RESPONSE, + extra=reply_data, + path_len_encoded=(packet.path_len if packet.path_len > 0 else None), + ) + else: + # Direct reply routed along the path supplied in the request. + response_pkt = PacketBuilder.create_datagram( + ptype=PAYLOAD_TYPE_RESPONSE, + dest=client_identity, + local_identity=self.local_identity, + secret=shared_secret, + plaintext=reply_data, + route_type="direct", + ) + if reply_path_len > 0 and path_bytes: + encoded = PathUtils.encode_path_len(hash_size, reply_path_len) + response_pkt.set_path(path_bytes, encoded) + + self._send_packet_callback(response_pkt, SERVER_RESPONSE_DELAY_MS) + route = "flood" if packet.is_route_flood() else "direct" + self.log( + f"[AnonReq] queued RESPONSE ({len(reply_data)}B, {route}, " + f"+{SERVER_RESPONSE_DELAY_MS}ms)" + ) + except Exception as e: + self.log(f"[AnonReq] Failed to send response: {e}") diff --git a/src/pymc_core/node/handlers/protocol_response.py b/src/pymc_core/node/handlers/protocol_response.py index d059dd5..5f1fbd8 100644 --- a/src/pymc_core/node/handlers/protocol_response.py +++ b/src/pymc_core/node/handlers/protocol_response.py @@ -229,8 +229,16 @@ async def __call__(self, pkt: Packet) -> None: f"({route_label}, {len(pkt.payload)}B)" ) - # Proceed if we have a callback for this source or the binary (path-discovery) callback - if src_hash not in self._response_callbacks and self._binary_response_callback is None: + # Proceed if we have a callback for this source or the binary (path-discovery) + # callback. PATH packets always proceed regardless of waiters so that the + # firmware-equivalent path learning in _decrypt_protocol_response + # (_update_contact_path + reciprocal PATH) runs for the login PATH-return, + # which arrives before any stats/telemetry waiter is registered. + if ( + src_hash not in self._response_callbacks + and self._binary_response_callback is None + and pkt_type != PAYLOAD_TYPE_PATH + ): return # Try to decrypt the response @@ -360,11 +368,23 @@ def _update_contact_path( """ try: if not PathUtils.is_valid_path_len(path_len_byte): + self._log( + f"[PATHDIAG] _update_contact_path REJECT src=0x{src_hash:02X}: " + f"invalid path_len_byte=0x{path_len_byte:02X}" + ) return False path_byte_len = PathUtils.get_path_byte_len(path_len_byte) out_path_bytes = bytes(decrypted[1 : 1 + path_byte_len]) + self._log( + f"[PATHDIAG] _update_contact_path src=0x{src_hash:02X} " + f"path_len_byte=0x{path_len_byte:02X} " + f"hops={PathUtils.get_path_hash_count(path_len_byte)} " + f"hash_size={PathUtils.get_path_hash_size(path_len_byte)} " + f"byte_len={path_byte_len} out_path={out_path_bytes.hex() or '(empty)'}" + ) contact_obj = self._contact_book.get_by_key(contact_pubkey) if contact_obj is not None: + prev_len = getattr(contact_obj, "out_path_len", None) contact_obj.out_path_len = path_len_byte contact_obj.out_path = out_path_bytes self._contact_book.update(contact_obj) @@ -372,6 +392,11 @@ def _update_contact_path( f"[ProtocolResponse] Updated out_path for 0x{src_hash:02X}: " f"path_len={path_len_byte}" ) + self._log( + f"[PATHDIAG] contact 0x{src_hash:02X} out_path_len " + f"{prev_len} -> {path_len_byte} " + f"(hops {PathUtils.get_path_hash_count(path_len_byte)})" + ) return True else: self._log( @@ -444,6 +469,12 @@ async def _send_reciprocal_path( f"[ProtocolResponse] Sending reciprocal PATH to 0x{src_hash:02X} " f"via DIRECT (out_path_len={path_len_byte}, in_path_len={len(in_path)})" ) + self._log( + f"[PATHDIAG] reciprocal -> 0x{src_hash:02X} route=DIRECT " + f"routing_path={out_path_bytes.hex() or '(empty)'} " + f"embedded_in_path={bytes(in_path).hex() or '(empty)'} " + f"path_len_byte=0x{path_len_byte:02X}" + ) except Exception as e: self._log(f"[ProtocolResponse] Failed to send reciprocal PATH: {e}") @@ -496,6 +527,15 @@ async def _decrypt_protocol_response( # Determine the actual response data based on packet type. response_data = decrypted if pkt_type == PAYLOAD_TYPE_PATH: + outer_path = bytes(pkt.path[: pkt.get_path_byte_len()]) if pkt.path_len else b"" + self._log( + f"[PATHDIAG] PATH rx src=0x{src_hash:02X} " + f"route={'FLOOD' if pkt.is_route_flood() else 'DIRECT'} " + f"outer_path_len=0x{pkt.path_len:02X} " + f"outer_hops={pkt.get_path_hash_count()} " + f"outer_path={outer_path.hex() or '(empty)'} " + f"inner_path_len_byte=0x{(decrypted[0] if decrypted else 0):02X}" + ) if len(decrypted) >= 2: path_len_byte = decrypted[0] path_byte_len = PathUtils.get_path_byte_len(path_len_byte) diff --git a/src/pymc_core/protocol/packet_builder.py b/src/pymc_core/protocol/packet_builder.py index dadddb6..68dee16 100644 --- a/src/pymc_core/protocol/packet_builder.py +++ b/src/pymc_core/protocol/packet_builder.py @@ -1,6 +1,7 @@ import hashlib import logging import struct +import threading import time from typing import Any, Optional, Sequence, Union @@ -58,6 +59,11 @@ class PacketBuilder: headers, encryption, and routing information for reliable mesh communication. """ + # Monotonic timestamp state (mirrors firmware getCurrentTimeUnique). Shared + # across all packet types so every request/login tag is strictly increasing. + _last_unique_timestamp: int = 0 + _timestamp_lock = threading.Lock() + @staticmethod def _hash_byte(pubkey: bytes) -> int: """Compute hash byte from public key for packet addressing.""" @@ -89,8 +95,23 @@ def _get_route_type_value(route_type: str, has_routing_path: bool = False) -> in @staticmethod def _get_timestamp() -> int: - """Get current timestamp for packet timing.""" - return int(time.time()) + """Get a strictly-increasing timestamp (epoch seconds) for packet tags. + + Mirrors firmware ``RTCClock::getCurrentTimeUnique`` (MeshCore.h): returns + the current epoch second, but if called more than once within the same + second it bumps by 1 so every request carries a unique, strictly-greater + tag. Firmware repeaters drop a REQ/login whose timestamp is not strictly + greater than the client's last stored timestamp (replay guard), so two + whole-second ``time.time()`` values from back-to-back requests (e.g. a + login immediately followed by a stats request) would collide and the + second packet would be silently ignored. + """ + with PacketBuilder._timestamp_lock: + t = int(time.time()) + if t <= PacketBuilder._last_unique_timestamp: + t = PacketBuilder._last_unique_timestamp + 1 + PacketBuilder._last_unique_timestamp = t + return t @staticmethod def _calc_shared_secret_and_key( @@ -442,6 +463,68 @@ def create_anon_req( pkt.path = bytearray() return pkt + @staticmethod + def create_anon_request( + contact: Any, + local_identity: LocalIdentity, + req_data: bytes = b"", + timestamp: Optional[int] = None, + ) -> tuple[Packet, int]: + """Create a PAYLOAD_TYPE_ANON_REQ packet for an anonymous request. + + Unlike ``create_protocol_request`` (which builds a PAYLOAD_TYPE_REQ and + relies on the recipient already knowing the sender), this emits a true + anonymous request: ``dest_hash(1) + sender_pubkey(32) + cipher`` under a + PAYLOAD_TYPE_ANON_REQ header. The decrypted plaintext is + ``timestamp(4) + req_data`` with ``req_data`` passed through verbatim + (e.g. ``[ANON_REQ_TYPE_REGIONS][reply_path_byte][reply_path...]``); no + protocol/sub-type byte is prepended. + + Routing mirrors firmware ``BaseChatMesh::sendAnonReq``: direct when the + out_path is known (``out_path_len >= 0``, including ``0`` for a zero-hop + direct neighbour) and flood when unknown (``-1``). The firmware regions + handler only answers ``isRouteDirect()`` packets, so zero-hop discovery + requires direct routing. + + Returns: + tuple: (packet, timestamp) - the packet and the timestamp used as the + request tag (echoed back by the responder). + """ + if timestamp is None: + timestamp = PacketBuilder._get_timestamp() + + plaintext = PacketBuilder._pack_timestamp_data(timestamp, req_data) + + contact_pubkey = bytes.fromhex(contact.public_key) + shared_secret, aes_key = PacketBuilder._calc_shared_secret_and_key(contact, local_identity) + cipher = PacketBuilder._encrypt_payload(aes_key, shared_secret, plaintext) + dest_hash = PacketBuilder._hash_byte(contact_pubkey) + payload = bytearray([dest_hash]) + local_identity.get_public_key() + cipher + + out_path_len = getattr(contact, "out_path_len", -1) + out_path = getattr(contact, "out_path", b"") or b"" + # Direct (incl. zero-hop, out_path_len == 0) when the path is known; + # flood only when the out_path is unknown (-1 / OUT_PATH_UNKNOWN). + route_type = "direct" if out_path_len >= 0 else "flood" + + header = PacketBuilder._create_header(PAYLOAD_TYPE_ANON_REQ, route_type) + packet = PacketBuilder._create_packet(header, payload) + packet.path_len = 0 + packet.path = bytearray() + + if route_type == "direct" and len(out_path) > 0: + path_bytes = out_path[:MAX_PATH_SIZE] + encoded_len = None + if PathUtils.is_valid_path_len(out_path_len) and PathUtils.get_path_byte_len( + out_path_len + ) <= len(path_bytes): + encoded_len = out_path_len + elif len(path_bytes) == 64: + path_bytes = path_bytes[:63] + packet.set_path(path_bytes, encoded_len) + + return packet, timestamp + @staticmethod def create_login_packet(contact: Any, local_identity: LocalIdentity, password: str) -> Packet: """ @@ -902,10 +985,11 @@ def create_protocol_request( out_path_len = getattr(contact, "out_path_len", -1) out_path = getattr(contact, "out_path", b"") or b"" - if out_path_len <= 0 or not out_path: - route_type = "flood" - else: - route_type = "direct" + # Direct (incl. zero-hop, out_path_len == 0 with an empty path) when the + # path is known; flood only when the out_path is unknown (-1). Mirrors + # create_anon_request and firmware sendRequest (OUT_PATH_UNKNOWN -> flood, + # else sendDirect, which works with a 0-length path). + route_type = "direct" if out_path_len >= 0 else "flood" header = PacketBuilder._create_header(PAYLOAD_TYPE_REQ, route_type) packet = PacketBuilder._create_packet(header, payload) diff --git a/tests/test_anon_request_handler.py b/tests/test_anon_request_handler.py new file mode 100644 index 0000000..7328ba4 --- /dev/null +++ b/tests/test_anon_request_handler.py @@ -0,0 +1,213 @@ +"""Tests for AnonRequestHandler / AnonRateLimiter (MeshCore 1.16.0 anon discovery). + +Validates parity with firmware ``MyMesh::onAnonDataRecv`` + the regions/owner/ +basic anon responders (``examples/simple_repeater/MyMesh.cpp``): +- sub-type 0x00 / >= 0x20 (a password) is delegated to the login handler; +- sub-type 0x01/0x02/0x03, route-direct, produce a RESPONSE datagram prefixed + with ``sender_timestamp(4) + now_clock(4)``; +- the responders are route-direct only and gated by a shared rate limiter. +""" + +import struct + +import pytest + +from pymc_core.node.handlers.anon_request import ( + ANON_REQ_TYPE_BASIC, + ANON_REQ_TYPE_OWNER, + ANON_REQ_TYPE_REGIONS, + AnonRateLimiter, + AnonRequestHandler, +) +from pymc_core.node.handlers.login_server import LoginServerHandler +from pymc_core.protocol import CryptoUtils, Identity, LocalIdentity, Packet +from pymc_core.protocol.constants import ( + PAYLOAD_TYPE_ANON_REQ, + PAYLOAD_TYPE_RESPONSE, + ROUTE_TYPE_DIRECT, + ROUTE_TYPE_FLOOD, +) + + +class TestAnonRateLimiter: + def test_default_matches_firmware_anon_limiter(self): + # Firmware: anon_limiter(4, 180) — max 4 per fixed 3-minute window. + rl = AnonRateLimiter() + assert rl._maximum == 4 + assert rl._secs == 180.0 + + def test_allows_up_to_cap_then_blocks_within_window(self): + rl = AnonRateLimiter(maximum=4, secs=180) + # First call (start_timestamp=0) resets the window to now=1000. + assert [rl.allow(1000.0) for _ in range(6)] == [True, True, True, True, False, False] + + def test_fixed_window_resets_after_secs(self): + rl = AnonRateLimiter(maximum=2, secs=180) + assert rl.allow(1000.0) is True # reset -> count 1 + assert rl.allow(1000.0) is True # count 2 + assert rl.allow(1000.0) is False # count 3 > 2 -> deny + # Still denied just before the window edge (1000 + 180 = 1180). + assert rl.allow(1179.0) is False + # At/after the window edge the counter resets fully. + assert rl.allow(1180.0) is True + + +class TestAnonRequestHandler: + def setup_method(self): + self.server_identity = LocalIdentity() + self.client_identity_local = LocalIdentity() + + self.auth_callback = lambda *a, **k: (True, 0x03) + self.login_handler = LoginServerHandler( + local_identity=self.server_identity, + log_fn=lambda *_: None, + authenticate_callback=self.auth_callback, + is_room_server=False, + ) + + self.limiter = AnonRateLimiter(maximum=4, secs=180) + self.handler = AnonRequestHandler( + local_identity=self.server_identity, + log_fn=lambda *_: None, + login_handler=self.login_handler, + anon_limiter=self.limiter, + region_names_fn=lambda: "*,VHF,USA", + owner_info_fn=lambda: ("repeater-1", "owner@example.com"), + features_fn=lambda: 0x80, + clock_fn=lambda: 1_700_000_000, + ) + + self.sent = [] + self.handler.set_send_packet_callback(lambda pkt, delay: self.sent.append((pkt, delay))) + + # -- request builders --------------------------------------------------- + + def _shared_secret(self): + server_id = Identity(self.server_identity.get_public_key()) + return server_id.calc_shared_secret(self.client_identity_local.get_private_key()) + + def _build_packet(self, plaintext: bytes, route_type="direct"): + shared_secret = self._shared_secret() + aes_key = shared_secret[:16] + encrypted = CryptoUtils.encrypt_then_mac(aes_key, shared_secret, plaintext) + + server_pubkey = self.server_identity.get_public_key() + payload = ( + bytes([server_pubkey[0]]) + self.client_identity_local.get_public_key() + encrypted + ) + route = ROUTE_TYPE_FLOOD if route_type == "flood" else ROUTE_TYPE_DIRECT + pkt = Packet() + pkt.header = (PAYLOAD_TYPE_ANON_REQ << 2) | route + pkt.payload = bytearray(payload) + pkt.payload_len = len(payload) + pkt.path = bytearray() + pkt.path_len = 0 + return pkt + + def _build_login(self, password="admin123", route_type="flood"): + plaintext = struct.pack("= 0x20 first byte) must still log in (item 1 regression).""" + await self.handler(self._build_login(password="admin123", route_type="flood")) + assert len(self.sent) == 1 + # Login flood reply is a PATH packet, not a RESPONSE datagram. + assert self.sent[0][0].get_payload_type() != PAYLOAD_TYPE_RESPONSE + + @pytest.mark.asyncio + async def test_zero_subtype_routes_to_login(self): + """Empty password (sub-type byte 0x00) routes to login, not discovery.""" + await self.handler(self._build_login(password="", route_type="flood")) + assert len(self.sent) == 1 + + # -- regions / owner / basic responders -------------------------------- + + @pytest.mark.asyncio + async def test_regions_reply(self): + await self.handler(self._build_discovery(ANON_REQ_TYPE_REGIONS)) + assert len(self.sent) == 1 + pkt, delay = self.sent[0] + assert delay == 300 + assert pkt.get_payload_type() == PAYLOAD_TYPE_RESPONSE + assert pkt.is_route_direct() + + body = self._decrypt_response(pkt) + assert body[:4] == struct.pack(" high bits 0), 2 hops => 0x02 + plaintext = struct.pack(" failure. + result = await bridge.send_raw_packet(0, b"\x00") + await bridge.stop() + assert result is False + assert injector.calls == [] + # --------------------------------------------------------------------------- # Path discovery, trace, control data diff --git a/tests/test_companion_radio.py b/tests/test_companion_radio.py index 5448fd1..107e38d 100644 --- a/tests/test_companion_radio.py +++ b/tests/test_companion_radio.py @@ -375,3 +375,61 @@ async def test_send_repeater_command_no_contact(self): out = await comp.send_repeater_command(b"\x00" * 32, "status") assert out["success"] is False assert "not found" in out["reason"].lower() + + +@pytest.mark.asyncio +class TestCompanionLoginRetry: + """send_login resends on a lost round-trip and succeeds on a later attempt.""" + + async def test_login_resends_then_succeeds(self, monkeypatch): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Rpt") + comp.contacts.add(contact) + + # Tiny per-attempt timeout so the test doesn't actually wait seconds. + monkeypatch.setattr(comp, "_response_timeout_s", lambda pkt, proxy: 0.05) + + handler = comp._get_login_response_handler() + captured = {} + orig_set = handler.set_login_callback + monkeypatch.setattr( + handler, "set_login_callback", lambda cb: (captured.__setitem__("cb", cb), orig_set(cb)) + ) + + calls = {"n": 0} + + async def fake_send(pkt, wait_for_ack=False): + calls["n"] += 1 + # First attempt is "lost"; reply only on the second attempt. + if calls["n"] == 2 and captured.get("cb"): + captured["cb"](True, {"timestamp": 1, "is_admin": False}) + return True + + monkeypatch.setattr(comp, "_send_packet", fake_send) + + result = await comp.send_login(contact.public_key, "pw") + assert result["success"] is True + assert calls["n"] == 2 # resent exactly once before success + + async def test_login_all_attempts_timeout(self, monkeypatch): + radio = MockRadio() + comp = CompanionRadio(radio, LocalIdentity()) + contact = _make_peer_contact("Rpt") + comp.contacts.add(contact) + monkeypatch.setattr(comp, "_response_timeout_s", lambda pkt, proxy: 0.02) + + calls = {"n": 0} + + async def fake_send(pkt, wait_for_ack=False): + calls["n"] += 1 + return True # never reply + + monkeypatch.setattr(comp, "_send_packet", fake_send) + + from pymc_core.companion.timing import DEFAULT_MAX_ATTEMPTS + + result = await comp.send_login(contact.public_key, "pw") + assert result["success"] is False + assert "timeout" in result["reason"].lower() + assert calls["n"] == DEFAULT_MAX_ATTEMPTS # tried the full budget diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index def3456..3f990f7 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -495,6 +495,21 @@ async def test_send_packet_with_ack_waiting(self, dispatcher): assert result is True + @pytest.mark.asyncio + async def test_send_packet_returns_false_when_radio_send_returns_none(self, dispatcher): + """If radio.send returns None, dispatcher must fail the send.""" + packet = Packet() + packet.header = (0 << 6) | (0 << 4) | (PAYLOAD_TYPE_ADVERT << 2) | 0 + packet.payload = bytearray(b"test_packet_data") + packet.payload_len = len(packet.payload) + packet.path_len = 0 + + dispatcher.radio.send = AsyncMock(return_value=None) + + result = await dispatcher.send_packet(packet, wait_for_ack=False) + + assert result is False + def test_own_packet_detection(self, dispatcher): """Test detection of own packets.""" # Create packet with our own address as source @@ -652,6 +667,29 @@ async def test_packet_filter_cleanup(self, dispatcher): # Verify cleanup was called dispatcher.packet_filter.cleanup_old_hashes.assert_called_once() + @pytest.mark.asyncio + async def test_run_forever_health_check_uses_to_thread(self, dispatcher): + """Health checks should run via asyncio.to_thread to avoid loop blocking.""" + dispatcher.radio.check_radio_health = Mock(return_value=True) + + sleep_calls = {"count": 0} + + async def fake_sleep(_seconds): + sleep_calls["count"] += 1 + if sleep_calls["count"] >= 60: + raise asyncio.CancelledError() + + to_thread_mock = AsyncMock(return_value=True) + + with ( + patch("pymc_core.node.dispatcher.asyncio.sleep", side_effect=fake_sleep), + patch("pymc_core.node.dispatcher.asyncio.to_thread", to_thread_mock), + ): + with pytest.raises(asyncio.CancelledError): + await dispatcher.run_forever() + + to_thread_mock.assert_awaited_once_with(dispatcher.radio.check_radio_health) + class TestDispatcherErrorHandling: """Test error handling.""" diff --git a/tests/test_frame_server.py b/tests/test_frame_server.py index 8eb9d65..fe585a0 100644 --- a/tests/test_frame_server.py +++ b/tests/test_frame_server.py @@ -16,12 +16,13 @@ PUB_KEY_SIZE, PUSH_CODE_ADVERT, PUSH_CODE_NEW_ADVERT, + RESP_CODE_ALLOWED_REPEAT_FREQ, RESP_CODE_CHANNEL_DATA_RECV, RESP_CODE_DEFAULT_FLOOD_SCOPE, RESP_CODE_OK, ) from pymc_core.companion.frame_server import CompanionFrameServer, _build_advert_push_frames -from pymc_core.companion.models import Contact, NodePrefs, QueuedMessage, SentResult +from pymc_core.companion.models import Contact, QueuedMessage, SentResult def test_build_advert_push_frames_short_only_when_no_name(): @@ -137,7 +138,9 @@ def __init__(self, send_ok: bool = True): def get_channel(self, idx: int): return self._channel if idx == 1 else None - async def send_channel_data(self, channel_idx, data_type, payload, *, path=None, path_len_encoded=None): + async def send_channel_data( + self, channel_idx, data_type, payload, *, path=None, path_len_encoded=None + ): self.calls.append((channel_idx, data_type, payload, path, path_len_encoded)) return self._send_ok @@ -181,7 +184,7 @@ async def test_cmd_send_channel_data_valid_direct_path(): server._write_err = Mock() path_len = PathUtils.encode_path_len(1, 2) # two 1-byte hops - payload = b"\xDE\xAD\xBE" + payload = b"\xde\xad\xbe" data = bytes([1, path_len, 0x10, 0x20, 0x34, 0x12]) + payload await server._cmd_send_channel_data(data) @@ -294,7 +297,7 @@ async def test_cmd_send_raw_data_2byte_hashes(): # path_len_encoded=0x42 → 2-byte hashes, 2 hops → 4 bytes of path path_len_byte = PathUtils.encode_path_len(2, 2) # 0x42 path_data = b"\x01\x02\x03\x04" - payload_data = b"\xAA\xBB\xCC\xDD" + payload_data = b"\xaa\xbb\xcc\xdd" data = bytes([path_len_byte]) + path_data + payload_data await server._cmd_send_raw_data(data) assert len(bridge.calls) == 1 @@ -383,7 +386,7 @@ def test_build_message_frame_channel_data_v15(): snr=2.0, rssi=-90, channel_data_type=0x1234, - channel_data_payload=b"\xAA\xBB", + channel_data_payload=b"\xaa\xbb", ) frame = server._build_message_frame(msg) assert frame[0] == RESP_CODE_CHANNEL_DATA_RECV @@ -391,7 +394,7 @@ def test_build_message_frame_channel_data_v15(): assert frame[5] == 0xFF assert frame[6:8] == b"\x34\x12" assert frame[8] == 2 - assert frame[9:11] == b"\xAA\xBB" + assert frame[9:11] == b"\xaa\xbb" @pytest.mark.asyncio @@ -826,3 +829,84 @@ async def test_handle_client_connection_reset_disconnects_cleanly(caplog): assert server._client_reader is None assert server._writer_task is None assert any("ConnectionResetError" in rec.message for rec in caplog.records) + + +@pytest.mark.asyncio +async def test_cmd_get_allowed_repeat_freq_empty_list(): + """CMD_GET_ALLOWED_REPEAT_FREQ replies with the response code and no ranges.""" + server = CompanionFrameServer(Mock(), "hash", port=0) + frames: list[bytes] = [] + server._write_frame = lambda f: frames.append(f) + await server._cmd_get_allowed_repeat_freq(b"") + assert frames == [bytes([RESP_CODE_ALLOWED_REPEAT_FREQ])] + + +@pytest.mark.asyncio +async def test_cmd_send_raw_packet_unsupported_without_bridge_method(): + """CMD_SEND_RAW_PACKET returns UNSUPPORTED when the bridge can't inject packets.""" + bridge = Mock(spec=[]) # no send_raw_packet attribute + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_err = Mock() + server._write_ok = Mock() + await server._cmd_send_raw_packet(bytes([0x00, 0xAA, 0xBB])) + server._write_err.assert_called_once_with(ERR_CODE_UNSUPPORTED_CMD) + server._write_ok.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_packet_delegates_to_bridge(): + """CMD_SEND_RAW_PACKET parses [priority][raw...] and delegates to the bridge.""" + bridge = Mock() + bridge.send_raw_packet = AsyncMock(return_value=True) + server = CompanionFrameServer(bridge, "hash", port=0) + server._write_ok = Mock() + server._write_err = Mock() + await server._cmd_send_raw_packet(bytes([0x05, 0xDE, 0xAD, 0xBE])) + bridge.send_raw_packet.assert_awaited_once_with(0x05, b"\xde\xad\xbe") + server._write_ok.assert_called_once() + server._write_err.assert_not_called() + + +@pytest.mark.asyncio +async def test_cmd_send_raw_packet_too_short(): + """CMD_SEND_RAW_PACKET rejects a frame with no packet body.""" + server = CompanionFrameServer(Mock(), "hash", port=0) + server._write_err = Mock() + await server._cmd_send_raw_packet(bytes([0x00])) + server._write_err.assert_called_once_with(ERR_CODE_ILLEGAL_ARG) + + +def test_parse_binary_response_regions(): + """Anon REGIONS response decodes clock + comma-separated region names.""" + from pymc_core.companion import binary_parsing + from pymc_core.companion.constants import ANON_REQ_TYPE_REGIONS, PROTOCOL_CODE_ANON_REQ + + # response_data (tag already stripped) = clock(4) + null-terminated name list + data = struct.pack(" N assert callback_calls[0][1] == path_len_byte # encoded byte, not raw count assert callback_calls[0][2] == path_bytes # all 4 bytes of path data + @pytest.mark.asyncio + async def test_login_path_return_learns_path_without_waiter(self): + """Zero-hop login PATH-return must trigger path learning + reciprocal PATH + even with no stats/telemetry waiter registered (Fix B). + + During login no response callback exists yet, so the old guard dropped the + PATH-return before path learning could run, leaving out_path_len == -1 and + forcing the follow-up stats REQ to flood. The PATH branch must always decrypt + so _update_contact_path + reciprocal PATH run (firmware onContactPathRecv).""" + from pymc_core.companion.contact_store import ContactStore + from pymc_core.companion.models import Contact + + local_identity = LocalIdentity() # companion + server_identity = LocalIdentity() # firmware repeater + server_pubkey = server_identity.get_public_key() + contacts = ContactStore(5) + contacts.add(Contact(public_key=server_pubkey, name="Repeater")) + handler = ProtocolResponseHandler(MagicMock(), local_identity, contacts) + + # Login state: no response waiter and no binary callback registered. + injector = AsyncMock() + handler.set_packet_injector(injector) + path_updates = [] + + async def on_path_updated(pub, path_len, path_bytes_arg): + path_updates.append((pub, path_len, path_bytes_arg)) + + handler.set_contact_path_updated_callback(on_path_updated) + + # Firmware zero-hop login reply: 13-byte login response embedded in a + # flood PATH-return with an empty (0-hop) path. + reply = bytearray(13) + struct.pack_into(" bool: + # sub_cmd is the first payload byte (frame[2]) in SetHardware frames. + if frame[2] == CMD_GET_VERSION: + wrote_first.set() + allow_first_write.wait(timeout=1.0) + elif frame[2] == CMD_PING: + wrote_second.set() + return True + + modem._write_frame = mock_write_frame + + results: dict[str, object] = {} + + def call_version(): + results["v"] = modem._send_command(CMD_GET_VERSION, timeout=0.5) + + def call_ping(): + results["p"] = modem._send_command(CMD_PING, timeout=0.5) + + t1 = threading.Thread(target=call_version) + t2 = threading.Thread(target=call_ping) + + t1.start() + assert wrote_first.wait(timeout=1.0) + + # Start second call while first is still holding the command lock in _write_frame. + t2.start() + assert not wrote_second.wait(timeout=0.1) + + # Let the first command proceed and respond. + allow_first_write.set() + version_bytes = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_VERSION, 0x01, KISS_FEND]) + for b in version_bytes: + modem._decode_kiss_byte(b) + + # Now second command can write and receive response. + assert wrote_second.wait(timeout=1.0) + pong_bytes = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_PONG, KISS_FEND]) + for b in pong_bytes: + modem._decode_kiss_byte(b) + + t1.join(timeout=1.0) + t2.join(timeout=1.0) + + assert results.get("v") is not None + assert results.get("p") == (RESP_PONG, b"") + + def test_send_command_timeout_clears_waiter_state(self): + """Timeout path must clear active waiter metadata for later commands.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._write_frame = MagicMock(return_value=True) + + resp = modem._send_command(CMD_GET_VERSION, timeout=0.05) + assert resp is None + assert modem._expected_response_subcmds is None + assert modem._active_request_subcmd is None + + # Ensure no lock leak by issuing another command. + resp2 = modem._send_command(CMD_PING, timeout=0.05) + assert resp2 is None + assert modem._expected_response_subcmds is None + assert modem._active_request_subcmd is None + + def test_send_command_write_failure_clears_waiter_state(self): + """Write-failure path must clear active waiter metadata.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._write_frame = MagicMock(return_value=False) + + resp = modem._send_command(CMD_GET_VERSION, timeout=0.1) + assert resp is None + assert modem._expected_response_subcmds is None + assert modem._active_request_subcmd is None + + def test_response_queue_drop_oldest_when_full(self): + """Unmatched SetHardware responses should drop oldest when queue is full.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + maxlen = modem._response_queue.maxlen or 0 + for i in range(maxlen): + modem._response_queue.append((0xA0 + (i % 10), bytes([i % 256]))) + oldest = modem._response_queue[0] + + # No active waiter; incoming response should be enqueued as unmatched. + frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_IDENTITY, 0x42, KISS_FEND]) + for b in frame: + modem._decode_kiss_byte(b) + + assert len(modem._response_queue) == maxlen + assert modem._response_queue[0] != oldest + assert modem._response_queue[-1] == (RESP_IDENTITY, b"\x42") + + def test_send_command_ok_policy_allowlisted_command(self): + """Allowlisted SetHardware commands may resolve with HW_RESP_OK.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + mock_serial = MagicMock() + mock_serial.is_open = True + mock_serial.write.side_effect = lambda b: len(b) + modem.serial_conn = mock_serial + + result_holder: dict[str, object] = {} + + def caller(): + result_holder["resp"] = modem._send_command(CMD_SET_TX_POWER, b"\x16", timeout=0.5) + + t = threading.Thread(target=caller) + t.start() + + ok_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_OK, KISS_FEND]) + for b in ok_frame: + modem._decode_kiss_byte(b) + + t.join(timeout=1.0) + assert result_holder.get("resp") == (HW_RESP_OK, b"") + + def test_send_command_ok_policy_non_allowlisted_command(self): + """Non-allowlisted commands should not complete on HW_RESP_OK.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + mock_serial = MagicMock() + mock_serial.is_open = True + mock_serial.write.side_effect = lambda b: len(b) + modem.serial_conn = mock_serial + + result_holder: dict[str, object] = {} + + def caller(): + result_holder["resp"] = modem._send_command(CMD_PING, timeout=0.5) + + t = threading.Thread(target=caller) + t.start() + + ok_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, HW_RESP_OK, KISS_FEND]) + for b in ok_frame: + modem._decode_kiss_byte(b) + + pong_frame = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_PONG, KISS_FEND]) + for b in pong_frame: + modem._decode_kiss_byte(b) + + t.join(timeout=1.0) + assert result_holder.get("resp") == (RESP_PONG, b"") + assert len(modem._response_queue) == 1 + assert modem._response_queue[0] == (HW_RESP_OK, b"") + + def test_send_command_preserves_unrelated_response_order(self): + """Multiple unrelated responses remain queued in arrival order.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + mock_serial = MagicMock() + mock_serial.is_open = True + mock_serial.write.side_effect = lambda b: len(b) + modem.serial_conn = mock_serial + + result_holder: dict[str, object] = {} + + def caller(): + result_holder["resp"] = modem._send_command(CMD_PING, timeout=0.5) + + t = threading.Thread(target=caller) + t.start() + + identity = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_IDENTITY, 0xAA, KISS_FEND]) + version = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_VERSION, 0x01, KISS_FEND]) + stats = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_STATS, 0x02, KISS_FEND]) + for frame in (identity, version, stats): + for b in frame: + modem._decode_kiss_byte(b) + + pong = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_PONG, KISS_FEND]) + for b in pong: + modem._decode_kiss_byte(b) + + t.join(timeout=1.0) + assert result_holder.get("resp") == (RESP_PONG, b"") + assert [entry[0] for entry in modem._response_queue] == [ + RESP_IDENTITY, + RESP_VERSION, + RESP_STATS, + ] + assert [entry[1] for entry in modem._response_queue] == [b"\xAA", b"\x01", b"\x02"] def test_tx_done_response(self): """Test SetHardware TxDone (0xF8) response sets event""" @@ -315,6 +563,82 @@ def test_tx_done_response(self): assert modem._tx_done_event.is_set() assert modem._tx_done_result is True + @pytest.mark.asyncio + async def test_send_offloads_get_airtime_to_thread(self): + """send() must not call blocking get_airtime on the asyncio event loop.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem.send_frame_and_wait = MagicMock(return_value=True) + + async def to_thread_side_effect(fn, *args, **kwargs): + if to_thread_side_effect.calls == 0: + to_thread_side_effect.calls += 1 + return True + to_thread_side_effect.calls += 1 + return 42 + + to_thread_side_effect.calls = 0 + + to_thread_mock = AsyncMock(side_effect=to_thread_side_effect) + with patch("pymc_core.hardware.kiss_modem_wrapper.asyncio.to_thread", to_thread_mock): + result = await modem.send(b"payload") + + assert result is not None + assert result["airtime_ms"] == 42 + assert to_thread_mock.await_count == 2 + to_thread_mock.assert_any_await(modem.send_frame_and_wait, b"payload", RESPONSE_TIMEOUT) + to_thread_mock.assert_any_await(modem.get_airtime, len(b"payload"), 1.0) + + +class TestKissAsyncTelemetry: + """Async-safe telemetry entrypoints delegate blocking work via asyncio.to_thread.""" + + @pytest.mark.asyncio + async def test_get_status_async_delegates_to_thread(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + to_thread_mock = AsyncMock(return_value={"ok": True}) + with patch("pymc_core.hardware.kiss_modem_wrapper.asyncio.to_thread", to_thread_mock): + result = await modem.get_status_async(1.25) + assert result == {"ok": True} + to_thread_mock.assert_awaited_once_with(modem._sync_get_status, 1.25) + + @pytest.mark.asyncio + async def test_get_noise_floor_async_delegates_to_thread(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + to_thread_mock = AsyncMock(return_value=-95) + with patch("pymc_core.hardware.kiss_modem_wrapper.asyncio.to_thread", to_thread_mock): + result = await modem.get_noise_floor_async(0.75) + assert result == -95 + to_thread_mock.assert_awaited_once_with(modem.get_noise_floor, 0.75) + + @pytest.mark.asyncio + async def test_get_modem_stats_async_delegates_to_thread(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + stats = {"rx": 1, "tx": 2, "errors": 0} + to_thread_mock = AsyncMock(return_value=stats) + with patch("pymc_core.hardware.kiss_modem_wrapper.asyncio.to_thread", to_thread_mock): + result = await modem.get_modem_stats_async(None) + assert result == stats + to_thread_mock.assert_awaited_once_with(modem.get_modem_stats, None) + + def test_get_noise_floor_forwards_timeout_to_send_command(self): + """Optional timeout on sync getter must reach _send_command.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + calls: list[tuple] = [] + + def mock_send_command(cmd, data=b"", timeout=5.0): + calls.append((cmd, data, timeout)) + if cmd == CMD_GET_NOISE_FLOOR: + return (RESP_NOISE_FLOOR, struct.pack(" None: + self._modem = m + + def write(self, data): + frame = bytes(data) + if ( + self._modem is not None + and len(frame) >= 4 + and frame[0] == KISS_FEND + and frame[-1] == KISS_FEND + and frame[1] == KISS_CMD_SETHARDWARE + ): + sub_cmd = frame[2] + if sub_cmd == CMD_PING: + response_sub = RESP_PONG + response_payload = b"" + elif sub_cmd == CMD_GET_NOISE_FLOOR: + response_sub = RESP_NOISE_FLOOR + response_payload = struct.pack(" None: + resp = ( + bytes([KISS_FEND, KISS_CMD_SETHARDWARE, response_sub]) + + response_payload + + bytes([KISS_FEND]) + ) + for b in resp: + self._modem._decode_kiss_byte(b) + + threading.Thread(target=emit, daemon=True).start() + return len(frame) + + def flush(self): + self.flush_count += 1 + + serial_conn = RespondingSerial() + serial_conn.set_modem(modem) + modem.serial_conn = serial_conn + + stop_event = threading.Event() + data_frame = modem._encode_kiss_frame(CMD_DATA, b"\xAA\xBB\xCC") + + def data_tx_worker() -> None: + for _ in range(200): + if stop_event.is_set(): + return + modem._write_frame(data_frame) + + tx_thread = threading.Thread(target=data_tx_worker) + tx_thread.start() + + try: + for _ in range(40): + ping_resp = modem._send_command(CMD_PING, timeout=0.2) + assert ping_resp is not None + assert ping_resp[0] == RESP_PONG + + noise = modem.get_noise_floor(timeout=0.2) + assert noise == -95 + finally: + stop_event.set() + tx_thread.join(timeout=1.0) + + def test_tx_worker_and_sethardware_queries_make_progress_together(self): + """Queued data TX should still make progress while periodic SetHardware queries run.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + class RespondingSerial: + def __init__(self): + self.is_open = True + self._modem: KissModemWrapper | None = None + self.flush_count = 0 + self.data_writes = 0 + + def set_modem(self, m: KissModemWrapper) -> None: + self._modem = m + + def write(self, data): + frame = bytes(data) + if ( + self._modem is not None + and len(frame) >= 4 + and frame[0] == KISS_FEND + and frame[-1] == KISS_FEND + and frame[1] == KISS_CMD_SETHARDWARE + ): + sub_cmd = frame[2] + if sub_cmd == CMD_PING: + response_sub = RESP_PONG + response_payload = b"" + elif sub_cmd == CMD_GET_NOISE_FLOOR: + response_sub = RESP_NOISE_FLOOR + response_payload = struct.pack(" None: + resp = ( + bytes([KISS_FEND, KISS_CMD_SETHARDWARE, response_sub]) + + response_payload + + bytes([KISS_FEND]) + ) + for b in resp: + self._modem._decode_kiss_byte(b) + + threading.Thread(target=emit, daemon=True).start() + elif len(frame) >= 2 and frame[0] == KISS_FEND and frame[1] == CMD_DATA: + self.data_writes += 1 + + return len(frame) + + def flush(self): + self.flush_count += 1 + + serial_conn = RespondingSerial() + serial_conn.set_modem(modem) + modem.serial_conn = serial_conn + + for _ in range(120): + assert modem.send_frame(b"\x01\x02\x03") + + modem.stop_event.clear() + tx_thread = threading.Thread(target=modem._tx_worker, daemon=True) + tx_thread.start() + + try: + for _ in range(25): + ping_resp = modem._send_command(CMD_PING, timeout=0.3) + assert ping_resp is not None + assert ping_resp[0] == RESP_PONG + + noise = modem.get_noise_floor(timeout=0.3) + assert noise == -92 + + deadline = time.time() + 2.0 + while modem.tx_buffer and time.time() < deadline: + time.sleep(0.01) + assert len(modem.tx_buffer) == 0 + assert serial_conn.data_writes > 0 + finally: + modem.stop_event.set() + tx_thread.join(timeout=1.0) + + def test_write_error_does_not_poison_future_writes(self): + """A serial write error should fail fast and transition to degraded mode.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + class FlakySerial: + def __init__(self): + self.is_open = True + self.calls = 0 + self.flush_count = 0 + + def write(self, data): + self.calls += 1 + if self.calls == 1: + raise OSError("simulated serial failure") + return len(data) + + def flush(self): + self.flush_count += 1 + + serial_conn = FlakySerial() + modem.serial_conn = serial_conn + modem._start_reconnect_worker = MagicMock() + + frame = modem._encode_kiss_frame(CMD_DATA, b"\xAA\xBB") + assert modem._write_frame(frame) is False + assert modem._degraded is True + assert modem.is_connected is False + modem._start_reconnect_worker.assert_called_once() + + class TestQueryMethods: """Test modem query methods""" @@ -749,6 +1349,7 @@ def mock_send_command(cmd, data=b"", timeout=5.0): modem._send_command = mock_send_command modem.is_connected = True + modem.serial_conn = MagicMock(is_open=True) modem.configure_radio() @@ -773,6 +1374,7 @@ def mock_send_command(cmd, data=b"", timeout=5.0): modem._send_command = mock_send_command modem.is_connected = True + modem.serial_conn = MagicMock(is_open=True) modem.configure_radio() @@ -797,6 +1399,7 @@ def mock_send_command(cmd, data=b"", timeout=5.0): modem._send_command = mock_send_command modem.is_connected = True + modem.serial_conn = MagicMock(is_open=True) modem.configure_radio() @@ -999,3 +1602,413 @@ def test_context_manager_calls_connect_disconnect(self): mock_connect.assert_called_once() mock_disconnect.assert_called_once() _ = modem # hold ref so __del__ runs after assert, not before + + +class TestAsyncSendTxDone: + """Test async send() TX_DONE confirmation behavior.""" + + @pytest.mark.asyncio + async def test_send_returns_metadata_after_tx_done(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem.lbt_enabled = False + modem.send_frame_and_wait = MagicMock(return_value=True) + modem.get_airtime = MagicMock(return_value=123) + + result = await modem.send(b"\x01\x02\x03\x04") + + assert result["airtime_ms"] == 123 + assert result["lbt_attempts"] == 0 + modem.send_frame_and_wait.assert_called_once() + + @pytest.mark.asyncio + async def test_send_raises_when_tx_done_not_received(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem.lbt_enabled = False + modem.send_frame_and_wait = MagicMock(return_value=False) + + with pytest.raises(Exception, match="TX_DONE"): + await modem.send(b"\x01\x02\x03\x04") + + +class TestShutdownAndConnectResilience: + def test_cleanup_sets_abort_flags_and_disconnects(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._response_event.clear() + modem._tx_done_event.clear() + + with patch.object(modem, "disconnect") as mock_disconnect: + modem.cleanup() + + assert modem._shutting_down is True + assert modem._response_event.is_set() is True + assert modem._tx_done_event.is_set() is True + mock_disconnect.assert_called_once() + + def test_send_command_returns_none_while_shutting_down(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._shutting_down = True + assert modem._send_command(CMD_GET_VERSION, timeout=0.1) is None + + def test_wait_for_modem_ready_retries_until_pong(self): + modem = KissModemWrapper( + port="/dev/serial/by-id/test", + auto_configure=False, + connect_retries=3, + post_open_delay_ms=0, + startup_retry_budget_sec=30.0, + ) + serial_conn = MagicMock() + serial_conn.is_open = True + modem.serial_conn = serial_conn + + responses = [None, None, (RESP_PONG, b"")] + + def _fake_send_command(_cmd, data=b"", timeout=1.0): + return responses.pop(0) + + modem._send_command = _fake_send_command + + with patch("threading.Event.wait", return_value=None): + assert modem._wait_for_modem_ready() is True + + assert modem.serial_conn.write.called + assert modem.serial_conn.flush.called + + def test_configure_radio_with_retries_eventually_succeeds(self): + modem = KissModemWrapper( + port="/dev/null", + auto_configure=False, + connect_retries=3, + startup_retry_budget_sec=30.0, + ) + + with ( + patch.object(modem, "configure_radio", side_effect=[False, False, True]) as mock_cfg, + patch("threading.Event.wait", return_value=None), + ): + assert modem._configure_radio_with_retries() is True + + assert mock_cfg.call_count == 3 + + +class TestSerialRecovery: + """Test serial degraded-state and reconnect behavior.""" + + def test_write_frame_marks_degraded_and_triggers_reconnect(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + class _FailingSerial: + is_open = True + + def write(self, _data): + raise OSError(5, "Input/output error") + + modem.serial_conn = _FailingSerial() + modem._start_reconnect_worker = MagicMock() + + frame = modem._encode_kiss_frame(CMD_DATA, b"\x01\x02") + assert modem._write_frame(frame) is False + assert modem._degraded is True + assert modem.is_connected is False + assert modem.serial_conn is None + modem._start_reconnect_worker.assert_called_once() + + def test_send_command_fails_fast_while_reconnecting(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._reconnecting_event.set() + modem._write_frame = MagicMock(return_value=True) + + assert modem._send_command(CMD_PING, timeout=0.1) is None + modem._write_frame.assert_not_called() + + def test_send_command_allowed_from_reconnect_thread_during_reconnect(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._reconnecting_event.set() + modem.reconnect_thread = threading.current_thread() + modem._response_queue.append((RESP_PONG, b"")) + + assert modem._send_command(CMD_PING, timeout=0.1) == (RESP_PONG, b"") + + def test_send_command_allowed_from_reconnect_thread_while_degraded(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._degraded = True + modem.reconnect_thread = threading.current_thread() + modem._response_queue.append((RESP_PONG, b"")) + + assert modem._send_command(CMD_PING, timeout=0.1) == (RESP_PONG, b"") + + def test_reconnect_worker_recovers_after_open_failure(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._reconnecting_event.set() + modem._degraded = True + modem._degraded_reason = "test failure" + modem._reconnect_base_delay_s = 0.0 + modem._reconnect_max_delay_s = 0.0 + + modem._open_serial_and_start_threads = MagicMock(side_effect=[False, True]) + modem._run_post_connect_handshake = MagicMock(return_value=True) + modem._stop_io_threads = MagicMock() + + with patch("pymc_core.hardware.kiss_modem_wrapper.time.sleep", return_value=None): + modem._reconnect_worker() + + assert modem._open_serial_and_start_threads.call_count == 2 + assert modem._run_post_connect_handshake.call_count == 1 + assert modem._degraded is False + assert modem._reconnecting_event.is_set() is False + + def test_start_reconnect_worker_guard_prevents_duplicate_thread(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._reconnecting_event.set() + modem._start_reconnect_worker() + assert modem.reconnect_thread is None + + def test_connect_clears_reconnecting_gate_after_success(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._reconnecting_event.set() + modem._open_serial_and_start_threads = MagicMock(return_value=True) + modem._run_post_connect_handshake = MagicMock(return_value=True) + + assert modem.connect() is True + assert modem.is_connected is True + assert modem._reconnecting_event.is_set() is False + + def test_connect_sets_connected_only_after_handshake_success(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._open_serial_and_start_threads = MagicMock(return_value=True) + + def handshake() -> bool: + # is_connected should stay false until handshake fully succeeds. + assert modem.is_connected is False + return True + + modem._run_post_connect_handshake = MagicMock(side_effect=handshake) + + assert modem.connect() is True + assert modem.is_connected is True + + def test_connect_handshake_failure_leaves_disconnected(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._open_serial_and_start_threads = MagicMock(return_value=True) + modem._run_post_connect_handshake = MagicMock(return_value=False) + modem._close_serial_connection = MagicMock() + + assert modem.connect() is False + assert modem.is_connected is False + modem._close_serial_connection.assert_called_once() + + def test_connect_retries_transient_configure_failure_then_succeeds(self): + modem = KissModemWrapper( + port="/dev/null", + auto_configure=True, + radio_config={"frequency": 869618000}, + ) + modem._open_serial_and_start_threads = MagicMock(return_value=True) + modem._close_serial_connection = MagicMock() + modem._query_modem_info = MagicMock() + modem._set_kiss_tx_delay = MagicMock() + + connected_states = [] + + def transient_configure_failure() -> bool: + connected_states.append(modem.is_connected) + return len(connected_states) > 1 + + modem.configure_radio = MagicMock(side_effect=transient_configure_failure) + + with patch( + "pymc_core.hardware.kiss_modem_wrapper.time.sleep", return_value=None + ) as sleep_mock, patch("threading.Event.wait", return_value=None): + assert modem.connect() is True + + assert modem.configure_radio.call_count == 2 + assert connected_states == [False, False] + assert modem.is_connected is True + assert modem._close_serial_connection.call_count == 0 + assert len(sleep_mock.call_args_list) == 1 # post-connect settle only + + def test_connect_persistent_configure_failures_still_fail(self): + modem = KissModemWrapper( + port="/dev/null", + auto_configure=True, + radio_config={"frequency": 869618000}, + ) + modem._open_serial_and_start_threads = MagicMock(return_value=True) + modem._close_serial_connection = MagicMock() + modem._query_modem_info = MagicMock() + modem._set_kiss_tx_delay = MagicMock() + modem.configure_radio = MagicMock(return_value=False) + + with patch("pymc_core.hardware.kiss_modem_wrapper.time.sleep", return_value=None), patch( + "threading.Event.wait", return_value=None + ): + assert modem.connect() is False + + assert modem.is_connected is False + assert modem.configure_radio.call_count == modem.connect_retries + modem._close_serial_connection.assert_called_once() + + def test_connect_sets_connected_only_after_retrying_handshake(self): + modem = KissModemWrapper( + port="/dev/null", + auto_configure=True, + radio_config={"frequency": 869618000}, + ) + modem._open_serial_and_start_threads = MagicMock(return_value=True) + modem._close_serial_connection = MagicMock() + modem._query_modem_info = MagicMock() + modem._set_kiss_tx_delay = MagicMock() + + observed_states = [] + + def configure_with_one_retry() -> bool: + observed_states.append(modem.is_connected) + return len(observed_states) >= 2 + + modem.configure_radio = MagicMock(side_effect=configure_with_one_retry) + + with patch("pymc_core.hardware.kiss_modem_wrapper.time.sleep", return_value=None): + assert modem.connect() is True + + assert observed_states == [False, False] + assert modem.is_connected is True + + def test_reconnect_sets_connected_only_after_handshake_success(self): + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem._reconnecting_event.set() + modem._degraded = True + modem._degraded_reason = "test failure" + modem._reconnect_base_delay_s = 0.0 + modem._reconnect_max_delay_s = 0.0 + modem._open_serial_and_start_threads = MagicMock(return_value=True) + + def reconnect_handshake() -> bool: + assert modem.is_connected is False + return True + + modem._run_post_connect_handshake = MagicMock(side_effect=reconnect_handshake) + modem._stop_io_threads = MagicMock() + + with patch("pymc_core.hardware.kiss_modem_wrapper.time.sleep", return_value=None): + modem._reconnect_worker() + + assert modem.is_connected is True + assert modem._degraded is False + assert modem._reconnecting_event.is_set() is False + + +class TestKissDataTxSingleFlight: + """DATA transmits are single-flight and fail fast on TX_BUSY / link loss.""" + + def test_send_frame_and_wait_is_single_flight(self): + """A second DATA frame must not be written while the first is in flight.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + first_sent = threading.Event() + second_sent = threading.Event() + order: list[str] = [] + + def mock_send_frame(data: bytes) -> bool: + if data == b"AA": + order.append("A") + first_sent.set() + elif data == b"BB": + order.append("B") + second_sent.set() + return True + + modem.send_frame = mock_send_frame + + results: dict[str, object] = {} + t1 = threading.Thread( + target=lambda: results.__setitem__("a", modem.send_frame_and_wait(b"AA", timeout=2.0)) + ) + t1.start() + assert first_sent.wait(timeout=1.0) + + # B must block on the in-flight lock while A awaits TX_DONE. + t2 = threading.Thread( + target=lambda: results.__setitem__("b", modem.send_frame_and_wait(b"BB", timeout=2.0)) + ) + t2.start() + assert not second_sent.wait(timeout=0.2) + + tx_done = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_TX_DONE, 0x01, KISS_FEND]) + for byte in tx_done: # complete A -> releases the lock + modem._decode_kiss_byte(byte) + + assert second_sent.wait(timeout=1.0) + for byte in tx_done: # complete B + modem._decode_kiss_byte(byte) + + t1.join(timeout=1.0) + t2.join(timeout=1.0) + assert results.get("a") is True + assert results.get("b") is True + assert order == ["A", "B"] + + def test_tx_busy_wakes_sender_and_is_not_queued(self): + """A 0x07 (TX_BUSY) error fails the sender fast and is not mis-routed to the + SetHardware response path.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + + sent = threading.Event() + modem.send_frame = lambda data: (sent.set() or True) + + result: dict[str, object] = {} + start = time.monotonic() + t = threading.Thread( + target=lambda: result.__setitem__("r", modem.send_frame_and_wait(b"AA", timeout=5.0)) + ) + t.start() + assert sent.wait(timeout=1.0) + + err = bytes([KISS_FEND, KISS_CMD_SETHARDWARE, RESP_ERROR, HW_ERR_TX_BUSY, KISS_FEND]) + for byte in err: + modem._decode_kiss_byte(byte) + + t.join(timeout=1.0) + assert result.get("r") is False + assert time.monotonic() - start < 2.0 # failed fast, not via the 5s timeout + assert len(modem._response_queue) == 0 # not consumed by the SetHardware waiter + + def test_serial_failure_wakes_in_flight_sender(self): + """A serial failure mid-transmit wakes the waiter instead of stalling.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._start_reconnect_worker = MagicMock() # don't spawn a reconnect thread + + sent = threading.Event() + modem.send_frame = lambda data: (sent.set() or True) + + result: dict[str, object] = {} + start = time.monotonic() + t = threading.Thread( + target=lambda: result.__setitem__("r", modem.send_frame_and_wait(b"AA", timeout=5.0)) + ) + t.start() + assert sent.wait(timeout=1.0) + + modem._mark_serial_failure("link lost") + + t.join(timeout=1.0) + assert result.get("r") is False + assert time.monotonic() - start < 2.0 + + def test_send_frame_and_wait_skips_when_degraded(self): + """Don't enqueue DATA while the link is degraded.""" + modem = KissModemWrapper(port="/dev/null", auto_configure=False) + modem.is_connected = True + modem._degraded = True + modem.send_frame = MagicMock(return_value=True) + + assert modem.send_frame_and_wait(b"AA", timeout=2.0) is False + modem.send_frame.assert_not_called() diff --git a/tests/test_packet_builder.py b/tests/test_packet_builder.py index 335f811..5906cac 100644 --- a/tests/test_packet_builder.py +++ b/tests/test_packet_builder.py @@ -4,6 +4,7 @@ MAX_PACKET_PAYLOAD, PAYLOAD_TYPE_ACK, PAYLOAD_TYPE_ADVERT, + PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_PATH, PAYLOAD_TYPE_RAW_CUSTOM, ) @@ -265,3 +266,117 @@ def test_truncated_path_packet_round_trip(): assert ok assert pkt2.get_path_byte_len() == len(pkt2.path) assert pkt2.get_path_byte_len() == 63 + + +def _make_contact(other, out_path=b"", out_path_len=-1): + return type( + "Contact", + (), + { + "public_key": other.get_public_key().hex(), + "out_path": out_path, + "out_path_len": out_path_len, + }, + )() + + +def _decrypt_anon(pkt, sender_local, recipient_local): + """Decrypt an ANON_REQ packet: payload = dest_hash(1)+sender_pubkey(32)+cipher.""" + assert pkt.payload[1:33] == bytes(sender_local.get_public_key()) + cipher = bytes(pkt.payload[33:]) + secret = Identity(sender_local.get_public_key()).calc_shared_secret( + recipient_local.get_private_key() + ) + return CryptoUtils.mac_then_decrypt(secret[:16], secret, cipher) + + +def test_create_anon_request_is_anon_payload_type_no_subtype_prefix(): + """Regression: anon requests must be PAYLOAD_TYPE_ANON_REQ with the client's + sub-type byte at offset 4 (after the 4-byte timestamp) - NOT a PAYLOAD_TYPE_REQ + with 0x07 prepended (which repeaters read as REQ_TYPE_GET_OWNER_INFO).""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=0) # zero-hop neighbour + # ANON_REQ_TYPE_REGIONS (0x01) + reply-path byte (0 = empty path) + req_data = bytes([0x01, 0x00]) + pkt, ts = PacketBuilder.create_anon_request(contact, local, req_data) + + assert pkt.get_payload_type() == PAYLOAD_TYPE_ANON_REQ + plaintext = _decrypt_anon(pkt, local, other) + assert int.from_bytes(plaintext[:4], "little") == ts + # sub-type byte sits immediately after the timestamp, with no 0x07 prefix + assert plaintext[4] == 0x01 + # (trailing bytes are AES block padding, ignored by the responder) + assert plaintext[4 : 4 + len(req_data)] == req_data + + +def test_create_anon_request_zero_hop_is_direct(): + """out_path_len == 0 (zero-hop direct neighbour) must route DIRECT so the + firmware regions handler (which requires isRouteDirect()) answers.""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=0) + pkt, _ = PacketBuilder.create_anon_request(contact, local, bytes([0x01, 0x00])) + assert pkt.is_route_direct() + assert not pkt.is_route_flood() + + +def test_create_anon_request_unknown_path_is_flood(): + """out_path_len == -1 (unknown) must route FLOOD.""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=-1) + pkt, _ = PacketBuilder.create_anon_request(contact, local, bytes([0x01, 0x00])) + assert pkt.is_route_flood() + + +def test_create_protocol_request_zero_hop_is_direct(): + """out_path_len == 0 (zero-hop direct neighbour, empty path) must route DIRECT. + + After login establishes the path, stats/telemetry requests must use sendDirect + so the firmware repeater answers directly instead of flooding (matches firmware + BaseChatMesh::sendRequest and create_anon_request).""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=0) + pkt, _ = PacketBuilder.create_protocol_request(contact, local, 0x01, b"") + assert pkt.is_route_direct() + assert not pkt.is_route_flood() + # Zero-hop direct packet carries an empty path (firmware sendDirect(pkt, path, 0)). + assert pkt.path_len == 0 + + +def test_create_protocol_request_unknown_path_is_flood(): + """out_path_len == -1 (unknown) must route FLOOD.""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=-1) + pkt, _ = PacketBuilder.create_protocol_request(contact, local, 0x01, b"") + assert pkt.is_route_flood() + + +def test_get_timestamp_is_strictly_monotonic_within_same_second(): + """Back-to-back tags must strictly increase even within one wall-clock second. + + Firmware repeaters drop a REQ/login whose timestamp is not strictly greater + than the client's last stored timestamp (replay guard). Mirrors firmware + getCurrentTimeUnique so a login + immediate stats request don't collide and + get silently ignored.""" + ts = [PacketBuilder._get_timestamp() for _ in range(5)] + assert ts == sorted(ts) + assert len(set(ts)) == 5 # all unique + assert all(b == a + 1 or b > a for a, b in zip(ts, ts[1:])) + + +def test_login_then_stats_tags_strictly_increase(): + """A login followed immediately by a stats request must carry strictly + increasing timestamps so the firmware repeater accepts the stats REQ.""" + local = LocalIdentity() + other = LocalIdentity() + contact = _make_contact(other, out_path=b"", out_path_len=0) + login_pkt = PacketBuilder.create_login_packet( + contact=contact, local_identity=local, password="x" + ) + _, stats_ts = PacketBuilder.create_protocol_request(contact, local, 0x01, b"") + login_ts = int.from_bytes(_decrypt_anon(login_pkt, local, other)[:4], "little") + assert stats_ts > login_ts diff --git a/tests/test_timing.py b/tests/test_timing.py new file mode 100644 index 0000000..a073889 --- /dev/null +++ b/tests/test_timing.py @@ -0,0 +1,63 @@ +"""Tests for adaptive request timeouts (companion/timing.py).""" + +import math + +from pymc_core.companion import timing +from pymc_core.protocol.packet_utils import PathUtils + + +def test_estimate_airtime_matches_semtech_formula(): + """SF10/250kHz/CR4-5, 24-byte packet ~= 185ms (hand-computed Semtech airtime).""" + air = timing.estimate_airtime_ms(24, sf=10, bw_hz=250000, cr=1) + assert math.isclose(air, 185.3, rel_tol=0.02) + + +def test_airtime_grows_with_spreading_factor(): + """Higher SF => much longer airtime (each +1 SF roughly doubles symbol time).""" + a_sf8 = timing.estimate_airtime_ms(40, sf=8, bw_hz=250000, cr=1) + a_sf10 = timing.estimate_airtime_ms(40, sf=10, bw_hz=250000, cr=1) + a_sf12 = timing.estimate_airtime_ms(40, sf=12, bw_hz=125000, cr=1) + assert a_sf8 < a_sf10 < a_sf12 + + +def test_flood_timeout_matches_firmware_formula(): + assert timing.calc_flood_timeout_ms(200.0) == int(500 + 16.0 * 200.0) + + +def test_direct_timeout_uses_hop_count_not_raw_byte(): + # 0x42 encodes hash_size=2, hop_count=2 -> firmware (hops+1) factor of 3. + out_path_len = 0x42 + assert PathUtils.get_path_hash_count(out_path_len) == 2 + expected = int(500 + (6.0 * 200.0 + 250) * (2 + 1)) + assert timing.calc_direct_timeout_ms(200.0, out_path_len) == expected + + +def test_direct_timeout_zero_hop(): + # 0x40 -> hash_size 2, 0 hops -> factor (0+1). + expected = int(500 + (6.0 * 200.0 + 250) * 1) + assert timing.calc_direct_timeout_ms(200.0, 0x40) == expected + + +def test_response_timeout_is_clamped(): + # Tiny airtime (fast SF, small packet) must not drop below the floor. + fast = timing.response_timeout_ms( + raw_length=12, is_flood=False, out_path_len=0, sf=7, bw_hz=500000, cr=1 + ) + assert fast == timing.MIN_TIMEOUT_MILLIS + # Huge multi-hop flood must not exceed the ceiling. + slow = timing.response_timeout_ms( + raw_length=200, is_flood=True, out_path_len=0, sf=12, bw_hz=125000, cr=4 + ) + assert slow == timing.MAX_TIMEOUT_MILLIS + + +def test_flood_vs_direct_selection(): + """A typical small request lands in a sane few-second window for both routes.""" + flood = timing.response_timeout_ms( + raw_length=53, is_flood=True, out_path_len=-1, sf=10, bw_hz=250000, cr=1 + ) + direct = timing.response_timeout_ms( + raw_length=22, is_flood=False, out_path_len=0x42, sf=10, bw_hz=250000, cr=1 + ) + assert timing.MIN_TIMEOUT_MILLIS <= flood <= timing.MAX_TIMEOUT_MILLIS + assert timing.MIN_TIMEOUT_MILLIS <= direct <= timing.MAX_TIMEOUT_MILLIS