From bd5ab85df459be3709f0bc72faec639822342303 Mon Sep 17 00:00:00 2001 From: Alice Cheng Date: Tue, 28 Apr 2026 16:09:47 -0700 Subject: [PATCH 1/2] refactor: generalize ZMQ pub/sub over message type via MessageCodec Replace EventRecord-specific publisher/subscriber classes with generic ZmqMessagePublisher[T] / ZmqMessageSubscriber[T] parameterized by a MessageCodec[T] Protocol. EventRecordCodec preserves existing wire format and decode-error wrapping behavior. Sets up the generic transport that the upcoming MetricsSnapshot publisher will reuse. - protocol.py: drop EventRecordPublisher/Subscriber ABCs; add MessageCodec, MessagePublisher[T], MessageSubscriber[T]. - pubsub.py: rewrite as ZmqMessagePublisher[T]/ZmqMessageSubscriber[T]; expose sndhwm/linger/conflate so future callers (e.g. live snapshots) can choose drop-old vs. delivery-guarantee semantics. - record.py: add EventRecordCodec next to encode/decode helpers. - Update EventPublisherService, EventLoggerService, MetricsAggregatorService and tests to use the generic classes. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../async_utils/event_publisher.py | 10 +- .../services/event_logger/__main__.py | 7 +- .../services/metrics_aggregator/aggregator.py | 7 +- .../async_utils/transport/protocol.py | 183 +++++++++--------- .../async_utils/transport/zmq/pubsub.py | 118 ++++++++--- src/inference_endpoint/core/record.py | 45 +++-- .../load_generator/session.py | 2 +- .../unit/async_utils/test_event_publisher.py | 10 +- tests/unit/core/test_record.py | 45 ++--- .../unit/transport/test_zmq_pool_transport.py | 7 +- 10 files changed, 258 insertions(+), 176 deletions(-) diff --git a/src/inference_endpoint/async_utils/event_publisher.py b/src/inference_endpoint/async_utils/event_publisher.py index 98e25eae..bfc2f92a 100644 --- a/src/inference_endpoint/async_utils/event_publisher.py +++ b/src/inference_endpoint/async_utils/event_publisher.py @@ -18,13 +18,14 @@ from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext -from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher +from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessagePublisher +from inference_endpoint.core.record import EventRecord, EventRecordCodec -class EventPublisherService(ZmqEventRecordPublisher): +class EventPublisherService(ZmqMessagePublisher[EventRecord]): """Publisher for publishing event records over ZMQ PUB socket. - Wraps ZmqEventRecordPublisher with LoopManager integration and + Wraps ZmqMessagePublisher[EventRecord] with LoopManager integration and auto-generated socket names. """ @@ -44,7 +45,7 @@ def __init__( synchronization mechanism (e.g., ENDED as a stop signal). isolated_event_loop: If True, runs on a separate event loop thread. send_threshold: Minimum number of buffered records before an - automatic flush is triggered. See ZmqEventRecordPublisher. + automatic flush is triggered. See ZmqMessagePublisher. """ if extra_eager: loop = None @@ -54,6 +55,7 @@ def __init__( loop = LoopManager().default_loop self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}" super().__init__( + EventRecordCodec(), self.socket_name, managed_zmq_context, loop=loop, diff --git a/src/inference_endpoint/async_utils/services/event_logger/__main__.py b/src/inference_endpoint/async_utils/services/event_logger/__main__.py index a7842b74..f57d51a0 100644 --- a/src/inference_endpoint/async_utils/services/event_logger/__main__.py +++ b/src/inference_endpoint/async_utils/services/event_logger/__main__.py @@ -29,10 +29,11 @@ from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext -from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber +from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal from inference_endpoint.core.record import ( EventRecord, + EventRecordCodec, SessionEventType, ) from inference_endpoint.utils.logging import setup_logging @@ -52,7 +53,7 @@ _WRITER_REGISTRY["sql"] = SQLWriter -class EventLoggerService(ZmqEventRecordSubscriber): +class EventLoggerService(ZmqMessageSubscriber[EventRecord]): """Event logger service for logging event records. When SessionEventType.ENDED is received (topic 'session.ended'), the service writes @@ -69,7 +70,7 @@ def __init__( shutdown_event: asyncio.Event | None = None, **kwargs, ): - super().__init__(*args, **kwargs) + super().__init__(EventRecordCodec(), *args, **kwargs) self._shutdown_received = False self._shutdown_event = shutdown_event diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py index c4640bbc..0186a26d 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py @@ -22,11 +22,12 @@ from enum import Enum from inference_endpoint.async_utils.transport.zmq.pubsub import ( - ZmqEventRecordSubscriber, + ZmqMessageSubscriber, ) from inference_endpoint.core.record import ( ErrorEventType, EventRecord, + EventRecordCodec, SampleEventType, SessionEventType, ) @@ -81,7 +82,7 @@ class MetricCounterKey(str, Enum): ) -class MetricsAggregatorService(ZmqEventRecordSubscriber): +class MetricsAggregatorService(ZmqMessageSubscriber[EventRecord]): """Subscribes to EventRecords and computes per-sample metrics in real time. The aggregator is a thin event router. All state management, trigger @@ -99,7 +100,7 @@ def __init__( shutdown_event: asyncio.Event | None = None, **kwargs, ): - super().__init__(*args, **kwargs) + super().__init__(EventRecordCodec(), *args, **kwargs) self._kv_store = kv_store self._tokenize_pool = tokenize_pool self._shutdown_event = shutdown_event diff --git a/src/inference_endpoint/async_utils/transport/protocol.py b/src/inference_endpoint/async_utils/transport/protocol.py index c865eb4e..4527b025 100644 --- a/src/inference_endpoint/async_utils/transport/protocol.py +++ b/src/inference_endpoint/async_utils/transport/protocol.py @@ -26,18 +26,14 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Protocol, runtime_checkable +from typing import Any, Generic, Protocol, TypeVar, runtime_checkable import msgspec from pydantic import BaseModel, ConfigDict, Field -from inference_endpoint.core.record import ( - ErrorEventType, - EventRecord, - decode_event_record, - encode_event_record, -) -from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk +from inference_endpoint.core.types import Query, QueryResult, StreamChunk + +T = TypeVar("T") class TransportConfig(BaseModel, ABC): @@ -235,54 +231,76 @@ def cleanup(self) -> None: pass -class EventRecordPublisher(ABC): - """Abstract base class for publishing event records over a transport.""" +class MessageCodec(Protocol[T]): + """Encode/decode policy for a single message type on the pub/sub layer. + + The codec is the only type-specific surface in the pub/sub stack. All + transport machinery (ZmqMessagePublisher / ZmqMessageSubscriber) operates + on (topic_bytes, payload_bytes); the codec is what binds those bytes to + a concrete Python type T. + """ + + def encode(self, item: T) -> tuple[bytes, bytes]: + """Return (topic, payload). topic must be exactly TOPIC_FRAME_SIZE bytes.""" + ... + + def decode(self, payload: bytes) -> T: + """Decode payload back to T. May raise; the caller routes failures + through on_decode_error.""" + ... + + def on_decode_error(self, payload: bytes, exc: Exception) -> T | None: + """Fallback for malformed payloads. Return a sentinel item or None + to drop the message.""" + ... + + +class MessagePublisher(ABC, Generic[T]): + """Abstract base for publishing typed messages over a transport. + + Subclasses implement send(topic, payload) and close(). publish() is + generic over T via the codec. + """ def __init__( self, + codec: MessageCodec[T], bind_address: str, loop: asyncio.AbstractEventLoop | None = None, ): - """Creates a new EventRecordPublisher. + """Creates a new MessagePublisher. Args: - bind_address: The address to bind the publisher to. This can be an IPC or TCP socket address. - loop: The event loop to use for the publisher. If not provided, it is assumed that the publisher - should always execute eagerly and will be blocking. This means that the call to `.publish()` - will always be called immediately and the current loop and thread will block until the message - is sent. + codec: Encode policy. Required because turning T into wire bytes + is the only type-specific operation; injecting it is the + whole point of generalization. + bind_address: IPC or TCP socket address to bind to. + loop: Event loop to register async writes on. If None, send is + eager/blocking — used by callers that publish before a loop + is running (e.g. service startup). """ + self._codec = codec self.bind_address = bind_address self.loop = loop self.is_closed: bool = False - def publish(self, event_record: EventRecord) -> None: - """Publish the event record on the bound address. - - Args: - event_record: The event record to publish. - """ + def publish(self, item: T) -> None: + """Encode item via the codec and send.""" if self.is_closed: return - - topic, payload = encode_event_record(event_record) + topic, payload = self._codec.encode(item) self.send(topic, payload) @abstractmethod def send(self, topic: bytes, payload: bytes) -> None: - """Send the message via the implemented transport layer. - - Args: - topic: The topic of the message. - payload: The payload of the message. - """ - raise NotImplementedError("Subclasses must implement this method.") + """Send raw frame via the implemented transport layer.""" + raise NotImplementedError def flush(self) -> None: # noqa: B027 — intentionally non-abstract """Force-send any buffered records. Unbuffered implementations need no override. Buffered subclasses - (e.g., ZmqEventRecordPublisher) override this to drain their buffer. + (e.g. ZmqMessagePublisher) override this to drain their buffer. """ @abstractmethod @@ -291,34 +309,39 @@ def close(self) -> None: Implementations must flush any buffered records before closing. """ - raise NotImplementedError("Subclasses must implement this method.") + raise NotImplementedError -class EventRecordSubscriber(ABC): - """Abstract base class for subscribing to event records over a transport.""" +class MessageSubscriber(ABC, Generic[T]): + """Abstract base for subscribing to typed messages over a transport. + + Subclasses implement receive() (raw bytes from socket) and process() + (handle decoded items). _on_readable wires them together using the + codec. + """ def __init__( self, + codec: MessageCodec[T], connect_address: str, loop: asyncio.AbstractEventLoop, topics: list[str] | None = None, ): - """Creates a new EventRecordSubscriber. + """Creates a new MessageSubscriber. - Initializing the subscriber does NOT start processing. The subscriber connects - to the address and subscribes to topics, but the socket reader is only added - when .start() is called. This allows bookkeeping or other setup before - listening. Each subscriber should use its own event loop (e.g. from LoopManager), - not shared with the publisher. - - It is mandatory for subscriber implementations to set the `_fd` attribute to the file - descriptor of the socket to add an asyncio reader to the event loop. + Initializing does NOT start processing — call .start() to add the + socket reader to the loop. Subclasses must set ``self._fd`` to the + socket file descriptor before .start() is called. Args: - connect_address: The address to connect the subscriber to. This can be an IPC or TCP socket address. - loop: The event loop to use for the subscriber (typically a dedicated loop per subscriber). - topics: The topics to subscribe to. If not provided, it is assumed that the subscriber should subscribe to all topics. + codec: Decode policy. Required for the same reason as in + MessagePublisher. + connect_address: IPC or TCP socket address to connect to. + loop: Dedicated loop for this subscriber (typically from + LoopManager — not shared with the publisher). + topics: Topics to subscribe to. None means subscribe to all. """ + self._codec = codec self.connect_address = connect_address self.topics = topics self.loop = loop @@ -328,31 +351,22 @@ def __init__( @abstractmethod def receive(self) -> bytes | None: - """Receive data from the transport. - - Should receive data from the socket and return a bytes object that should be able - to be decoded into an EventRecord. + """Receive a single payload (no topic prefix) from the transport. - If the received data is malformed, this method should return None. - - For the specific case that the transport is not readable or the underlying socket is busy - (such as when an EAGAIN error is raised), this method should raise a StopIteration exception. + Returns None for malformed-but-recognized frames. Raises + StopIteration when the transport has nothing more to deliver right + now (EAGAIN). """ - raise NotImplementedError("Subclasses must implement this method.") + raise NotImplementedError @abstractmethod - async def process(self, records: list[EventRecord]) -> None: - """Process a list of EventRecords. - - Called asynchronously (scheduled via create_task) so that heavy work does not - block the socket read path. Implementations should be async. - """ - raise NotImplementedError("Subclasses must implement this method.") + async def process(self, items: list[T]) -> None: + """Handle a batch of decoded items. Called as an asyncio task so + heavy work does not block the socket read path.""" + raise NotImplementedError def close(self) -> None: - """Close the subscriber and release resources (e.g. remove reader, close socket). - Should be idempotent; safe to call multiple times. Call when the session has ended. - """ + """Close the subscriber. Idempotent.""" if self.loop is not None and self._fd is not None: try: self.loop.remove_reader(self._fd) @@ -361,46 +375,32 @@ def close(self) -> None: pass def _on_readable(self) -> None: - """Drain socket, decode records, and schedule process() as an async task.""" + """Drain socket, decode via codec, and schedule process().""" if self.is_closed: return - records: list[EventRecord] = [] + items: list[T] = [] try: while True: payload = self.receive() if payload is None: continue - - # Attempt decode try: - event_record = decode_event_record(payload) + items.append(self._codec.decode(payload)) except msgspec.DecodeError as e: - event_record = EventRecord( - event_type=ErrorEventType.GENERIC, - data=ErrorData( - error_type="msgspec.DecodeError", - error_message=str(e), - ), - ) - records.append(event_record) + fallback = self._codec.on_decode_error(payload, e) + if fallback is not None: + items.append(fallback) except StopIteration: - # No more messages to receive right now pass finally: - if records: - # Schedule process() so it does not block the socket read path - self.loop.create_task(self.process(records)) + if items: + self.loop.create_task(self.process(items)) def start(self) -> None: - """Start the subscriber: add the socket reader to the loop and begin processing. - - Call this after any setup (e.g. when the session is about to start). Before - start() is called, no messages are received. - """ + """Add the socket reader to the loop and begin processing.""" if self._fd is None: raise ValueError("Subscriber not initialized with a file descriptor") - self.loop.add_reader(self._fd, self._on_readable) @@ -410,6 +410,7 @@ def start(self) -> None: "SenderTransport", "WorkerConnector", "WorkerPoolTransport", - "EventRecordPublisher", - "EventRecordSubscriber", + "MessageCodec", + "MessagePublisher", + "MessageSubscriber", ] diff --git a/src/inference_endpoint/async_utils/transport/zmq/pubsub.py b/src/inference_endpoint/async_utils/transport/zmq/pubsub.py index 9463356b..5f0ade89 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/pubsub.py +++ b/src/inference_endpoint/async_utils/transport/zmq/pubsub.py @@ -17,14 +17,17 @@ import logging import os from collections import deque +from typing import TypeVar from urllib.parse import urlparse +import msgspec import msgspec.msgpack import zmq from inference_endpoint.async_utils.transport.protocol import ( - EventRecordPublisher, - EventRecordSubscriber, + MessageCodec, + MessagePublisher, + MessageSubscriber, ) from inference_endpoint.core.record import BATCH_TOPIC, TOPIC_FRAME_SIZE @@ -32,12 +35,14 @@ logger = logging.getLogger(__name__) +T = TypeVar("T") + _batch_encoder = msgspec.msgpack.Encoder() _batch_decoder = msgspec.msgpack.Decoder(type=list[bytes]) -class ZmqEventRecordPublisher(EventRecordPublisher): - """ZMQ PUB socket publisher with batched sending. +class ZmqMessagePublisher(MessagePublisher[T]): + """ZMQ PUB socket publisher generic over message type T. Records are buffered in memory and flushed as a single msgpack-encoded batch when the buffer reaches ``send_threshold``. This reduces syscalls @@ -46,36 +51,67 @@ class ZmqEventRecordPublisher(EventRecordPublisher): The ``send_threshold`` is the *minimum* number of records in the buffer before an automatic flush is triggered. There is no maximum — records accumulate until the threshold is reached or ``flush()``/``close()`` - is called explicitly. Callers that need immediate delivery (e.g., + is called explicitly. Callers that need immediate delivery (e.g. session control events) should call ``flush()`` after publishing. + Setting ``send_threshold=1`` effectively disables batching: every + publish is sent immediately as a single record without batch overhead + via the ``len(buf) == 1`` fast path in ``_flush_batch``. Batching protocol: - Batched messages use ``BATCH_TOPIC`` as the ZMQ routing prefix. - The payload is ``msgpack(list[bytes])`` where each element is a pre-encoded record payload (no per-record topic prefix). - Subscribers unpack the list and yield payloads in insertion order. - - Per-record topics are omitted because EventRecord already contains - event_type for dispatching. + - Per-record topics are omitted because the codec-decoded item + already carries any dispatch information. - Single-record flushes use the record's own topic (no batch overhead). """ def __init__( self, + codec: MessageCodec[T], path: str, zmq_context: ManagedZMQContext, loop: asyncio.AbstractEventLoop | None = None, scheme: str = "ipc", send_threshold: int = 1000, + sndhwm: int = 0, + linger: int = -1, ): + """Creates a new ZmqMessagePublisher. + + Args: + codec: Encode policy for T. Required — the only type-specific + surface in this class. + path: IPC path / socket name. Bind-side identity. Required — + each publisher in the system has a distinct path. + zmq_context: ManagedZMQContext owning socket lifetime and IPC + file cleanup. Required — sharing one context across + publishers is the existing pattern. + loop: Event loop for async writer registration. None means + eager/blocking send (used by callers that publish before a + loop is running). + scheme: ipc:// vs tcp://. Default ipc matches all current + callers; tcp is an escape hatch. + send_threshold: Minimum buffered records before automatic batch + flush. Set to 1 to disable batching (e.g. one snapshot per + tick, where batching adds latency). + sndhwm: ZMQ SNDHWM. 0 (default, unlimited) for delivery + guarantees. A small value (e.g. 4) makes the writer drop + instead of stall when subscribers are slow — appropriate + for telemetry-style senders. + linger: ZMQ LINGER on close. -1 (default, wait forever) + guarantees buffered records are sent. 0 drops in-flight on + close — appropriate when the caller flushes synchronously + before close. + """ self._socket = zmq_context.socket(zmq.PUB) - - # Guarantee delivery: unlimited send buffer, wait on close. - self._socket.setsockopt(zmq.SNDHWM, 0) - self._socket.setsockopt(zmq.LINGER, -1) + self._socket.setsockopt(zmq.SNDHWM, sndhwm) + self._socket.setsockopt(zmq.LINGER, linger) self._socket.setsockopt(zmq.IMMEDIATE, 1) bind_address = zmq_context.bind(self._socket, path, scheme) - super().__init__(bind_address, loop) + super().__init__(codec, bind_address, loop) self.bind_path = path logger.info(f"Publisher bound to {self.bind_address}") @@ -97,10 +133,10 @@ def pending_count(self) -> int: return len(self._pending) def send(self, topic: bytes, payload: bytes) -> None: - """Buffer a record for batched sending. + """Buffer a payload for batched sending. Only the payload is buffered — topics are not stored per-record - since the EventRecord already contains event_type for dispatching. + since the codec-decoded item already carries any dispatch info. When the buffer reaches ``send_threshold``, payloads are encoded as a single msgpack list and sent with BATCH_TOPIC. For a single record, a direct send with the record's own topic is used instead. @@ -136,8 +172,7 @@ def _flush_batch(self) -> None: else: # Multiple records: encode payloads as msgpack list[bytes], # prefix with BATCH_TOPIC for routing. Individual topics are - # not included — subscribers decode EventRecord.event_type - # from the payload for dispatching. + # not included — codec-decoded items carry their own dispatch. frame = BATCH_TOPIC + _batch_encoder.encode(buf) try: @@ -228,31 +263,54 @@ def close(self) -> None: pass -class ZmqEventRecordSubscriber(EventRecordSubscriber): - """ZMQ SUB socket subscriber that handles both single and batched messages. +class ZmqMessageSubscriber(MessageSubscriber[T]): + """ZMQ SUB socket subscriber generic over message type T. Automatically subscribes to BATCH_TOPIC in addition to any explicit topic subscriptions. Batched messages are unpacked into individual - records and yielded in order via ``receive()``. - - Note on topic filtering with batches: batched messages contain records - of mixed event types. Subscribers with specific topic filters will - receive ALL event types from batches, not just their filtered topics. - Per-record filtering must be done in application code (e.g., checking - ``EventRecord.event_type`` after decode). This is acceptable because - the decode cost (~0.6us/record) is negligible compared to processing. + payloads and yielded in order via ``receive()``; the codec then + decodes each payload to T. + + Note on topic filtering with batches: batched messages contain + payloads of mixed types. Subscribers with specific topic filters will + receive ALL types from batches, not just their filtered topics. + Per-payload filtering must be done in application code (e.g. by + inspecting the decoded item). This is acceptable because the decode + cost is negligible compared to processing. """ def __init__( self, + codec: MessageCodec[T], path: str, zmq_context: ManagedZMQContext, loop: asyncio.AbstractEventLoop, topics: list[str] | None = None, scheme: str = "ipc", + conflate: bool = False, + rcvhwm: int = 0, ): + """Creates a new ZmqMessageSubscriber. + + Args: + codec: Decode policy for T. Required. + path: IPC path / socket name to connect to. + zmq_context: Managed context. Reusing one context across + multiple subscribers is fine. + loop: Dedicated loop for this subscriber. + topics: Topics to subscribe to. None means subscribe to all. + scheme: ipc:// vs tcp://. + conflate: ZMQ_CONFLATE. False (default) keeps every message; + appropriate for EventRecord and for the final-snapshot + consumer. True keeps only the latest message; appropriate + for a TUI rendering live snapshots, where stale ticks have + no value. + rcvhwm: ZMQ RCVHWM. 0 (default) is unlimited. + """ self._socket = zmq_context.socket(zmq.SUB) - self._socket.setsockopt(zmq.RCVHWM, 0) + self._socket.setsockopt(zmq.RCVHWM, rcvhwm) + if conflate: + self._socket.setsockopt(zmq.CONFLATE, 1) if not topics: self._socket.setsockopt(zmq.SUBSCRIBE, b"") @@ -263,7 +321,7 @@ def __init__( self._socket.setsockopt(zmq.SUBSCRIBE, BATCH_TOPIC) connect_address = zmq_context.connect(self._socket, path, scheme) - super().__init__(connect_address, loop, topics) + super().__init__(codec, connect_address, loop, topics) self.connect_path = path logger.info(f"Subscriber connected to {self.connect_address}") @@ -273,7 +331,7 @@ def __init__( # Reader is added in .start(); do not add here. def receive(self) -> bytes | None: - """Receive a single record payload. + """Receive a single payload. If a batched message was received, individual payloads are buffered and returned one at a time in insertion order. @@ -291,8 +349,6 @@ def receive(self) -> bytes | None: raise StopIteration from e # Batch message: BATCH_TOPIC prefix + msgpack list[bytes] of payloads. - # Individual payloads do not have topic prefixes — EventRecord.event_type - # is used for dispatching instead. if raw[:TOPIC_FRAME_SIZE] == BATCH_TOPIC: batch_data = raw[TOPIC_FRAME_SIZE:] try: diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index adac5c8f..a78edd49 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -156,22 +156,39 @@ class EventRecord(msgspec.Struct, kw_only=True, frozen=True, gc=False): # type: data: OUTPUT_TYPE | PromptData | ErrorData | None = None -_ENCODER = msgspec.msgpack.Encoder(enc_hook=EventType.encode_hook) -_DECODER = msgspec.msgpack.Decoder(type=EventRecord, dec_hook=EventType.decode_hook) - +class EventRecordCodec: + """MessageCodec[EventRecord] — binds the pub/sub layer to EventRecord wire format. + + Implements the structural ``MessageCodec`` Protocol from + ``inference_endpoint.async_utils.transport.protocol`` without importing it + (avoids a transport→core back-import). Decode failures are wrapped in + ``ErrorEventType.GENERIC`` so downstream consumers see a recognizable + record rather than a silently dropped payload. + + The encoder and decoder are class-level singletons: msgspec's dispatch + tables are stateless after construction, so one instance per process + suffices. + """ -def encode_event_record(event_record: EventRecord) -> tuple[bytes, bytes]: - """Encodes an EventRecord into a tuple of (topic_bytes_padded, payload_bytes). + __slots__ = () - Args: - event_record: The EventRecord to encode. + _ENCODER: ClassVar = msgspec.msgpack.Encoder(enc_hook=EventType.encode_hook) + _DECODER: ClassVar = msgspec.msgpack.Decoder( + type=EventRecord, dec_hook=EventType.decode_hook + ) - Returns: - A tuple of (topic_bytes_padded, payload_bytes). - """ - # MyPy doesn't recognize custom attributes defined by __new__ in the metaclass. - return event_record.event_type.topic_bytes_padded, _ENCODER.encode(event_record) # type: ignore[attr-defined] + def encode(self, item: EventRecord) -> tuple[bytes, bytes]: + # MyPy doesn't recognize custom attributes defined by __new__ in the metaclass. + return item.event_type.topic_bytes_padded, self._ENCODER.encode(item) # type: ignore[attr-defined] + def decode(self, payload: bytes) -> EventRecord: + return self._DECODER.decode(payload) -def decode_event_record(payload: bytes) -> EventRecord: - return _DECODER.decode(payload) + def on_decode_error(self, payload: bytes, exc: Exception) -> EventRecord: + return EventRecord( + event_type=ErrorEventType.GENERIC, + data=ErrorData( + error_type=type(exc).__name__, + error_message=str(exc), + ), + ) diff --git a/src/inference_endpoint/load_generator/session.py b/src/inference_endpoint/load_generator/session.py index 1c8ad992..1d0a63ec 100644 --- a/src/inference_endpoint/load_generator/session.py +++ b/src/inference_endpoint/load_generator/session.py @@ -123,7 +123,7 @@ def shutdown(self) -> None: ... # --------------------------------------------------------------------------- -# EventRecordPublisher protocol +# EventPublisher protocol # --------------------------------------------------------------------------- diff --git a/tests/unit/async_utils/test_event_publisher.py b/tests/unit/async_utils/test_event_publisher.py index 411e21c4..d65c1274 100644 --- a/tests/unit/async_utils/test_event_publisher.py +++ b/tests/unit/async_utils/test_event_publisher.py @@ -33,13 +33,13 @@ from inference_endpoint.async_utils.event_publisher import EventPublisherService from inference_endpoint.async_utils.loop_manager import LoopManager from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext -from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber +from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber from inference_endpoint.core.record import ( TOPIC_FRAME_SIZE, EventRecord, + EventRecordCodec, SampleEventType, SessionEventType, - decode_event_record, ) from inference_endpoint.core.types import TextModelOutput @@ -52,7 +52,7 @@ # ============================================================================= -class CollectingEventSubscriber(ZmqEventRecordSubscriber): +class CollectingEventSubscriber(ZmqMessageSubscriber[EventRecord]): """Subscriber that appends all received EventRecords to a list for tests. Uses its own event loop (passed in). Call .start() to begin receiving. @@ -61,7 +61,7 @@ class CollectingEventSubscriber(ZmqEventRecordSubscriber): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super().__init__(EventRecordCodec(), *args, **kwargs) self.received: list[EventRecord] = [] self._wait_event: asyncio.Event | None = None self._wait_count: int | None = None @@ -165,7 +165,7 @@ async def test_publish_sends_data_on_ipc_socket( topic_bytes = frame[:TOPIC_FRAME_SIZE].rstrip(b"\x00") payload = frame[TOPIC_FRAME_SIZE:] assert topic_bytes == b"session.started" - rec = decode_event_record(bytes(payload)) + rec = EventRecordCodec().decode(bytes(payload)) assert rec.event_type.value == SessionEventType.STARTED.value assert rec.data is None # Socket is closed by ManagedZMQContext.cleanup() in ev_pub_zmq_context fixture teardown. diff --git a/tests/unit/core/test_record.py b/tests/unit/core/test_record.py index e21984c7..75698eb5 100644 --- a/tests/unit/core/test_record.py +++ b/tests/unit/core/test_record.py @@ -22,14 +22,15 @@ TOPIC_FRAME_SIZE, ErrorEventType, EventRecord, + EventRecordCodec, EventType, SampleEventType, SessionEventType, - decode_event_record, - encode_event_record, ) from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput +_codec = EventRecordCodec() + class TestEventType: def test_category_base_raises_subclasses_return_expected(self): @@ -71,19 +72,19 @@ class TestEncodeEventRecord: def test_returns_tuple_of_topic_bytes_padded_and_payload_bytes_with_valid_msgpack( self, ): - """encode_event_record returns (topic_bytes_padded, payload) for single-frame ZMQ.""" + """EventRecordCodec.encode returns (topic_bytes_padded, payload) for single-frame ZMQ.""" data = TextModelOutput(output="test-output") record = EventRecord( event_type=SampleEventType.ISSUED, sample_uuid="test-uuid", data=data, ) - topic_bytes, payload = encode_event_record(record) + topic_bytes, payload = _codec.encode(record) assert isinstance(topic_bytes, bytes) assert len(topic_bytes) == TOPIC_FRAME_SIZE assert topic_bytes.rstrip(b"\x00") == b"sample.issued" assert isinstance(payload, bytes) - decoded = decode_event_record(payload) + decoded = _codec.decode(payload) assert decoded.sample_uuid == "test-uuid" assert decoded.data == data @@ -95,7 +96,7 @@ def test_topic_bytes_padded_matches_event_type_for_session_sample_error(self): (SampleEventType.COMPLETE, "sample.complete"), (ErrorEventType.GENERIC, "error.generic"), ]: - topic_bytes, _ = encode_event_record(EventRecord(event_type=ev)) + topic_bytes, _ = _codec.encode(EventRecord(event_type=ev)) assert len(topic_bytes) == TOPIC_FRAME_SIZE assert topic_bytes.rstrip(b"\x00") == expected_prefix.encode("utf-8") @@ -106,8 +107,8 @@ def test_session_event_round_trips_with_all_fields(self): event_type=SessionEventType.STARTED, sample_uuid="sess-1", ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SessionEventType.STARTED.topic assert decoded.sample_uuid == "sess-1" assert decoded.data is None @@ -121,8 +122,8 @@ def test_sample_event_round_trips_with_output(self): sample_uuid="sample-42", data=data, ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SampleEventType.COMPLETE.topic assert decoded.sample_uuid == "sample-42" assert decoded.data == data @@ -133,8 +134,8 @@ def test_sample_event_round_trips_with_text_model_output(self): sample_uuid="sample-42", data=TextModelOutput(output="out", reasoning="reason"), ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SampleEventType.COMPLETE.topic assert decoded.sample_uuid == "sample-42" assert isinstance(decoded.data, TextModelOutput) @@ -147,8 +148,8 @@ def test_sample_event_round_trips_with_prompt_data_text(self): sample_uuid="sample-99", data=PromptData(text="What is AI?"), ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SampleEventType.ISSUED.topic assert decoded.sample_uuid == "sample-99" assert isinstance(decoded.data, PromptData) @@ -161,8 +162,8 @@ def test_sample_event_round_trips_with_prompt_data_token_ids(self): sample_uuid="sample-100", data=PromptData(token_ids=(101, 202, 303)), ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SampleEventType.ISSUED.topic assert isinstance(decoded.data, PromptData) assert decoded.data.token_ids == (101, 202, 303) @@ -176,8 +177,8 @@ def test_error_event_round_trips_with_error_data(self): error_message="error details", ), ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == ErrorEventType.LOADGEN.topic assert isinstance(decoded.data, ErrorData) assert decoded.data.error_type == "LoadgenError" @@ -186,8 +187,8 @@ def test_error_event_round_trips_with_error_data(self): def test_record_with_only_event_type_round_trips_with_defaults(self): record = EventRecord(event_type=SessionEventType.ENDED) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.event_type.topic == SessionEventType.ENDED.topic assert decoded.sample_uuid == "" assert decoded.data is None @@ -199,6 +200,6 @@ def test_explicit_timestamp_ns_preserved_round_trip(self): event_type=SampleEventType.ISSUED, timestamp_ns=ts, ) - _, payload = encode_event_record(record) - decoded = decode_event_record(payload) + _, payload = _codec.encode(record) + decoded = _codec.decode(payload) assert decoded.timestamp_ns == ts diff --git a/tests/unit/transport/test_zmq_pool_transport.py b/tests/unit/transport/test_zmq_pool_transport.py index 69c90f10..8c5455c5 100644 --- a/tests/unit/transport/test_zmq_pool_transport.py +++ b/tests/unit/transport/test_zmq_pool_transport.py @@ -27,7 +27,7 @@ import zmq from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.async_utils.transport.zmq.pubsub import ( - ZmqEventRecordPublisher, + ZmqMessagePublisher, ) from inference_endpoint.async_utils.transport.zmq.ready_check import ( ReadyCheckReceiver, @@ -36,6 +36,7 @@ ZMQTransportConfig, ZmqWorkerPoolTransport, ) +from inference_endpoint.core.record import EventRecordCodec @pytest.fixture(autouse=True) @@ -122,7 +123,9 @@ async def test_pool(self, num_workers: int, create_publisher: bool): dummy = None if create_publisher: sid = uuid.uuid4().hex[:8] - publisher = ZmqEventRecordPublisher(f"ev_pub_{sid}", zmq_ctx, loop=loop) + publisher = ZmqMessagePublisher( + EventRecordCodec(), f"ev_pub_{sid}", zmq_ctx, loop=loop + ) else: # Baseline: bind an unrelated PUB socket so the context is non-empty. dummy = zmq_ctx.socket(zmq.PUB) From faaff7bd5da99428cf1fdb998c77777f904344a6 Mon Sep 17 00:00:00 2001 From: Alice Cheng Date: Tue, 28 Apr 2026 17:22:36 -0700 Subject: [PATCH 2/2] refactor: keep MessageSubscriber catch generic; narrow per-codec Per Gemini review on PR #300: catching only msgspec.DecodeError in MessageSubscriber._on_readable bakes the codec implementation into the supposedly-generic base class. A future codec backed by json, pickle, etc. raises different exception types and would bypass on_decode_error, crashing the reader. - protocol.py: widen the catch back to Exception so the base class makes no assumption about which decoder library a codec uses; drop the now- unused msgspec import. - record.py: tighten EventRecordCodec.on_decode_error to wrap only msgspec.DecodeError and re-raise other exceptions. Preserves the previous behavior parity (only malformed-payload errors become ErrorEventType.GENERIC records; programmer bugs in the decode path still surface). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/inference_endpoint/async_utils/transport/protocol.py | 8 ++++++-- src/inference_endpoint/core/record.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/inference_endpoint/async_utils/transport/protocol.py b/src/inference_endpoint/async_utils/transport/protocol.py index 4527b025..f4dd907c 100644 --- a/src/inference_endpoint/async_utils/transport/protocol.py +++ b/src/inference_endpoint/async_utils/transport/protocol.py @@ -28,7 +28,6 @@ from contextlib import asynccontextmanager from typing import Any, Generic, Protocol, TypeVar, runtime_checkable -import msgspec from pydantic import BaseModel, ConfigDict, Field from inference_endpoint.core.types import Query, QueryResult, StreamChunk @@ -387,7 +386,12 @@ def _on_readable(self) -> None: continue try: items.append(self._codec.decode(payload)) - except msgspec.DecodeError as e: + except Exception as e: # noqa: BLE001 — codec decides handling + # The base class is codec-agnostic: different codec + # implementations raise different exception types + # (msgspec.DecodeError, json.JSONDecodeError, ValueError, + # etc.). The codec's on_decode_error decides whether to + # return a fallback item, drop the message, or re-raise. fallback = self._codec.on_decode_error(payload, e) if fallback is not None: items.append(fallback) diff --git a/src/inference_endpoint/core/record.py b/src/inference_endpoint/core/record.py index a78edd49..da35389e 100644 --- a/src/inference_endpoint/core/record.py +++ b/src/inference_endpoint/core/record.py @@ -185,6 +185,11 @@ def decode(self, payload: bytes) -> EventRecord: return self._DECODER.decode(payload) def on_decode_error(self, payload: bytes, exc: Exception) -> EventRecord: + # Only wrap genuine wire-format failures (malformed payload). Other + # exceptions indicate a bug somewhere in the decode path and should + # propagate so they aren't silently swallowed into an EventRecord. + if not isinstance(exc, msgspec.DecodeError): + raise exc return EventRecord( event_type=ErrorEventType.GENERIC, data=ErrorData(