Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ dependencies = [
"sentencepiece==0.2.1",
"protobuf==7.34.1",
"openai_harmony==0.0.8",
# HDR Histogram for live percentile/histogram approximations in the
# metrics aggregator (PyPI: hdrhistogram, importable as hdrh.histogram).
"hdrhistogram==0.10.3",
# Color support for cross-platform terminals
"colorama==0.4.6",
# Fix pytz-2024 import warning
Expand Down
10 changes: 6 additions & 4 deletions src/inference_endpoint/async_utils/event_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

from inference_endpoint.async_utils.loop_manager import LoopManager
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordPublisher
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessagePublisher
from inference_endpoint.core.record import EventRecord, EventRecordCodec


class EventPublisherService(ZmqEventRecordPublisher):
class EventPublisherService(ZmqMessagePublisher[EventRecord]):
"""Publisher for publishing event records over ZMQ PUB socket.

Wraps ZmqEventRecordPublisher with LoopManager integration and
Wraps ZmqMessagePublisher[EventRecord] with LoopManager integration and
auto-generated socket names.
"""

Expand All @@ -44,7 +45,7 @@ def __init__(
synchronization mechanism (e.g., ENDED as a stop signal).
isolated_event_loop: If True, runs on a separate event loop thread.
send_threshold: Minimum number of buffered records before an
automatic flush is triggered. See ZmqEventRecordPublisher.
automatic flush is triggered. See ZmqMessagePublisher.
"""
if extra_eager:
loop = None
Expand All @@ -54,6 +55,7 @@ def __init__(
loop = LoopManager().default_loop
self.socket_name = f"ev_pub_{uuid.uuid4().hex[:8]}"
super().__init__(
EventRecordCodec(),
self.socket_name,
managed_zmq_context,
loop=loop,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@

from inference_endpoint.async_utils.loop_manager import LoopManager
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqEventRecordSubscriber
from inference_endpoint.async_utils.transport.zmq.pubsub import ZmqMessageSubscriber
from inference_endpoint.async_utils.transport.zmq.ready_check import send_ready_signal
from inference_endpoint.core.record import (
EventRecord,
EventRecordCodec,
SessionEventType,
)
from inference_endpoint.utils.logging import setup_logging
Expand All @@ -52,7 +53,7 @@
_WRITER_REGISTRY["sql"] = SQLWriter


class EventLoggerService(ZmqEventRecordSubscriber):
class EventLoggerService(ZmqMessageSubscriber[EventRecord]):
"""Event logger service for logging event records.

When SessionEventType.ENDED is received (topic 'session.ended'), the service writes
Expand All @@ -69,7 +70,7 @@ def __init__(
shutdown_event: asyncio.Event | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
super().__init__(EventRecordCodec(), *args, **kwargs)
self._shutdown_received = False
self._shutdown_event = shutdown_event

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from inference_endpoint.utils.logging import setup_logging

from .aggregator import MetricsAggregatorService
from .kv_store import BasicKVStore
from .publisher import MetricsPublisher
from .registry import MetricsRegistry
from .snapshot import MetricsSnapshotCodec
from .token_metrics import TokenizePool


Expand All @@ -44,13 +46,37 @@ async def main() -> None:
"--socket-name",
type=str,
required=True,
help="Socket name within socket-dir",
help="EventRecord PUB socket name within socket-dir to subscribe to",
)
parser.add_argument(
"--metrics-dir",
"--metrics-socket",
type=str,
required=True,
help="Directory for mmap-backed metric files (created by the parent process)",
help="IPC socket name (within socket-dir) for the metrics PUB output",
)
parser.add_argument(
"--metrics-output-dir",
type=Path,
required=True,
help="Directory for the final-snapshot disk fallback (created if missing)",
)
parser.add_argument(
"--refresh-hz",
type=float,
default=4.0,
help="Live snapshot publish rate (default: 4.0)",
)
parser.add_argument(
"--hdr-sig-figs",
type=int,
default=3,
help="HDR Histogram significant figures (default: 3)",
)
parser.add_argument(
"--n-histogram-buckets",
type=int,
default=30,
help="Number of dense histogram buckets per series (default: 30)",
)
parser.add_argument(
"--tokenizer",
Expand Down Expand Up @@ -85,7 +111,9 @@ async def main() -> None:
args = parser.parse_args()
setup_logging(level="INFO")

metrics_dir = Path(args.metrics_dir)
metrics_output_dir: Path = args.metrics_output_dir
metrics_output_dir.mkdir(parents=True, exist_ok=True)

shutdown_event = asyncio.Event()
loop = LoopManager().default_loop

Expand All @@ -102,14 +130,25 @@ async def main() -> None:
pool_cm as pool,
ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx,
):
kv_store = BasicKVStore(metrics_dir)
registry = MetricsRegistry()
publisher = MetricsPublisher(
MetricsSnapshotCodec(),
zmq_ctx,
args.metrics_socket,
loop,
fallback_path=metrics_output_dir / "final_snapshot.msgpack",
)
try:
aggregator = MetricsAggregatorService(
args.socket_name,
zmq_ctx,
loop,
topics=None,
kv_store=kv_store,
registry=registry,
publisher=publisher,
refresh_hz=args.refresh_hz,
sig_figs=args.hdr_sig_figs,
n_histogram_buckets=args.n_histogram_buckets,
tokenize_pool=pool,
streaming=args.streaming,
shutdown_event=shutdown_event,
Expand All @@ -121,7 +160,7 @@ async def main() -> None:

await shutdown_event.wait()
finally:
kv_store.close()
publisher.close()


if __name__ == "__main__":
Expand Down
Loading
Loading