From 8deff5c954194ec708824762dc6ed278fa7c2db0 Mon Sep 17 00:00:00 2001 From: Gezi-lzq Date: Sun, 17 May 2026 16:08:25 +0800 Subject: [PATCH] feat: add turn admission hook --- src/bub/__init__.py | 17 +- src/bub/channels/manager.py | 165 +++++++++++- src/bub/framework.py | 40 ++- src/bub/hookspecs.py | 14 ++ src/bub/turn_admission.py | 154 ++++++++++++ tests/test_channels.py | 236 +++++++++++++++++- tests/test_framework.py | 49 ++++ .../src/content/docs/docs/reference/hooks.mdx | 1 + .../src/content/docs/docs/reference/types.mdx | 59 ++++- .../docs/zh-cn/docs/reference/hooks.mdx | 1 + .../docs/zh-cn/docs/reference/types.mdx | 59 ++++- 11 files changed, 765 insertions(+), 30 deletions(-) create mode 100644 src/bub/turn_admission.py diff --git a/src/bub/__init__.py b/src/bub/__init__.py index 27305efc..e82f0d01 100644 --- a/src/bub/__init__.py +++ b/src/bub/__init__.py @@ -13,8 +13,21 @@ from bub.framework import DEFAULT_HOME, BubFramework from bub.hookspecs import hookimpl from bub.tools import tool - -__all__ = ["BubFramework", "Settings", "config", "ensure_config", "home", "hookimpl", "tool"] +from bub.turn_admission import AdmitAction, AdmitDecision, SteeringHandle, TurnSnapshot + +__all__ = [ + "AdmitAction", + "AdmitDecision", + "BubFramework", + "Settings", + "SteeringHandle", + "TurnSnapshot", + "config", + "ensure_config", + "home", + "hookimpl", + "tool", +] try: __version__ = import_module("bub._version").version diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index a7a26b06..bbae191a 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -15,6 +15,7 @@ from bub.configure import Settings, ensure_config from bub.envelope import content_of, field_of from bub.framework import BubFramework +from bub.turn_admission import AdmitAction, AdmitDecision, SessionTurnController from bub.types import Envelope, MessageHandler from bub.utils import wait_until_stopped @@ -57,7 +58,7 @@ def __init__( else: self._enabled_channels = self._settings.enabled_channels.split(",") self._messages = asyncio.Queue[ChannelMessage]() - self._ongoing_tasks: dict[str, set[asyncio.Task]] = {} + self._session_controllers: dict[str, SessionTurnController] = {} self._session_handlers: dict[str, MessageHandler] = {} async def on_receive(self, message: ChannelMessage) -> None: @@ -117,7 +118,14 @@ def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> return channel.stream_events(message, stream) async def quit(self, session_id: str) -> None: - tasks = self._ongoing_tasks.pop(session_id, set()) + controller = self._session_controllers.get(session_id) + if controller is None: + self.framework.clear_steering(session_id) + logger.info(f"channel.manager quit session_id={session_id}, cancelled 0 tasks") + return + controller.clear_pending() + controller.steering.drain_nowait() + tasks = set(controller.active_tasks) current_task = asyncio.current_task() cancelled_count = 0 for task in tasks: @@ -127,6 +135,8 @@ async def quit(self, session_id: str) -> None: with contextlib.suppress(asyncio.CancelledError): await task cancelled_count += 1 + controller.active_tasks.difference_update(task for task in tasks if task is not current_task) + self._drop_empty_controller(session_id) logger.info(f"channel.manager quit session_id={session_id}, cancelled {cancelled_count} tasks") def enabled_channels(self) -> list[Channel]: @@ -137,15 +147,132 @@ def enabled_channels(self) -> list[Channel]: channel for name, channel in self._channels.items() if name in self._enabled_channels and channel.enabled ] + def _controller(self, session_id: str) -> SessionTurnController: + controller = self._session_controllers.get(session_id) + if controller is None: + controller = SessionTurnController(session_id=session_id, steering=self.framework.steering(session_id)) + self._session_controllers[session_id] = controller + return controller + + def _drop_empty_controller(self, session_id: str) -> None: + controller = self._session_controllers.get(session_id) + if controller is None: + return + if controller.active() or controller.pending_queue or controller.steering.has_messages(): + return + self._session_controllers.pop(session_id, None) + self.framework.clear_steering(session_id) + def _on_task_done(self, session_id: str, task: asyncio.Task) -> None: if task.cancelled(): logger.info("channel.manager task cancelled session_id={}", session_id) else: task.exception() # to log any exception - tasks = self._ongoing_tasks.get(session_id, set()) - tasks.discard(task) - if not tasks: - self._ongoing_tasks.pop(session_id, None) + controller = self._session_controllers.get(session_id) + if controller is None: + return + controller.active_tasks.discard(task) + if not controller.active(): + controller.promote_steering_to_pending() + self._schedule_pending(session_id) + self._drop_empty_controller(session_id) + + async def _admit_message(self, message: ChannelMessage) -> bool: + try: + session_id = await self._resolve_message_session(message) + except Exception as exc: + logger.exception("channel.manager resolve_session failed") + await self.framework._hook_runtime.notify_error(stage="resolve_session", error=exc, message=message) + return False + controller = self._controller(session_id) + try: + decision = await self.framework.admit_message( + session_id=session_id, + message=message, + turn=controller.snapshot(), + ) + except Exception as exc: + logger.exception("channel.manager admission hook failed") + await self.framework._hook_runtime.notify_error(stage="admit_message", error=exc, message=message) + return True + if decision is None: + self._drop_empty_controller(session_id) + return True + admitted = await self._apply_admission_decision(controller, message, decision) + if not admitted or not controller.active(): + self._drop_empty_controller(session_id) + return admitted + + async def _apply_admission_decision( + self, + controller: SessionTurnController, + message: ChannelMessage, + decision: AdmitDecision, + ) -> bool: + action = _normalize_admit_action(decision.action) + if action == AdmitAction.PROCESS: + return True + if action == AdmitAction.DROP: + logger.info( + "channel.manager admission drop session_id={} reason={}", + message.session_id, + decision.reason, + ) + return False + if action == AdmitAction.WAIT: + return self._queue_pending(controller, message, decision.reason) + if action == AdmitAction.STEER: + if controller.active() and controller.steering.put_nowait(message): + logger.info( + "channel.manager admission steer session_id={} reason={}", + message.session_id, + decision.reason, + ) + return False + return self._queue_pending(controller, message, decision.reason) + logger.warning("channel.manager admission unknown action={} session_id={}", decision.action, message.session_id) + return True + + def _queue_pending( + self, + controller: SessionTurnController, + message: ChannelMessage, + reason: str | None, + ) -> bool: + if not controller.active(): + return True + controller.add_pending(message) + logger.info( + "channel.manager admission wait session_id={} pending_count={} reason={}", + message.session_id, + len(controller.pending_queue), + reason, + ) + return False + + def _schedule_message(self, message: ChannelMessage) -> asyncio.Task: + controller = self._controller(message.session_id) + task = asyncio.create_task(self._run_message(message)) + task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) + controller.active_tasks.add(task) + return task + + def _schedule_pending(self, session_id: str) -> None: + controller = self._session_controllers.get(session_id) + if controller is None or controller.active(): + return + message = controller.pop_pending() + if message is not None: + self._schedule_message(message) + + async def _resolve_message_session(self, message: ChannelMessage) -> str: + session_id = await self.framework.resolve_session(message) + message.session_id = session_id + setattr(message, "_runtime_session_id", session_id) # noqa: B010 + return session_id + + async def _run_message(self, message: ChannelMessage) -> None: + await self.framework.process_inbound(message, self._stream_output) async def listen_and_run(self) -> None: stop_event = asyncio.Event() @@ -157,9 +284,9 @@ async def listen_and_run(self) -> None: try: while True: message = await wait_until_stopped(self._messages.get(), stop_event) - task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output)) - task.add_done_callback(functools.partial(self._on_task_done, message.session_id)) - self._ongoing_tasks.setdefault(message.session_id, set()).add(task) + if not await self._admit_message(message): + continue + self._schedule_message(message) except asyncio.CancelledError: logger.info("channel.manager received shutdown signal") except Exception: @@ -172,13 +299,27 @@ async def listen_and_run(self) -> None: async def shutdown(self) -> None: count = 0 - for tasks in self._ongoing_tasks.values(): - for task in tasks: + session_ids = list(self._session_controllers) + for controller in list(self._session_controllers.values()): + controller.clear_pending() + controller.steering.drain_nowait() + for task in set(controller.active_tasks): task.cancel() with contextlib.suppress(asyncio.CancelledError): await task count += 1 - self._ongoing_tasks.clear() + self._session_controllers.clear() + for session_id in session_ids: + self.framework.clear_steering(session_id) logger.info(f"channel.manager cancelled {count} in-flight tasks") for channel in self.enabled_channels(): await channel.stop() + + +def _normalize_admit_action(action: AdmitAction | str) -> AdmitAction | str: + if isinstance(action, AdmitAction): + return action + try: + return AdmitAction(action) + except ValueError: + return action diff --git a/src/bub/framework.py b/src/bub/framework.py index 14e9bbd5..c5dd5228 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Iterator from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import pluggy import typer @@ -20,6 +20,7 @@ from bub.envelope import content_of, field_of, unpack_batch from bub.hook_runtime import _SKIP_VALUE, HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs +from bub.turn_admission import AdmitDecision, SteeringBuffer, SteeringHandle, TurnSnapshot from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult if TYPE_CHECKING: @@ -48,6 +49,7 @@ def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None: self._hook_runtime = HookRuntime(self._plugin_manager) self._plugin_status: dict[str, PluginStatus] = {} self._outbound_router: OutboundChannelRouter | None = None + self._steering_handles: dict[str, SteeringHandle] = {} self._tape_store: TapeStore | AsyncTapeStore | None = None configure.load(self.config_file) @@ -109,12 +111,10 @@ async def process_inbound(self, inbound: Envelope, stream_output: bool = False) """Run one inbound message through hooks and return turn result.""" try: - session_id = await self._hook_runtime.call_first( - "resolve_session", message=inbound - ) or self._default_session_id(inbound) + session_id = await self.resolve_session(inbound) if isinstance(inbound, dict): inbound.setdefault("session_id", session_id) - state = {"_runtime_workspace": str(self.workspace)} + state = {"_runtime_workspace": str(self.workspace), "_runtime_steering": self.steering(session_id)} for hook_state in reversed( await self._hook_runtime.call_many("load_state", message=inbound, session_id=session_id) ): @@ -146,6 +146,15 @@ async def process_inbound(self, inbound: Envelope, stream_output: bool = False) await self._hook_runtime.notify_error(stage="turn", error=exc, message=inbound) raise + async def resolve_session(self, message: Envelope) -> str: + """Resolve the canonical session id for a message.""" + + runtime_session_id = field_of(message, "_runtime_session_id") + if runtime_session_id is not None: + return str(runtime_session_id) + resolved = await self._hook_runtime.call_first("resolve_session", message=message) + return str(resolved or self._default_session_id(message)) + async def _run_model( self, inbound: Envelope, @@ -207,6 +216,27 @@ async def quit_via_router(self, session_id: str) -> None: if self._outbound_router is not None: await self._outbound_router.quit(session_id) + async def admit_message(self, *, session_id: str, message: Envelope, turn: TurnSnapshot) -> AdmitDecision | None: + return cast( + "AdmitDecision | None", + await self._hook_runtime.call_first( + "admit_message", + session_id=session_id, + message=message, + turn=turn, + ), + ) + + def steering(self, session_id: str) -> SteeringHandle: + handle = self._steering_handles.get(session_id) + if handle is None: + handle = SteeringHandle(session_id=session_id, buffer=SteeringBuffer()) + self._steering_handles[session_id] = handle + return handle + + def clear_steering(self, session_id: str) -> None: + self._steering_handles.pop(session_id, None) + @staticmethod def _default_session_id(message: Envelope) -> str: session_id = field_of(message, "session_id") diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 000aa9c5..35d9b1b6 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -8,6 +8,7 @@ from republic import AsyncStreamEvents, AsyncTapeStore, TapeContext from republic.tape import TapeStore +from bub.turn_admission import AdmitDecision, TurnSnapshot from bub.types import Envelope, MessageHandler, State if TYPE_CHECKING: @@ -107,3 +108,16 @@ def provide_channels(self, message_handler: MessageHandler) -> list[Channel]: def build_tape_context(self) -> TapeContext: """Build a tape context for the current session, to be used to build context messages.""" raise NotImplementedError + + @hookspec(firstresult=True) + def admit_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: + """Decide how to handle an inbound channel message for a session. + + Return ``None`` to keep Bub's default concurrent scheduling behavior. + """ + raise NotImplementedError diff --git a/src/bub/turn_admission.py b/src/bub/turn_admission.py new file mode 100644 index 00000000..861b7ef1 --- /dev/null +++ b/src/bub/turn_admission.py @@ -0,0 +1,154 @@ +"""Turn admission primitives for channel message scheduling.""" + +from __future__ import annotations + +import asyncio +from collections import deque +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Literal + +from bub.types import Envelope + + +class AdmitAction(StrEnum): + """Actions an ``admit_message`` hook can return.""" + + PROCESS = "process" + DROP = "drop" + WAIT = "wait" + STEER = "steer" + + +TurnAdmissionAction = AdmitAction | Literal["process", "drop", "wait", "steer"] + + +@dataclass(frozen=True) +class AdmitDecision: + """Decision returned by ``admit_message`` hooks.""" + + action: TurnAdmissionAction + reason: str | None = None + + +@dataclass(frozen=True) +class TurnSnapshot: + """Snapshot of current session turn state exposed to admission hooks.""" + + session_id: str + is_running: bool + running_count: int + pending_count: int + steering_count: int + + +@dataclass +class SteeringBuffer: + """Per-session queue for steering messages offered to active turns.""" + + _queue: asyncio.Queue[Envelope] = field(default_factory=asyncio.Queue, init=False, repr=False) + + def put_nowait(self, message: Envelope) -> bool: + """Append one message.""" + + try: + self._queue.put_nowait(message) + except asyncio.QueueFull: + return False + return True + + @property + def count(self) -> int: + return self._queue.qsize() + + def has_messages(self) -> bool: + return not self._queue.empty() + + def get_nowait(self) -> Envelope | None: + """Return one queued message without waiting.""" + + try: + return self._queue.get_nowait() + except asyncio.QueueEmpty: + return None + + def drain_nowait(self) -> list[Envelope]: + """Return all queued messages without waiting.""" + + messages: list[Envelope] = [] + while True: + message = self.get_nowait() + if message is None: + return messages + messages.append(message) + + +@dataclass +class SteeringHandle: + """Control surface exposed to model hooks through turn state.""" + + session_id: str + buffer: SteeringBuffer + + def put_nowait(self, message: Envelope) -> bool: + return self.buffer.put_nowait(message) + + @property + def count(self) -> int: + return self.buffer.count + + def has_messages(self) -> bool: + return self.buffer.has_messages() + + def get_nowait(self) -> Envelope | None: + """Drain one steering input and acknowledge ownership of it.""" + + return self.buffer.get_nowait() + + def drain_nowait(self) -> list[Envelope]: + """Drain steering input and acknowledge ownership of those messages.""" + + return self.buffer.drain_nowait() + + +@dataclass +class SessionTurnController: + """Per-session runtime queues used by ``ChannelManager``.""" + + session_id: str + steering: SteeringHandle + active_tasks: set[asyncio.Task] = field(default_factory=set) + pending_queue: deque[Envelope] = field(default_factory=deque) + + def active(self) -> set[asyncio.Task]: + return {task for task in self.active_tasks if not task.done()} + + def snapshot(self) -> TurnSnapshot: + running_count = len(self.active()) + return TurnSnapshot( + session_id=self.session_id, + is_running=running_count > 0, + running_count=running_count, + pending_count=len(self.pending_queue), + steering_count=self.steering.count, + ) + + def add_pending(self, message: Envelope) -> bool: + self.pending_queue.append(message) + return True + + def add_pending_left(self, message: Envelope) -> bool: + self.pending_queue.appendleft(message) + return True + + def pop_pending(self) -> Envelope | None: + if not self.pending_queue: + return None + return self.pending_queue.popleft() + + def clear_pending(self) -> None: + self.pending_queue.clear() + + def promote_steering_to_pending(self) -> None: + for message in reversed(self.steering.drain_nowait()): + self.add_pending_left(message) diff --git a/tests/test_channels.py b/tests/test_channels.py index c574766d..301af983 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -16,6 +16,7 @@ from bub.channels.manager import ChannelManager from bub.channels.message import ChannelMessage from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser +from bub.turn_admission import AdmitAction, AdmitDecision, SessionTurnController, SteeringBuffer, SteeringHandle def _load_channel_config( @@ -66,6 +67,11 @@ def __init__(self, channels: dict[str, FakeChannel]) -> None: self._channels = channels self.router = None self.process_calls: list[tuple[ChannelMessage, bool]] = [] + self.admission_decisions: list[AdmitDecision | None] = [] + self.admission_calls: list[tuple[str, ChannelMessage, object]] = [] + self._steering_handles: dict[str, SteeringHandle] = {} + self.resolved_sessions: dict[str, str] = {} + self._hook_runtime = SimpleNamespace(notify_error=self._notify_error) self.running_entries = 0 self.running_exits = 0 @@ -91,6 +97,28 @@ async def process_inbound(self, message: ChannelMessage, stream_output: bool = F stop_event.set() return None + async def admit_message(self, *, session_id: str, message: ChannelMessage, turn): + self.admission_calls.append((session_id, message, turn)) + if self.admission_decisions: + return self.admission_decisions.pop(0) + return None + + async def resolve_session(self, message: ChannelMessage) -> str: + return self.resolved_sessions.get(message.session_id, message.session_id) + + def steering(self, session_id: str) -> SteeringHandle: + handle = self._steering_handles.get(session_id) + if handle is None: + handle = SteeringHandle(session_id=session_id, buffer=SteeringBuffer()) + self._steering_handles[session_id] = handle + return handle + + def clear_steering(self, session_id: str) -> None: + self._steering_handles.pop(session_id, None) + + async def _notify_error(self, *, stage: str, error: Exception, message: ChannelMessage | None) -> None: + return None + def _message( content: str, @@ -264,7 +292,7 @@ async def never_finish() -> None: await asyncio.sleep(10) task = asyncio.create_task(never_finish()) - manager._ongoing_tasks["telegram:chat"] = {task} + manager._controller("telegram:chat").active_tasks = {task} await manager.shutdown() @@ -343,15 +371,15 @@ async def never_finish() -> None: target_task = asyncio.create_task(never_finish()) other_task = asyncio.create_task(never_finish()) - manager._ongoing_tasks["session:target"] = {target_task} - manager._ongoing_tasks["session:other"] = {other_task} + manager._controller("session:target").active_tasks = {target_task} + manager._controller("session:other").active_tasks = {other_task} await manager.quit("session:target") assert target_task.cancelled() - assert "session:target" not in manager._ongoing_tasks + assert "session:target" not in manager._session_controllers assert other_task.cancelled() is False - assert manager._ongoing_tasks["session:other"] == {other_task} + assert manager._session_controllers["session:other"].active_tasks == {other_task} other_task.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -369,13 +397,14 @@ async def never_finish() -> None: current_task = asyncio.current_task() assert current_task is not None target_task = asyncio.create_task(never_finish()) - manager._ongoing_tasks["session:target"] = {current_task, target_task} + controller = manager._controller("session:target") + controller.active_tasks = {current_task, target_task} await manager.quit("session:target") assert current_task.cancelled() is False assert target_task.cancelled() - assert "session:target" not in manager._ongoing_tasks + assert controller.active_tasks == {current_task} @pytest.mark.asyncio @@ -387,14 +416,203 @@ async def never_finish() -> None: await asyncio.sleep(10) task = asyncio.create_task(never_finish()) - manager._ongoing_tasks["session:target"] = {task} + manager._controller("session:target").active_tasks = {task} task.cancel() with contextlib.suppress(asyncio.CancelledError): await task manager._on_task_done("session:target", task) - assert "session:target" not in manager._ongoing_tasks + assert "session:target" not in manager._session_controllers + + +@pytest.mark.asyncio +async def test_channel_manager_admission_default_keeps_concurrent_processing(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("telegram:chat").active_tasks = {active} + + admitted = await manager._admit_message(_message("second")) + + assert admitted is True + session_id, message, turn = framework.admission_calls[0] + assert session_id == "telegram:chat" + assert message.content == "second" + assert turn.is_running is True + assert turn.running_count == 1 + assert turn.pending_count == 0 + assert turn.steering_count == 0 + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_uses_resolved_session_for_control(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.resolved_sessions["telegram:raw"] = "tenant:canonical" + framework.admission_decisions.append(AdmitDecision(AdmitAction.WAIT, reason="serial")) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("tenant:canonical").active_tasks = {active} + + admitted = await manager._admit_message(_message("second", session_id="telegram:raw")) + + assert admitted is False + assert framework.admission_calls[0][0] == "tenant:canonical" + assert "telegram:raw" not in manager._session_controllers + assert [message.content for message in manager._session_controllers["tenant:canonical"].pending_queue] == ["second"] + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_drop_discards_message(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.append(AdmitDecision(AdmitAction.DROP, reason="busy")) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("telegram:chat").active_tasks = {active} + + admitted = await manager._admit_message(_message("drop me")) + + assert admitted is False + assert not manager._session_controllers["telegram:chat"].pending_queue + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_wait_queues_pending_message(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.append(AdmitDecision(AdmitAction.WAIT, reason="serial")) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("telegram:chat").active_tasks = {active} + + admitted = await manager._admit_message(_message("queued")) + + assert admitted is False + assert [message.content for message in manager._session_controllers["telegram:chat"].pending_queue] == ["queued"] + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_steer_promotes_undrained_messages_to_pending(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.extend([ + AdmitDecision(AdmitAction.STEER, reason="correction"), + AdmitDecision(AdmitAction.STEER, reason="correction"), + ]) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + done = asyncio.create_task(asyncio.sleep(0)) + controller = manager._controller("telegram:chat") + controller.active_tasks = {done} + controller.add_pending(_message("already waiting")) + + admitted = await manager._admit_message(_message("actually do this")) + admitted_again = await manager._admit_message(_message("then this")) + await done + manager._on_task_done("telegram:chat", done) + for _ in range(10): + if len(framework.process_calls) == 3: + break + await asyncio.sleep(0) + + assert admitted is False + assert admitted_again is False + assert [message.content for message, _ in framework.process_calls] == [ + "actually do this", + "then this", + "already waiting", + ] + + +@pytest.mark.asyncio +async def test_channel_manager_admission_steer_drain_acknowledges_ownership(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.append(AdmitDecision(AdmitAction.STEER, reason="correction")) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + done = asyncio.create_task(asyncio.sleep(0)) + controller = manager._controller("telegram:chat") + controller.active_tasks = {done} + + admitted = await manager._admit_message(_message("consume me")) + drained = framework.steering("telegram:chat").drain_nowait() + await done + manager._on_task_done("telegram:chat", done) + + assert admitted is False + assert [message.content for message in drained] == ["consume me"] + assert framework.process_calls == [] + + +def test_turn_admission_queues_preserve_messages_without_capacity_policy() -> None: + steering = SteeringBuffer() + + assert steering.put_nowait(_message("one")) is True + assert steering.put_nowait(_message("two")) is True + assert steering.put_nowait(_message("three with a long body")) is True + drained_one = steering.get_nowait() + assert drained_one is not None + assert drained_one.content == "one" + assert [message.content for message in steering.drain_nowait()] == ["two", "three with a long body"] + + handle = SteeringHandle(session_id="telegram:chat", buffer=SteeringBuffer()) + handle.put_nowait(_message("handle one")) + handle.put_nowait(_message("handle two")) + drained_from_handle = handle.get_nowait() + assert drained_from_handle is not None + assert drained_from_handle.content == "handle one" + assert [message.content for message in handle.drain_nowait()] == ["handle two"] + + controller = SessionTurnController(session_id="telegram:chat", steering=handle) + + assert controller.add_pending(_message("one")) is True + assert controller.add_pending(_message("two")) is True + assert controller.add_pending(_message("three with a long body")) is True + assert [message.content for message in controller.pending_queue] == ["one", "two", "three with a long body"] + + assert controller.add_pending_left(_message("priority")) is True + assert [message.content for message in controller.pending_queue] == [ + "priority", + "one", + "two", + "three with a long body", + ] def test_cli_channel_normalize_input_prefixes_shell_commands() -> None: diff --git a/tests/test_framework.py b/tests/test_framework.py index 63756cfc..6550c77a 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -20,6 +20,7 @@ from bub.configure import ensure_config from bub.framework import BubFramework from bub.hookspecs import hookimpl +from bub.turn_admission import AdmitAction, AdmitDecision, SteeringHandle, TurnSnapshot def make_named_channel(name: str, label: str) -> Channel: @@ -284,6 +285,54 @@ async def dispatch_outbound(self, message) -> bool: assert saved_outputs == ["plain-text"] +@pytest.mark.asyncio +async def test_process_inbound_exposes_runtime_steering_handle() -> None: + framework = BubFramework() + observed_state: dict[str, Any] = {} + + class SteeringAwarePlugin: + @hookimpl + async def run_model(self, prompt, session_id, state) -> str: + observed_state.update(state) + return "ok" + + framework._plugin_manager.register(SteeringAwarePlugin(), name="steering-aware") + + result = await framework.process_inbound({"session_id": "session", "content": "hi"}) + + assert result.model_output == "ok" + assert isinstance(observed_state["_runtime_steering"], SteeringHandle) + assert observed_state["_runtime_steering"].session_id == "session" + + +@pytest.mark.asyncio +async def test_framework_admit_message_calls_hook_with_snapshot() -> None: + framework = BubFramework() + + class AdmissionPlugin: + @hookimpl + def admit_message(self, session_id, message, turn): + assert session_id == "session" + assert message["content"] == "hello" + assert turn.pending_count == 1 + return AdmitDecision(AdmitAction.WAIT, reason="busy") + + framework._plugin_manager.register(AdmissionPlugin(), name="admission") + decision = await framework.admit_message( + session_id="session", + message={"content": "hello"}, + turn=TurnSnapshot( + session_id="session", + is_running=True, + running_count=1, + pending_count=1, + steering_count=0, + ), + ) + + assert decision == AdmitDecision(AdmitAction.WAIT, reason="busy") + + @pytest.mark.asyncio async def test_process_inbound_streams_when_requested() -> None: # noqa: C901 framework = BubFramework() diff --git a/website/src/content/docs/docs/reference/hooks.mdx b/website/src/content/docs/docs/reference/hooks.mdx index 2face52f..9dd42d7f 100644 --- a/website/src/content/docs/docs/reference/hooks.mdx +++ b/website/src/content/docs/docs/reference/hooks.mdx @@ -30,6 +30,7 @@ For the *why* and *how* of each stage see [Turn pipeline](/docs/concepts/turn-pi | `provide_tape_store` | firstresult | `() -> TapeStore \| AsyncTapeStore` | tape store | `BubFramework.running()` | Resolved once when the runtime scope opens; sync/async iterators are entered as context managers. | | `provide_channels` | sync-only consumer (deduped) | `(message_handler: MessageHandler) -> list[Channel]` | channels | `BubFramework.get_channels` (`call_many_sync`) | Channels are deduplicated by `Channel.name`; the first channel seen in hook priority order wins. | | `build_tape_context` | firstresult | `() -> TapeContext` | tape context | `BubFramework.build_tape_context` (`call_first_sync`) | Sync-only; awaitable returns are skipped. | +| `admit_message` | firstresult | `(session_id, message, turn) -> AdmitDecision \| None` | turn admission decision | `ChannelManager` | Runs before channel scheduling. `None` keeps default concurrent scheduling; decision types are listed in [Types](/docs/reference/types/#turn-admission-types). | ## How hooks are invoked diff --git a/website/src/content/docs/docs/reference/types.mdx b/website/src/content/docs/docs/reference/types.mdx index 689fedda..a7a68c7a 100644 --- a/website/src/content/docs/docs/reference/types.mdx +++ b/website/src/content/docs/docs/reference/types.mdx @@ -51,7 +51,7 @@ Normalizes one `render_outbound` return value to a list. `None` → `[]`; `list` type State = dict[str, Any] ``` -The per-turn state dict. The framework seeds it with `_runtime_workspace` and merges the results of every `load_state` hook before the model call. The same dict is passed to `build_prompt`, `run_model[_stream]`, `save_state`, `render_outbound`, and `system_prompt`. +The per-turn state dict. The framework seeds it with `_runtime_workspace` and `_runtime_steering`, then merges the results of every `load_state` hook before the model call. The same dict is passed to `build_prompt`, `run_model[_stream]`, `save_state`, `render_outbound`, and `system_prompt`. ## `MessageHandler` @@ -95,6 +95,51 @@ class TurnResult: Returned by `BubFramework.process_inbound`. `prompt` is the resolved prompt. The source annotation is currently `str`, but a `build_prompt` hook may return multimodal content parts and the runtime preserves that list. `outbounds` is the flattened result of every `render_outbound` impl. +## Turn Admission Types + +These types are exported from `bub` and defined in `src/bub/turn_admission.py`. + +```python +class AdmitAction(StrEnum): + PROCESS = "process" + DROP = "drop" + WAIT = "wait" + STEER = "steer" +``` + +```python +@dataclass(frozen=True) +class AdmitDecision: + action: AdmitAction | Literal["process", "drop", "wait", "steer"] + reason: str | None = None +``` + +```python +@dataclass(frozen=True) +class TurnSnapshot: + session_id: str + is_running: bool + running_count: int + pending_count: int + steering_count: int +``` + +`admit_message` implementations return `AdmitDecision` to tell the channel manager whether to process, drop, wait, or steer one inbound message. + +```python +@dataclass +class SteeringHandle: + session_id: str + + @property + def count(self) -> int: ... + def has_messages(self) -> bool: ... + def get_nowait(self) -> Envelope | None: ... + def drain_nowait(self) -> list[Envelope]: ... +``` + +`SteeringHandle` is exposed to model hooks as `state["_runtime_steering"]`. `get_nowait()` removes one message; `drain_nowait()` removes all currently queued messages. Returned messages are owned by the model hook and will not be replayed. + ## `Channel` ```python @@ -159,6 +204,11 @@ class BubFramework: def get_tape_store(self) -> TapeStore | AsyncTapeStore | None: ... def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) -> str: ... def hook_report(self) -> dict[str, list[str]]: ... + async def admit_message( + self, *, session_id: str, message: Envelope, turn: TurnSnapshot + ) -> AdmitDecision | None: ... + def steering(self, session_id: str) -> SteeringHandle: ... + def clear_steering(self, session_id: str) -> None: ... @contextlib.asynccontextmanager async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]: ... @@ -179,6 +229,9 @@ class BubFramework: | `get_tape_store()` | Return the tape store entered by `running()`, or `None` outside the scope. | | `get_system_prompt(prompt, state)` | Run `system_prompt` impls (sync), reverse, and join non-empty results with `\n\n`. | | `hook_report()` | Map hook name → discovered adapter names. Backs `bub hooks`; read the hook reference before treating this order as runtime precedence. | +| `admit_message(...)` | Call the `admit_message` hook and return the selected decision. Used by `ChannelManager`. | +| `steering(session_id)` | Return the per-session steering handle exposed to model hooks. | +| `clear_steering(session_id)` | Clear an idle session's steering handle. | | `running()` | Async context manager; resolves `provide_tape_store` once and binds the resulting store for the duration. | | `bind_outbound_router(router)` | Attach (or detach with `None`) the `OutboundChannelRouter`. The `ChannelManager` calls this on start/stop. | | `build_tape_context()` | Sync-call `build_tape_context` and return the resulting `TapeContext`. | @@ -193,7 +246,11 @@ From `src/bub/__init__.py`: | Name | Kind | Description | | --- | --- | --- | | `BubFramework` | class | Framework runtime (above). | +| `AdmitAction` | enum | Turn admission action vocabulary. | +| `AdmitDecision` | dataclass | Decision returned by `admit_message`. | | `Settings` | class | Base class for plugin settings (re-exported from `bub.configure`). | +| `SteeringHandle` | dataclass | Per-session steering queue handle exposed to model hooks. | +| `TurnSnapshot` | dataclass | Snapshot passed to `admit_message`. | | `config` | decorator | `@config(name="...")` registers a settings class for YAML/env validation. | | `ensure_config` | function | `ensure_config(SettingsCls)` — return the singleton instance for that class. | | `home` | `Path` (lazy attr) | `Path(BUB_HOME)` if set, else `~/.bub`. | diff --git a/website/src/content/docs/zh-cn/docs/reference/hooks.mdx b/website/src/content/docs/zh-cn/docs/reference/hooks.mdx index b3f4939a..27f4639d 100644 --- a/website/src/content/docs/zh-cn/docs/reference/hooks.mdx +++ b/website/src/content/docs/zh-cn/docs/reference/hooks.mdx @@ -30,6 +30,7 @@ description: BubHookSpecs 中每个钩子的类型、签名、返回值与调用 | `provide_tape_store` | firstresult | `() -> TapeStore \| AsyncTapeStore` | tape store | `BubFramework.running()` | 仅在 runtime 作用域开启时解析一次;返回同步或异步迭代器时会被作为 context manager 进入。 | | `provide_channels` | sync-only consumer (deduped) | `(message_handler: MessageHandler) -> list[Channel]` | channels | `BubFramework.get_channels` (`call_many_sync`) | 按 `Channel.name` 去重;在钩子优先级顺序中最先出现的 channel 胜出。 | | `build_tape_context` | firstresult | `() -> TapeContext` | tape context | `BubFramework.build_tape_context` (`call_first_sync`) | 仅同步;awaitable 返回会被跳过。 | +| `admit_message` | firstresult | `(session_id, message, turn) -> AdmitDecision \| None` | turn admission decision | `ChannelManager` | 调度 channel message 前调用。返回 `None` 保持默认并发调度;decision 类型见 [类型](/zh-cn/docs/reference/types/#turn-admission-类型)。 | ## 钩子如何被调用 diff --git a/website/src/content/docs/zh-cn/docs/reference/types.mdx b/website/src/content/docs/zh-cn/docs/reference/types.mdx index 29c5086f..555cbeac 100644 --- a/website/src/content/docs/zh-cn/docs/reference/types.mdx +++ b/website/src/content/docs/zh-cn/docs/reference/types.mdx @@ -51,7 +51,7 @@ def unpack_batch(batch: Any) -> list[Envelope] type State = dict[str, Any] ``` -per-turn 的 state dict。框架先以 `_runtime_workspace` 初始化,再合并所有 `load_state` 钩子的结果,然后才调用模型。同一个 dict 会被传给 `build_prompt`、`run_model[_stream]`、`save_state`、`render_outbound` 与 `system_prompt`。 +per-turn 的 state dict。框架先以 `_runtime_workspace` 与 `_runtime_steering` 初始化,再合并所有 `load_state` 钩子的结果,然后才调用模型。同一个 dict 会被传给 `build_prompt`、`run_model[_stream]`、`save_state`、`render_outbound` 与 `system_prompt`。 ## `MessageHandler` @@ -95,6 +95,51 @@ class TurnResult: `BubFramework.process_inbound` 的返回值。`prompt` 是解析后的 prompt。源码标注当前仍是 `str`,但 `build_prompt` hook 可以返回多模态内容片段列表,运行时会保留这个 list。`outbounds` 是所有 `render_outbound` 实现结果的扁平拼接。 +## Turn Admission 类型 + +这些类型从 `bub` 导出,定义在 `src/bub/turn_admission.py`。 + +```python +class AdmitAction(StrEnum): + PROCESS = "process" + DROP = "drop" + WAIT = "wait" + STEER = "steer" +``` + +```python +@dataclass(frozen=True) +class AdmitDecision: + action: AdmitAction | Literal["process", "drop", "wait", "steer"] + reason: str | None = None +``` + +```python +@dataclass(frozen=True) +class TurnSnapshot: + session_id: str + is_running: bool + running_count: int + pending_count: int + steering_count: int +``` + +`admit_message` 实现返回 `AdmitDecision`,告诉 channel manager 对单条 inbound message 执行 process、drop、wait 或 steer。 + +```python +@dataclass +class SteeringHandle: + session_id: str + + @property + def count(self) -> int: ... + def has_messages(self) -> bool: ... + def get_nowait(self) -> Envelope | None: ... + def drain_nowait(self) -> list[Envelope]: ... +``` + +`SteeringHandle` 会以 `state["_runtime_steering"]` 暴露给 model hooks。`get_nowait()` 取出一条;`drain_nowait()` 取出当前全部 queued messages。返回的消息由 model hook 接管,不会重放。 + ## `Channel` ```python @@ -159,6 +204,11 @@ class BubFramework: def get_tape_store(self) -> TapeStore | AsyncTapeStore | None: ... def get_system_prompt(self, prompt: str | list[dict], state: dict[str, Any]) -> str: ... def hook_report(self) -> dict[str, list[str]]: ... + async def admit_message( + self, *, session_id: str, message: Envelope, turn: TurnSnapshot + ) -> AdmitDecision | None: ... + def steering(self, session_id: str) -> SteeringHandle: ... + def clear_steering(self, session_id: str) -> None: ... @contextlib.asynccontextmanager async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]: ... @@ -179,6 +229,9 @@ class BubFramework: | `get_tape_store()` | 返回 `running()` 中启用的 tape store;在作用域之外返回 `None`。 | | `get_system_prompt(prompt, state)` | 同步调用 `system_prompt` 实现,反转后用 `\n\n` 拼接非空片段。 | | `hook_report()` | 返回 hook 名 → 已发现的 adapter 列表。`bub hooks` 的数据来源;不要只根据该输出顺序推断运行时优先级。 | +| `admit_message(...)` | 调用 `admit_message` hook 并返回选中的 decision。由 `ChannelManager` 使用。 | +| `steering(session_id)` | 返回暴露给 model hooks 的 per-session steering handle。 | +| `clear_steering(session_id)` | 清除 idle session 的 steering handle。 | | `running()` | 异步 context manager;一次性解析 `provide_tape_store` 并在作用域内绑定 tape store。 | | `bind_outbound_router(router)` | 绑定(或传 `None` 解绑)`OutboundChannelRouter`。`ChannelManager` 在启停时调用。 | | `build_tape_context()` | 同步调用 `build_tape_context` 并返回 `TapeContext`。 | @@ -193,7 +246,11 @@ class BubFramework: | 名称 | 类型 | 描述 | | --- | --- | --- | | `BubFramework` | class | 框架运行时(见上)。 | +| `AdmitAction` | enum | Turn admission action 词汇。 | +| `AdmitDecision` | dataclass | `admit_message` 返回的 decision。 | | `Settings` | class | 插件配置基类(从 `bub.configure` 重新导出)。 | +| `SteeringHandle` | dataclass | 暴露给 model hooks 的 per-session steering queue handle。 | +| `TurnSnapshot` | dataclass | 传给 `admit_message` 的快照。 | | `config` | decorator | `@config(name="...")` 注册一个用于 YAML/env 验证的配置类。 | | `ensure_config` | function | `ensure_config(SettingsCls)` —— 返回该类的单例实例。 | | `home` | `Path`(惰性属性) | `BUB_HOME` 已设置时为 `Path(BUB_HOME)`,否则为 `~/.bub`。 |