Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/inference_endpoint/async_utils/event_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

from inference_endpoint.async_utils.loop_manager import LoopManager
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessagePublisher
from inference_endpoint.core.record import EventRecord, EventRecordCodec


class EventPublisherService(ZmqEventRecordPublisher):
class EventPublisherService(ZmqMessagePublisher[EventRecord]):
"""Publisher for publishing event records over ZMQ PUB socket.

Wraps ZmqEventRecordPublisher with LoopManager integration and
Wraps ZmqMessagePublisher[EventRecord] with LoopManager integration and
auto-generated socket names.
"""

Expand All @@ -44,7 +45,7 @@ def __init__(
synchronization mechanism (e.g., ENDED as a stop signal).
isolated_event_loop: If True, runs on a separate event loop thread.
send_threshold: Minimum number of buffered records before an
automatic flush is triggered. See ZmqEventRecordPublisher.
automatic flush is triggered. See ZmqMessagePublisher.
"""
if extra_eager:
loop = None
Expand All @@ -54,6 +55,7 @@ def __init__(
loop = LoopManager().default_loop
self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}"
super().__init__(
EventRecordCodec(),
self.socket_name,
managed_zmq_context,
loop=loop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

from inference_endpoint.async_utils.loop_manager import LoopManager
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber
from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal
from inference_endpoint.core.record import (
EventRecord,
EventRecordCodec,
SessionEventType,
)
from inference_endpoint.utils.logging import setup_logging
Expand All @@ -52,7 +53,7 @@
_WRITER_REGISTRY["sql"] = SQLWriter


class EventLoggerService(ZmqEventRecordSubscriber):
class EventLoggerService(ZmqMessageSubscriber[EventRecord]):
"""Event logger service for logging event records.

When SessionEventType.ENDED is received (topic 'session.ended'), the service writes
Expand All @@ -69,7 +70,7 @@ def __init__(
shutdown_event: asyncio.Event | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(EventRecordCodec(), *args, **kwargs)
self._shutdown_received = False
self._shutdown_event = shutdown_event

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from enum import Enum

from inference_endpoint.async_utils.transport.zmq.pubsub import (
ZmqEventRecordSubscriber,
ZmqMessageSubscriber,
)
from inference_endpoint.core.record import (
ErrorEventType,
EventRecord,
EventRecordCodec,
SampleEventType,
SessionEventType,
)
Expand Down Expand Up @@ -81,7 +82,7 @@ class MetricCounterKey(str, Enum):
)


class MetricsAggregatorService(ZmqEventRecordSubscriber):
class MetricsAggregatorService(ZmqMessageSubscriber[EventRecord]):
"""Subscribes to EventRecords and computes per-sample metrics in real time.

The aggregator is a thin event router. All state management, trigger
Expand All @@ -99,7 +100,7 @@ def __init__(
shutdown_event: asyncio.Event | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(EventRecordCodec(), *args, **kwargs)
self._kv_store = kv_store
self._tokenize_pool = tokenize_pool
self._shutdown_event = shutdown_event
Expand Down
191 changes: 98 additions & 93 deletions src/inference_endpoint/async_utils/transport/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,13 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, Protocol, runtime_checkable
from typing import Any, Generic, Protocol, TypeVar, runtime_checkable

import msgspec
from pydantic import BaseModel, ConfigDict, Field

from inference_endpoint.core.record import (
ErrorEventType,
EventRecord,
decode_event_record,
encode_event_record,
)
from inference_endpoint.core.types import ErrorData, Query, QueryResult, StreamChunk
from inference_endpoint.core.types import Query, QueryResult, StreamChunk

T = TypeVar("T")


class TransportConfig(BaseModel, ABC):
Expand Down Expand Up @@ -235,54 +230,76 @@ def cleanup(self) -> None:
pass


class EventRecordPublisher(ABC):
"""Abstract base class for publishing event records over a transport."""
class MessageCodec(Protocol[T]):
"""Encode/decode policy for a single message type on the pub/sub layer.

The codec is the only type-specific surface in the pub/sub stack. All
transport machinery (ZmqMessagePublisher / ZmqMessageSubscriber) operates
on (topic_bytes, payload_bytes); the codec is what binds those bytes to
a concrete Python type T.
"""

def encode(self, item: T) -> tuple[bytes, bytes]:
"""Return (topic, payload). topic must be exactly TOPIC_FRAME_SIZE bytes."""
...
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

def decode(self, payload: bytes) -> T:
"""Decode payload back to T. May raise; the caller routes failures
through on_decode_error."""
...
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

def on_decode_error(self, payload: bytes, exc: Exception) -> T | None:
"""Fallback for malformed payloads. Return a sentinel item or None
to drop the message."""
...
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed


class MessagePublisher(ABC, Generic[T]):
"""Abstract base for publishing typed messages over a transport.

Subclasses implement send(topic, payload) and close(). publish() is
generic over T via the codec.
"""

def __init__(
self,
codec: MessageCodec[T],
bind_address: str,
loop: asyncio.AbstractEventLoop | None = None,
):
"""Creates a new EventRecordPublisher.
"""Creates a new MessagePublisher.

Args:
bind_address: The address to bind the publisher to. This can be an IPC or TCP socket address.
loop: The event loop to use for the publisher. If not provided, it is assumed that the publisher
should always execute eagerly and will be blocking. This means that the call to `.publish()`
will always be called immediately and the current loop and thread will block until the message
is sent.
codec: Encode policy. Required because turning T into wire bytes
is the only type-specific operation; injecting it is the
whole point of generalization.
bind_address: IPC or TCP socket address to bind to.
loop: Event loop to register async writes on. If None, send is
eager/blocking — used by callers that publish before a loop
is running (e.g. service startup).
"""
self._codec = codec
self.bind_address = bind_address
self.loop = loop
self.is_closed: bool = False

def publish(self, event_record: EventRecord) -> None:
"""Publish the event record on the bound address.

Args:
event_record: The event record to publish.
"""
def publish(self, item: T) -> None:
"""Encode item via the codec and send."""
if self.is_closed:
return

topic, payload = encode_event_record(event_record)
topic, payload = self._codec.encode(item)
self.send(topic, payload)

@abstractmethod
def send(self, topic: bytes, payload: bytes) -> None:
"""Send the message via the implemented transport layer.

Args:
topic: The topic of the message.
payload: The payload of the message.
"""
raise NotImplementedError("Subclasses must implement this method.")
"""Send raw frame via the implemented transport layer."""
raise NotImplementedError

def flush(self) -> None: # noqa: B027 — intentionally non-abstract
"""Force-send any buffered records.

Unbuffered implementations need no override. Buffered subclasses
(e.g., ZmqEventRecordPublisher) override this to drain their buffer.
(e.g. ZmqMessagePublisher) override this to drain their buffer.
"""

@abstractmethod
Expand All @@ -291,34 +308,39 @@ def close(self) -> None:

Implementations must flush any buffered records before closing.
"""
raise NotImplementedError("Subclasses must implement this method.")
raise NotImplementedError


class EventRecordSubscriber(ABC):
"""Abstract base class for subscribing to event records over a transport."""
class MessageSubscriber(ABC, Generic[T]):
"""Abstract base for subscribing to typed messages over a transport.

Subclasses implement receive() (raw bytes from socket) and process()
(handle decoded items). _on_readable wires them together using the
codec.
"""

def __init__(
self,
codec: MessageCodec[T],
connect_address: str,
loop: asyncio.AbstractEventLoop,
topics: list[str] | None = None,
):
"""Creates a new EventRecordSubscriber.
"""Creates a new MessageSubscriber.

Initializing the subscriber does NOT start processing. The subscriber connects
to the address and subscribes to topics, but the socket reader is only added
when .start() is called. This allows bookkeeping or other setup before
listening. Each subscriber should use its own event loop (e.g. from LoopManager),
not shared with the publisher.

It is mandatory for subscriber implementations to set the `_fd` attribute to the file
descriptor of the socket to add an asyncio reader to the event loop.
Initializing does NOT start processing — call .start() to add the
socket reader to the loop. Subclasses must set ``self._fd`` to the
socket file descriptor before .start() is called.

Args:
connect_address: The address to connect the subscriber to. This can be an IPC or TCP socket address.
loop: The event loop to use for the subscriber (typically a dedicated loop per subscriber).
topics: The topics to subscribe to. If not provided, it is assumed that the subscriber should subscribe to all topics.
codec: Decode policy. Required for the same reason as in
MessagePublisher.
connect_address: IPC or TCP socket address to connect to.
loop: Dedicated loop for this subscriber (typically from
LoopManager — not shared with the publisher).
topics: Topics to subscribe to. None means subscribe to all.
"""
self._codec = codec
self.connect_address = connect_address
self.topics = topics
self.loop = loop
Expand All @@ -328,31 +350,22 @@ def __init__(

@abstractmethod
def receive(self) -> bytes | None:
"""Receive data from the transport.

Should receive data from the socket and return a bytes object that should be able
to be decoded into an EventRecord.
"""Receive a single payload (no topic prefix) from the transport.

If the received data is malformed, this method should return None.

For the specific case that the transport is not readable or the underlying socket is busy
(such as when an EAGAIN error is raised), this method should raise a StopIteration exception.
Returns None for malformed-but-recognized frames. Raises
StopIteration when the transport has nothing more to deliver right
now (EAGAIN).
"""
raise NotImplementedError("Subclasses must implement this method.")
raise NotImplementedError

@abstractmethod
async def process(self, records: list[EventRecord]) -> None:
"""Process a list of EventRecords.

Called asynchronously (scheduled via create_task) so that heavy work does not
block the socket read path. Implementations should be async.
"""
raise NotImplementedError("Subclasses must implement this method.")
async def process(self, items: list[T]) -> None:
"""Handle a batch of decoded items. Called as an asyncio task so
heavy work does not block the socket read path."""
raise NotImplementedError

def close(self) -> None:
"""Close the subscriber and release resources (e.g. remove reader, close socket).
Should be idempotent; safe to call multiple times. Call when the session has ended.
"""
"""Close the subscriber. Idempotent."""
if self.loop is not None and self._fd is not None:
try:
self.loop.remove_reader(self._fd)
Expand All @@ -361,46 +374,37 @@ def close(self) -> None:
pass

def _on_readable(self) -> None:
"""Drain socket, decode records, and schedule process() as an async task."""
"""Drain socket, decode via codec, and schedule process()."""
if self.is_closed:
return

records: list[EventRecord] = []
items: list[T] = []
try:
while True:
payload = self.receive()
if payload is None:
continue

# Attempt decode
try:
event_record = decode_event_record(payload)
except msgspec.DecodeError as e:
event_record = EventRecord(
event_type=ErrorEventType.GENERIC,
data=ErrorData(
error_type="msgspec.DecodeError",
error_message=str(e),
),
)
records.append(event_record)
items.append(self._codec.decode(payload))
except Exception as e: # noqa: BLE001 — codec decides handling
# The base class is codec-agnostic: different codec
# implementations raise different exception types
# (msgspec.DecodeError, json.JSONDecodeError, ValueError,
# etc.). The codec's on_decode_error decides whether to
# return a fallback item, drop the message, or re-raise.
fallback = self._codec.on_decode_error(payload, e)
if fallback is not None:
items.append(fallback)
except StopIteration:
# No more messages to receive right now
pass
finally:
if records:
# Schedule process() so it does not block the socket read path
self.loop.create_task(self.process(records))
if items:
self.loop.create_task(self.process(items))

def start(self) -> None:
"""Start the subscriber: add the socket reader to the loop and begin processing.

Call this after any setup (e.g. when the session is about to start). Before
start() is called, no messages are received.
"""
"""Add the socket reader to the loop and begin processing."""
if self._fd is None:
raise ValueError("Subscriber not initialized with a file descriptor")

self.loop.add_reader(self._fd, self._on_readable)


Expand All @@ -410,6 +414,7 @@ def start(self) -> None:
"SenderTransport",
"WorkerConnector",
"WorkerPoolTransport",
"EventRecordPublisher",
"EventRecordSubscriber",
"MessageCodec",
"MessagePublisher",
"MessageSubscriber",
]
Loading
Loading