Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
*.py[codz]
*$py.class
.agents
.claude

# C extensions
*.so
Expand Down
29 changes: 2 additions & 27 deletions src/openrtc/core/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pathlib import Path
from typing import Any, Literal

from livekit.agents import Agent, AgentServer, AgentSession, JobContext, JobProcess, cli
from livekit.agents import Agent, AgentServer, AgentSession, JobContext, cli

from openrtc.core.config import (
AgentConfig,
Expand All @@ -21,6 +21,7 @@
)
from openrtc.core.routing import _resolve_agent_config
from openrtc.core.turn_handling import _build_session_kwargs
from openrtc.execution.prewarm import _prewarm_worker
from openrtc.observability.metrics import (
MetricsStreamEvent,
RuntimeMetricsStore,
Expand Down Expand Up @@ -49,18 +50,6 @@ class _PoolRuntimeState:
metrics: RuntimeMetricsStore = field(default_factory=RuntimeMetricsStore)


def _prewarm_worker(
runtime_state: _PoolRuntimeState,
proc: JobProcess,
) -> None:
"""Load shared runtime assets into ``proc.userdata`` once per worker."""
if not runtime_state.agents:
raise RuntimeError("Register at least one agent before calling run().")
silero_module, turn_detector_model = _load_shared_runtime_dependencies()
proc.userdata["vad"] = silero_module.VAD.load()
proc.userdata["turn_detection_factory"] = turn_detector_model


async def _run_universal_session(
runtime_state: _PoolRuntimeState,
ctx: JobContext,
Expand Down Expand Up @@ -430,17 +419,3 @@ def _merge_session_kwargs(
if direct_session_kwargs is not None:
merged_kwargs.update(direct_session_kwargs)
return merged_kwargs


def _load_shared_runtime_dependencies() -> tuple[Any, type[Any]]:
"""Load the optional LiveKit runtime dependencies used during prewarm."""
try:
from livekit.plugins import silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
except ModuleNotFoundError as exc:
raise RuntimeError(
"OpenRTC requires the LiveKit Silero and turn-detector plugins. "
"Reinstall openrtc, or install livekit-agents[silero,turn-detector]."
) from exc

return silero, MultilingualModel
44 changes: 44 additions & 0 deletions src/openrtc/execution/prewarm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Worker prewarm hook.

The function in this module is registered as ``AgentServer.setup_fnc`` and runs
once per worker process before any session starts. It loads the shared runtime
assets (Silero VAD, LiveKit turn-detector model) into ``proc.userdata`` so they
are not re-loaded per session.

Adding a new shared resource means adding it to ``_prewarm_worker``.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from livekit.agents import JobProcess

if TYPE_CHECKING:
from openrtc.core.pool import _PoolRuntimeState


def _prewarm_worker(
runtime_state: _PoolRuntimeState,
proc: JobProcess,
) -> None:
"""Load shared runtime assets into ``proc.userdata`` once per worker."""
if not runtime_state.agents:
raise RuntimeError("Register at least one agent before calling run().")
silero_module, turn_detector_model = _load_shared_runtime_dependencies()
proc.userdata["vad"] = silero_module.VAD.load()
proc.userdata["turn_detection_factory"] = turn_detector_model


def _load_shared_runtime_dependencies() -> tuple[Any, type[Any]]:
"""Load the optional LiveKit runtime dependencies used during prewarm."""
try:
from livekit.plugins import silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
except ModuleNotFoundError as exc:
raise RuntimeError(
"OpenRTC requires the LiveKit Silero and turn-detector plugins. "
"Reinstall openrtc, or install livekit-agents[silero,turn-detector]."
) from exc

return silero, MultilingualModel
9 changes: 5 additions & 4 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import openrtc.core.discovery as discovery_module
import openrtc.core.pool as pool_module
import openrtc.core.serialization as serialization_module
import openrtc.execution.prewarm as prewarm_module
from openrtc import AgentPool


Expand Down Expand Up @@ -468,7 +469,7 @@ class FakeSilero:
VAD = FakeVAD

monkeypatch.setattr(
"openrtc.core.pool._load_shared_runtime_dependencies",
"openrtc.execution.prewarm._load_shared_runtime_dependencies",
lambda: (FakeSilero, FakeTurnDetector),
)
setup_callback(process)
Expand Down Expand Up @@ -843,7 +844,7 @@ def test_prewarm_worker_raises_when_runtime_state_has_no_agents() -> None:
proc = SimpleNamespace(userdata={})

with pytest.raises(RuntimeError, match="Register at least one agent"):
pool_module._prewarm_worker(pool._runtime_state, proc)
prewarm_module._prewarm_worker(pool._runtime_state, proc)


def test_run_universal_session_raises_when_no_agents_registered() -> None:
Expand Down Expand Up @@ -881,7 +882,7 @@ def _import_without_silero(
monkeypatch.setattr(builtins, "__import__", _import_without_silero)

with pytest.raises(RuntimeError, match="silero"):
pool_module._load_shared_runtime_dependencies()
prewarm_module._load_shared_runtime_dependencies()


def test_merge_session_kwargs_skips_direct_when_none() -> None:
Expand All @@ -898,7 +899,7 @@ def test_load_shared_runtime_dependencies_returns_silero_and_turn_detector() ->
pytest.importorskip("livekit.plugins.silero")
pytest.importorskip("livekit.plugins.turn_detector.multilingual")

silero, multilingual = pool_module._load_shared_runtime_dependencies()
silero, multilingual = prewarm_module._load_shared_runtime_dependencies()

assert hasattr(silero, "VAD")
assert multilingual.__name__ == "MultilingualModel"
Loading