Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion packages/bub-schedule/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
]

[project.entry-points.bub]
schedule = "bub_schedule.plugin:main"
schedule = "bub_schedule.plugin:ScheduleImpl"

[build-system]
requires = ["uv_build>=0.9.7,<0.10.0"]
Expand All @@ -21,4 +21,8 @@ build-backend = "uv_build"
[dependency-groups]
dev = [
"pytest>=9.0.3",
"pytest-asyncio>=1.3.0",
]

[tool.pytest.ini_options]
asyncio_mode = "auto"
22 changes: 21 additions & 1 deletion packages/bub-schedule/src/bub_schedule/channel.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,43 @@
from __future__ import annotations

import asyncio
from asyncio import Event

from apscheduler.schedulers.base import BaseScheduler
from bub.channels import Channel
from bub.framework import BubFramework
from loguru import logger


class ScheduleChannel(Channel):
name = "schedule"

def __init__(self, scheduler: BaseScheduler) -> None:
# Class-level runtime state (singleton per process)
_framework: BubFramework | None = None

def __init__(self, scheduler: BaseScheduler, *, framework: BubFramework) -> None:
self.scheduler = scheduler
self._instance_framework = framework

@classmethod
def current_framework(cls) -> BubFramework:
"""Return the live framework bound to the current gateway process."""
if cls._framework is None:
raise RuntimeError(
"no live schedule framework available, cannot deliver scheduled message"
)
return cls._framework

async def start(self, stop_event: Event) -> None:
ScheduleChannel._framework = self._instance_framework

loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(self.scheduler.start)
logger.info("schedule.start complete")

async def stop(self) -> None:
loop = asyncio.get_running_loop()
loop.call_soon_threadsafe(self.scheduler.shutdown)

ScheduleChannel._framework = None
logger.info("schedule.stop complete")
11 changes: 4 additions & 7 deletions packages/bub-schedule/src/bub_schedule/jobs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

from pathlib import Path

from bub import BubFramework
from bub.channels.message import ChannelMessage

SCHEDULE_SUBPROCESS_TIMEOUT_SECONDS = 300
Expand All @@ -11,19 +8,19 @@
async def run_scheduled_reminder(
message: str, session_id: str, workspace: str | None = None
) -> None:
framework = BubFramework()
framework.load_hooks()
if workspace:
framework.workspace = Path(workspace).resolve()
from bub_schedule.channel import ScheduleChannel

if ":" in session_id:
channel, chat_id = session_id.split(":", 1)
else:
channel = "schedule"
chat_id = "default"

payload = ChannelMessage(
content=message,
session_id=session_id,
channel=channel,
chat_id=chat_id,
)
framework = ScheduleChannel.current_framework()
await framework.process_inbound(payload)
9 changes: 4 additions & 5 deletions packages/bub-schedule/src/bub_schedule/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from apscheduler.schedulers.base import BaseScheduler
from bub import hookimpl
from bub.channels import Channel
from bub.framework import BubFramework
from bub.types import Envelope, MessageHandler, State

from bub_schedule.jobstore import JSONJobStore
Expand All @@ -15,9 +16,10 @@ def default_scheduler() -> BaseScheduler:


class ScheduleImpl:
def __init__(self) -> None:
def __init__(self, framework: BubFramework | None = None) -> None:
from bub_schedule import tools # noqa: F401

self.framework = framework
self.scheduler = default_scheduler()

@hookimpl
Expand All @@ -28,7 +30,4 @@ def load_state(self, message: Envelope, session_id: str) -> State:
def provide_channels(self, message_handler: MessageHandler) -> list[Channel]:
from bub_schedule.channel import ScheduleChannel

return [ScheduleChannel(self.scheduler)]


main = ScheduleImpl()
return [ScheduleChannel(self.scheduler, framework=self.framework)]
121 changes: 121 additions & 0 deletions packages/bub-schedule/tests/test_scheduled_delivery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Tests for scheduled task delivery through the live framework."""
from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, MagicMock

import pytest

from bub.channels.message import ChannelMessage
from bub_schedule.channel import ScheduleChannel
from bub_schedule.jobs import run_scheduled_reminder


@pytest.fixture
def mock_framework():
"""Create a mock BubFramework."""
framework = MagicMock()
framework.process_inbound = AsyncMock()
return framework


@pytest.fixture
def scheduler():
"""Create a mock scheduler; job execution itself is not under test here."""
return MagicMock()


@pytest.fixture
def channel(scheduler, mock_framework):
"""Create a ScheduleChannel with mock framework."""
return ScheduleChannel(scheduler, framework=mock_framework)


class TestCurrentFramework:
"""Test class-level live framework registration lifecycle."""

async def test_raises_when_no_channel_started(self):
"""current_framework should raise RuntimeError before start()."""
ScheduleChannel._framework = None
with pytest.raises(RuntimeError, match="no live schedule framework"):
ScheduleChannel.current_framework()

async def test_returns_framework_after_start(self, channel, mock_framework):
"""start() should register the live framework."""
stop_event = asyncio.Event()
await channel.start(stop_event)
try:
assert ScheduleChannel.current_framework() is mock_framework
finally:
await channel.stop()

async def test_cleared_after_stop(self, channel):
"""Class-level state should be None after stop()."""
stop_event = asyncio.Event()
await channel.start(stop_event)
await channel.stop()
assert ScheduleChannel._framework is None


class TestRunScheduledReminder:
"""Test that run_scheduled_reminder directly uses the live framework."""

async def test_processes_payload_via_live_framework(self, channel, mock_framework):
"""run_scheduled_reminder should call framework.process_inbound directly."""
stop_event = asyncio.Event()
await channel.start(stop_event)
try:
await run_scheduled_reminder(
message="test message",
session_id="feishu:oc_123",
)

mock_framework.process_inbound.assert_called_once()
payload = mock_framework.process_inbound.call_args[0][0]
assert isinstance(payload, ChannelMessage)
assert payload.content == "test message"
assert payload.session_id == "feishu:oc_123"
assert payload.channel == "feishu"
assert payload.chat_id == "oc_123"
finally:
await channel.stop()

async def test_fallback_session_channel(self, channel, mock_framework):
"""Session ID without ':' should default to schedule:default."""
stop_event = asyncio.Event()
await channel.start(stop_event)
try:
await run_scheduled_reminder(
message="reminder",
session_id="simple_session",
)

mock_framework.process_inbound.assert_called_once()
payload = mock_framework.process_inbound.call_args[0][0]
assert payload.channel == "schedule"
assert payload.chat_id == "default"
finally:
await channel.stop()

async def test_no_framework_raises_error(self):
"""When no live framework is registered, should raise RuntimeError."""
ScheduleChannel._framework = None
with pytest.raises(RuntimeError, match="no live schedule framework"):
await run_scheduled_reminder(
message="should fail",
session_id="feishu:oc_456",
)

async def test_propagates_process_inbound_error(self, channel, mock_framework):
"""Delivery should surface framework.process_inbound failures directly."""
mock_framework.process_inbound.side_effect = RuntimeError("boom")
stop_event = asyncio.Event()
await channel.start(stop_event)
try:
with pytest.raises(RuntimeError, match="boom"):
await run_scheduled_reminder(
message="fail",
session_id="feishu:oc_999",
)
finally:
await channel.stop()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@ test = [

[tool.pytest.ini_options]
addopts = "-ra --import-mode=importlib"
asyncio_mode = "auto"
Loading