diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 0c58174897..cf43fa2f38 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -377,11 +377,10 @@ def append_or_extend( append_or_extend(gemini_contents, parts, types.UserContent) elif role == "assistant": - if isinstance(content, str): - parts = [types.Part.from_text(text=content)] - append_or_extend(gemini_contents, parts, types.ModelContent) + parts = [] + if isinstance(content, str) and content: + parts.append(types.Part.from_text(text=content)) elif isinstance(content, list): - parts = [] thinking_signature = None text = "" for part in content: @@ -406,11 +405,10 @@ def append_or_extend( thought_signature=thinking_signature, ) ) - append_or_extend(gemini_contents, parts, types.ModelContent) - elif not native_tool_enabled and "tool_calls" in message: - parts = [] - for tool in message["tool_calls"]: + tool_calls = message.get("tool_calls") or [] + if not native_tool_enabled and tool_calls: + for tool in tool_calls: part = types.Part.from_function_call( name=tool["function"]["name"], args=json.loads(tool["function"]["arguments"]), @@ -427,15 +425,16 @@ def append_or_extend( if ts_bs64: part.thought_signature = base64.b64decode(ts_bs64) parts.append(part) - append_or_extend(gemini_contents, parts, types.ModelContent) - else: + + if not parts: logger.warning("assistant 角色的消息内容为空,已添加空格占位") if native_tool_enabled and "tool_calls" in message: logger.warning( "检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文", ) parts = [types.Part.from_text(text=" ")] - append_or_extend(gemini_contents, parts, types.ModelContent) + + append_or_extend(gemini_contents, parts, types.ModelContent) elif role == "tool" and not native_tool_enabled: func_name = message.get("name", message["tool_call_id"]) diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py index 4db8e92bfe..a9a1ca9c15 100644 --- a/tests/test_gemini_source.py +++ b/tests/test_gemini_source.py @@ -1,3 +1,5 @@ +import base64 + import pytest from astrbot.core.exceptions import EmptyModelOutputError @@ -27,3 +29,103 @@ def test_gemini_reasoning_only_output_is_allowed(): response_id="resp_reasoning", finish_reason="STOP", ) + + +def _make_gemini_provider_for_conversation(): + provider = object.__new__(ProviderGoogleGenAI) + provider.provider_config = { + "gm_native_coderunner": False, + "gm_native_search": False, + } + return provider + + +def _assistant_tool_call_message(content): + return { + "role": "assistant", + "content": content, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "get_pull_request_files", + "arguments": '{"owner":"AstrBotDevs","repo":"AstrBot","pull_number":8742}', + }, + }, + ], + } + + +def _first_model_parts(gemini_contents): + model_content = next( + content + for content in gemini_contents + if content.__class__.__name__ == "ModelContent" + ) + return model_content.parts or [] + + +def test_prepare_conversation_keeps_assistant_text_and_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "summarize this PR"}, + _assistant_tool_call_message("I will inspect the changed files first."), + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert any(part.text == "I will inspect the changed files first." for part in parts) + assert [ + part.function_call.name + for part in parts + if getattr(part, "function_call", None) + ] == ["get_pull_request_files"] + + +def test_prepare_conversation_keeps_assistant_list_content_and_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "summarize this PR"}, + _assistant_tool_call_message( + [ + { + "type": "think", + "encrypted": base64.b64encode(b"signature").decode("utf-8"), + }, + {"type": "text", "text": "I will inspect the changed files first."}, + ] + ), + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert any(part.text == "I will inspect the changed files first." for part in parts) + assert [ + part.function_call.name + for part in parts + if getattr(part, "function_call", None) + ] == ["get_pull_request_files"] + + +def test_prepare_conversation_ignores_null_tool_calls(): + provider = _make_gemini_provider_for_conversation() + payloads = { + "messages": [ + {"role": "user", "content": "hello"}, + { + "role": "assistant", + "content": "hello back", + "tool_calls": None, + }, + ] + } + + parts = _first_model_parts(provider._prepare_conversation(payloads)) + + assert [part.text for part in parts] == ["hello back"] + assert not any(getattr(part, "function_call", None) for part in parts)