From a54623ded54ca61160ceb62d3221e35fec40a692 Mon Sep 17 00:00:00 2001 From: zzz27578 <2950506809@qq.com> Date: Sat, 20 Jun 2026 01:59:57 +0800 Subject: [PATCH] fix: apply fallback chat models to proactive wakeups --- astrbot/core/astr_agent_tool_exec.py | 7 ++- astrbot/core/cron/manager.py | 8 +-- tests/unit/test_astr_agent_tool_exec.py | 76 +++++++++++++++++++++++++ tests/unit/test_cron_manager.py | 66 +++++++++++++++++++++ 4 files changed, 150 insertions(+), 7 deletions(-) diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 8c3ed661f9..23c174f1fd 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -543,11 +543,12 @@ async def _wake_main_agent_for_background_result( message_type=session.message_type, ) cron_event.role = event.role + cfg = ctx.get_config(umo=event.unified_msg_origin) or {} + provider_settings = cfg.get("provider_settings", {}) config = MainAgentBuildConfig( tool_call_timeout=run_context.tool_call_timeout, - streaming_response=ctx.get_config() - .get("provider_settings", {}) - .get("stream", False), + streaming_response=provider_settings.get("stream", False), + provider_settings=provider_settings, ) req = ProviderRequest() diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index fde2ad5cd8..a4bd1e1b71 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -328,7 +328,7 @@ async def _woke_main_agent( # judge user's role umo = cron_event.unified_msg_origin - cfg = self.ctx.get_config(umo=umo) + cfg = self.ctx.get_config(umo=umo) or {} cron_payload = extras.get("cron_payload", {}) if extras else {} sender_id = cron_payload.get("sender_id") admin_ids = cfg.get("admins_id", []) @@ -337,13 +337,13 @@ async def _woke_main_agent( if cron_payload.get("origin", "tool") == "api": cron_event.role = "admin" - tool_call_timeout = cfg.get("provider_settings", {}).get( - "tool_call_timeout", 120 - ) + provider_settings = cfg.get("provider_settings", {}) + tool_call_timeout = provider_settings.get("tool_call_timeout", 120) config = MainAgentBuildConfig( tool_call_timeout=tool_call_timeout, llm_safety_mode=False, streaming_response=False, + provider_settings=provider_settings, ) req = ProviderRequest() conv = await _get_session_conv(event=cron_event, plugin_context=self.ctx) diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 61fb4048c8..c66be69dae 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from unittest.mock import AsyncMock import mcp import pytest @@ -19,6 +20,7 @@ class _DummyEvent: def __init__(self, message_components: list[object] | None = None) -> None: self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" self.message_obj = SimpleNamespace(message=message_components or []) + self.role = "member" def get_extra(self, _key: str): return None @@ -36,6 +38,15 @@ def _build_run_context(message_components: list[object] | None = None): return ContextWrapper(context=ctx) +class _DoneRunner: + async def step_until_done(self, _max_step): + for item in (): + yield item + + def get_final_llm_resp(self): + return SimpleNamespace(role="assistant", completion_text="done") + + def test_build_handoff_toolset_keeps_permission_guards_for_default_tools(): mgr = FunctionToolManager() plugin_tool = FunctionTool( @@ -354,6 +365,71 @@ async def _fake_tool_loop_agent(**kwargs): assert captured["tool_call_timeout"] == 120 +@pytest.mark.asyncio +async def test_background_wakeup_passes_provider_settings_to_main_agent( + monkeypatch: pytest.MonkeyPatch, +): + provider_settings = { + "fallback_chat_models": ["fallback-provider"], + "request_max_retries": 3, + "stream": True, + } + captured: dict = {} + + async def _fake_get_session_conv(**_kwargs): + return SimpleNamespace(history="[]") + + async def _fake_build_main_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(agent_runner=_DoneRunner()) + + monkeypatch.setattr( + "astrbot.core.astr_main_agent._get_session_conv", + _fake_get_session_conv, + ) + monkeypatch.setattr( + "astrbot.core.astr_main_agent.build_main_agent", + _fake_build_main_agent, + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.persist_agent_history", + AsyncMock(), + ) + + send_tool = FunctionTool( + name="send_message_to_user", + description="send", + parameters={"type": "object", "properties": {}}, + ) + context = SimpleNamespace( + get_config=lambda **_kwargs: {"provider_settings": provider_settings}, + get_llm_tool_manager=lambda: SimpleNamespace( + get_builtin_tool=lambda _tool_cls: send_tool + ), + conversation_manager=SimpleNamespace(), + ) + run_context = ContextWrapper( + context=SimpleNamespace(event=_DummyEvent([]), context=context), + tool_call_timeout=456, + ) + + await FunctionToolExecutor._wake_main_agent_for_background_result( + run_context, + task_id="task-id", + tool_name="long_tool", + result_text="ok", + tool_args={}, + note="task finished", + summary_name="BackgroundTask", + ) + + config = captured["config"] + assert config.tool_call_timeout == 456 + assert config.streaming_response == provider_settings["stream"] + assert config.provider_settings is provider_settings + assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"] + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_filters_extensionless_file_outside_temp_root( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_cron_manager.py b/tests/unit/test_cron_manager.py index 5596973133..0e25c862a1 100644 --- a/tests/unit/test_cron_manager.py +++ b/tests/unit/test_cron_manager.py @@ -1,6 +1,7 @@ """Tests for CronJobManager.""" from datetime import datetime, timedelta, timezone +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -462,6 +463,71 @@ async def test_run_job_not_found(self, cron_manager, mock_db): mock_db.update_cron_job.assert_not_called() +class _DoneRunner: + async def step_until_done(self, _max_step): + for item in (): + yield item + + def get_final_llm_resp(self): + return SimpleNamespace(role="assistant", completion_text="done") + + +class TestWokeMainAgent: + """Tests for active-agent wakeup configuration.""" + + @pytest.mark.asyncio + async def test_woke_main_agent_passes_provider_settings_to_main_agent( + self, cron_manager, mock_context, monkeypatch + ): + """Future tasks should use configured fallback chat models.""" + provider_settings = { + "fallback_chat_models": ["fallback-provider"], + "request_max_retries": 2, + "tool_call_timeout": 321, + } + mock_context.get_config.return_value = { + "admins_id": [], + "provider_settings": provider_settings, + } + cron_manager.ctx = mock_context + captured: dict = {} + + async def fake_get_session_conv(**_kwargs): + return SimpleNamespace(history="[]") + + async def fake_build_main_agent(**kwargs): + captured.update(kwargs) + return SimpleNamespace(agent_runner=_DoneRunner()) + + monkeypatch.setattr( + "astrbot.core.astr_main_agent._get_session_conv", + fake_get_session_conv, + ) + monkeypatch.setattr( + "astrbot.core.astr_main_agent.build_main_agent", + fake_build_main_agent, + ) + monkeypatch.setattr( + "astrbot.core.cron.manager.persist_agent_history", + AsyncMock(), + ) + + await cron_manager._woke_main_agent( + message="run scheduled task", + session_str="cron:OtherMessage:test-job-id", + extras={ + "cron_job": {"id": "test-job-id"}, + "cron_payload": {"origin": "tool"}, + }, + ) + + config = captured["config"] + assert config.tool_call_timeout == 321 + assert config.streaming_response is False + assert config.provider_settings is provider_settings + assert config.provider_settings["fallback_chat_models"] == ["fallback-provider"] + + class TestRunBasicJob: """Tests for _run_basic_job method."""