Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/bub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,21 @@
from bub.framework import DEFAULT_HOME, BubFramework
from bub.hookspecs import hookimpl
from bub.tools import tool

__all__ = ["BubFramework", "Settings", "config", "ensure_config", "home", "hookimpl", "tool"]
from bub.turn_admission import AdmitAction, AdmitDecision, SteeringHandle, TurnSnapshot

__all__ = [
"AdmitAction",
"AdmitDecision",
"BubFramework",
"Settings",
"SteeringHandle",
"TurnSnapshot",
"config",
"ensure_config",
"home",
"hookimpl",
"tool",
]

try:
__version__ = import_module("bub._version").version
Expand Down
165 changes: 153 additions & 12 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from bub.configure import Settings, ensure_config
from bub.envelope import content_of, field_of
from bub.framework import BubFramework
from bub.turn_admission import AdmitAction, AdmitDecision, SessionTurnController
from bub.types import Envelope, MessageHandler
from bub.utils import wait_until_stopped

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
else:
self._enabled_channels = self._settings.enabled_channels.split(",")
self._messages = asyncio.Queue[ChannelMessage]()
self._ongoing_tasks: dict[str, set[asyncio.Task]] = {}
self._session_controllers: dict[str, SessionTurnController] = {}
self._session_handlers: dict[str, MessageHandler] = {}

async def on_receive(self, message: ChannelMessage) -> None:
Expand Down Expand Up @@ -117,7 +118,14 @@ def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) ->
return channel.stream_events(message, stream)

async def quit(self, session_id: str) -> None:
tasks = self._ongoing_tasks.pop(session_id, set())
controller = self._session_controllers.get(session_id)
if controller is None:
self.framework.clear_steering(session_id)
logger.info(f"channel.manager quit session_id={session_id}, cancelled 0 tasks")
return
controller.clear_pending()
controller.steering.drain_nowait()
tasks = set(controller.active_tasks)
current_task = asyncio.current_task()
cancelled_count = 0
for task in tasks:
Expand All @@ -127,6 +135,8 @@ async def quit(self, session_id: str) -> None:
with contextlib.suppress(asyncio.CancelledError):
await task
cancelled_count += 1
controller.active_tasks.difference_update(task for task in tasks if task is not current_task)
self._drop_empty_controller(session_id)
logger.info(f"channel.manager quit session_id={session_id}, cancelled {cancelled_count} tasks")

def enabled_channels(self) -> list[Channel]:
Expand All @@ -137,15 +147,132 @@ def enabled_channels(self) -> list[Channel]:
channel for name, channel in self._channels.items() if name in self._enabled_channels and channel.enabled
]

def _controller(self, session_id: str) -> SessionTurnController:
controller = self._session_controllers.get(session_id)
if controller is None:
controller = SessionTurnController(session_id=session_id, steering=self.framework.steering(session_id))
self._session_controllers[session_id] = controller
return controller

def _drop_empty_controller(self, session_id: str) -> None:
controller = self._session_controllers.get(session_id)
if controller is None:
return
if controller.active() or controller.pending_queue or controller.steering.has_messages():
return
self._session_controllers.pop(session_id, None)
self.framework.clear_steering(session_id)

def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
if task.cancelled():
logger.info("channel.manager task cancelled session_id={}", session_id)
else:
task.exception() # to log any exception
tasks = self._ongoing_tasks.get(session_id, set())
tasks.discard(task)
if not tasks:
self._ongoing_tasks.pop(session_id, None)
controller = self._session_controllers.get(session_id)
if controller is None:
return
controller.active_tasks.discard(task)
if not controller.active():
controller.promote_steering_to_pending()
self._schedule_pending(session_id)
self._drop_empty_controller(session_id)

async def _admit_message(self, message: ChannelMessage) -> bool:
try:
session_id = await self._resolve_message_session(message)
except Exception as exc:
logger.exception("channel.manager resolve_session failed")
await self.framework._hook_runtime.notify_error(stage="resolve_session", error=exc, message=message)
return False
controller = self._controller(session_id)
try:
decision = await self.framework.admit_message(
session_id=session_id,
message=message,
turn=controller.snapshot(),
)
except Exception as exc:
logger.exception("channel.manager admission hook failed")
await self.framework._hook_runtime.notify_error(stage="admit_message", error=exc, message=message)
return True
if decision is None:
self._drop_empty_controller(session_id)
return True
admitted = await self._apply_admission_decision(controller, message, decision)
if not admitted or not controller.active():
self._drop_empty_controller(session_id)
return admitted

async def _apply_admission_decision(
self,
controller: SessionTurnController,
message: ChannelMessage,
decision: AdmitDecision,
) -> bool:
action = _normalize_admit_action(decision.action)
if action == AdmitAction.PROCESS:
return True
if action == AdmitAction.DROP:
logger.info(
"channel.manager admission drop session_id={} reason={}",
message.session_id,
decision.reason,
)
return False
if action == AdmitAction.WAIT:
return self._queue_pending(controller, message, decision.reason)
if action == AdmitAction.STEER:
if controller.active() and controller.steering.put_nowait(message):
logger.info(
"channel.manager admission steer session_id={} reason={}",
message.session_id,
decision.reason,
)
return False
return self._queue_pending(controller, message, decision.reason)
logger.warning("channel.manager admission unknown action={} session_id={}", decision.action, message.session_id)
return True

def _queue_pending(
self,
controller: SessionTurnController,
message: ChannelMessage,
reason: str | None,
) -> bool:
if not controller.active():
return True
controller.add_pending(message)
logger.info(
"channel.manager admission wait session_id={} pending_count={} reason={}",
message.session_id,
len(controller.pending_queue),
reason,
)
return False

def _schedule_message(self, message: ChannelMessage) -> asyncio.Task:
controller = self._controller(message.session_id)
task = asyncio.create_task(self._run_message(message))
task.add_done_callback(functools.partial(self._on_task_done, message.session_id))
controller.active_tasks.add(task)
return task

def _schedule_pending(self, session_id: str) -> None:
controller = self._session_controllers.get(session_id)
if controller is None or controller.active():
return
message = controller.pop_pending()
if message is not None:
self._schedule_message(message)

async def _resolve_message_session(self, message: ChannelMessage) -> str:
session_id = await self.framework.resolve_session(message)
message.session_id = session_id
setattr(message, "_runtime_session_id", session_id) # noqa: B010
return session_id

async def _run_message(self, message: ChannelMessage) -> None:
await self.framework.process_inbound(message, self._stream_output)

async def listen_and_run(self) -> None:
stop_event = asyncio.Event()
Expand All @@ -157,9 +284,9 @@ async def listen_and_run(self) -> None:
try:
while True:
message = await wait_until_stopped(self._messages.get(), stop_event)
task = asyncio.create_task(self.framework.process_inbound(message, self._stream_output))
task.add_done_callback(functools.partial(self._on_task_done, message.session_id))
self._ongoing_tasks.setdefault(message.session_id, set()).add(task)
if not await self._admit_message(message):
continue
self._schedule_message(message)
except asyncio.CancelledError:
logger.info("channel.manager received shutdown signal")
except Exception:
Expand All @@ -172,13 +299,27 @@ async def listen_and_run(self) -> None:

async def shutdown(self) -> None:
count = 0
for tasks in self._ongoing_tasks.values():
for task in tasks:
session_ids = list(self._session_controllers)
for controller in list(self._session_controllers.values()):
controller.clear_pending()
controller.steering.drain_nowait()
for task in set(controller.active_tasks):
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
count += 1
self._ongoing_tasks.clear()
self._session_controllers.clear()
for session_id in session_ids:
self.framework.clear_steering(session_id)
logger.info(f"channel.manager cancelled {count} in-flight tasks")
for channel in self.enabled_channels():
await channel.stop()


def _normalize_admit_action(action: AdmitAction | str) -> AdmitAction | str:
if isinstance(action, AdmitAction):
return action
try:
return AdmitAction(action)
except ValueError:
return action
40 changes: 35 additions & 5 deletions src/bub/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import AsyncGenerator, AsyncIterator, Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

import pluggy
import typer
Expand All @@ -20,6 +20,7 @@
from bub.envelope import content_of, field_of, unpack_batch
from bub.hook_runtime import _SKIP_VALUE, HookRuntime
from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs
from bub.turn_admission import AdmitDecision, SteeringBuffer, SteeringHandle, TurnSnapshot
from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None:
self._hook_runtime = HookRuntime(self._plugin_manager)
self._plugin_status: dict[str, PluginStatus] = {}
self._outbound_router: OutboundChannelRouter | None = None
self._steering_handles: dict[str, SteeringHandle] = {}
self._tape_store: TapeStore | AsyncTapeStore | None = None
configure.load(self.config_file)

Expand Down Expand Up @@ -109,12 +111,10 @@ async def process_inbound(self, inbound: Envelope, stream_output: bool = False)
"""Run one inbound message through hooks and return turn result."""

try:
session_id = await self._hook_runtime.call_first(
"resolve_session", message=inbound
) or self._default_session_id(inbound)
session_id = await self.resolve_session(inbound)
if isinstance(inbound, dict):
inbound.setdefault("session_id", session_id)
state = {"_runtime_workspace": str(self.workspace)}
state = {"_runtime_workspace": str(self.workspace), "_runtime_steering": self.steering(session_id)}
for hook_state in reversed(
await self._hook_runtime.call_many("load_state", message=inbound, session_id=session_id)
):
Expand Down Expand Up @@ -146,6 +146,15 @@ async def process_inbound(self, inbound: Envelope, stream_output: bool = False)
await self._hook_runtime.notify_error(stage="turn", error=exc, message=inbound)
raise

async def resolve_session(self, message: Envelope) -> str:
"""Resolve the canonical session id for a message."""

runtime_session_id = field_of(message, "_runtime_session_id")
if runtime_session_id is not None:
return str(runtime_session_id)
resolved = await self._hook_runtime.call_first("resolve_session", message=message)
return str(resolved or self._default_session_id(message))

async def _run_model(
self,
inbound: Envelope,
Expand Down Expand Up @@ -207,6 +216,27 @@ async def quit_via_router(self, session_id: str) -> None:
if self._outbound_router is not None:
await self._outbound_router.quit(session_id)

async def admit_message(self, *, session_id: str, message: Envelope, turn: TurnSnapshot) -> AdmitDecision | None:
return cast(
"AdmitDecision | None",
await self._hook_runtime.call_first(
"admit_message",
session_id=session_id,
message=message,
turn=turn,
),
)

def steering(self, session_id: str) -> SteeringHandle:
handle = self._steering_handles.get(session_id)
if handle is None:
handle = SteeringHandle(session_id=session_id, buffer=SteeringBuffer())
self._steering_handles[session_id] = handle
return handle

def clear_steering(self, session_id: str) -> None:
self._steering_handles.pop(session_id, None)

@staticmethod
def _default_session_id(message: Envelope) -> str:
session_id = field_of(message, "session_id")
Expand Down
14 changes: 14 additions & 0 deletions src/bub/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from republic import AsyncStreamEvents, AsyncTapeStore, TapeContext
from republic.tape import TapeStore

from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler, State

if TYPE_CHECKING:
Expand Down Expand Up @@ -107,3 +108,16 @@ def provide_channels(self, message_handler: MessageHandler) -> list[Channel]:
def build_tape_context(self) -> TapeContext:
"""Build a tape context for the current session, to be used to build context messages."""
raise NotImplementedError

@hookspec(firstresult=True)
def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
"""Decide how to handle an inbound channel message for a session.
Return ``None`` to keep Bub's default concurrent scheduling behavior.
"""
raise NotImplementedError
Comment on lines +112 to +123
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to implement this as a method of Channel?

Because steering is highly related to channels, each channel should be able to define its own steering logic. The current approach only takes the first implemented hook.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see admit_message as more than channel-side input handling. Some decisions, especially STEER, only become meaningful when they are paired with a run_model implementation that knows how to consume steering input. If the active model hook never drains the steering queue, a channel-level STEER decision cannot actually steer the running turn.

Loading
Loading