diff --git a/portkey_ai/integrations/adk.py b/portkey_ai/integrations/adk.py index fdf9fa22..bdc08ad7 100644 --- a/portkey_ai/integrations/adk.py +++ b/portkey_ai/integrations/adk.py @@ -76,8 +76,15 @@ def __init__( class _TextChunk: - def __init__(self, text: str) -> None: + def __init__(self, text: str, thought_signature: Optional[str] = None) -> None: self.text = text + self.thought_signature = thought_signature + + +class _ThoughtChunk: + def __init__(self, text: str, thought_signature: Optional[str] = None) -> None: + self.text = text + self.thought_signature = thought_signature class _UsageMetadataChunk: @@ -89,6 +96,92 @@ def __init__( self.total_tokens = total_tokens +def _get_anthropic_content_blocks(message: Any) -> Optional[list[dict[str, Any]]]: + """Extract content_blocks from a message, falling back to list-typed content. + + Portkey returns thinking/reasoning via content_blocks when + strict_open_ai_compliance=False. Some providers put them directly in + content as a list instead. + """ + content_blocks = getattr(message, "content_blocks", None) + if content_blocks is None: + content = getattr(message, "content", None) + if isinstance(content, list): + content_blocks = content + return content_blocks + + +def _iter_anthropic_content_blocks( + content_blocks: Optional[list[dict[str, Any]]], +) -> Iterable[tuple[str, str, Optional[str]]]: + """Yields (block_type, text, thought_signature) from content blocks. + + Handles both non-streaming format (type/thinking/text keys) and + streaming delta format (delta dict with thinking/text keys). + """ + if not content_blocks: + return [] + items: list[tuple[str, str, Optional[str]]] = [] + for block in content_blocks: + block_type = block.get("type") + thought_signature = _get_gemini_thought_signature( + block.get("thought_signature") + ) + if block_type == "thinking": + text = block.get("thinking") + if text: + items.append(("thinking", text, thought_signature)) + elif block_type == "text": + text = block.get("text") + if text: + items.append(("text", text, thought_signature)) + elif "delta" in block: + delta = block.get("delta", {}) + if delta.get("thinking"): + items.append(("thinking", delta["thinking"], thought_signature)) + elif delta.get("text"): + items.append(("text", delta["text"], thought_signature)) + return items + + +def _get_gemini_thought_signature(value: Any) -> Optional[str]: + """Normalize thought_signature to str. Gemini returns bytes, Portkey returns str.""" + if value is None: + return None + if isinstance(value, bytes): + return base64.b64encode(value).decode("utf-8") + if isinstance(value, str): + return value + return None + + +def _build_content_blocks( + thought_text: str, + thought_signature: Optional[str], + text: str, + text_signature: Optional[str], +) -> Optional[list[dict[str, Any]]]: + """Build content_blocks for aggregated streaming responses, omitting empty blocks.""" + blocks: list[dict[str, Any]] = [] + if thought_text: + blocks.append( + { + "type": "thinking", + "thinking": thought_text, + "thought_signature": thought_signature, + } + ) + if text: + blocks.append( + { + "type": "text", + "text": text, + "thought_signature": text_signature, + } + ) + return blocks or None + + def _safe_json_serialize(obj: Any) -> str: try: return json.dumps(obj, ensure_ascii=False) @@ -112,11 +205,17 @@ def _get_content(parts: Iterable[Any]) -> Union[list[dict], str]: for part in parts: text = getattr(part, "text", None) inline_data = getattr(part, "inline_data", None) + thought_signature = _get_gemini_thought_signature( + getattr(part, "thought_signature", None) + ) if text: # Return simple string when it's a single text part if isinstance(parts, list) and len(parts) == 1: return text - content_objects.append({"type": "text", "text": text}) + content_object: dict[str, Any] = {"type": "text", "text": text} + if thought_signature: + content_object["thought_signature"] = thought_signature + content_objects.append(content_object) elif ( inline_data and getattr(inline_data, "data", None) @@ -126,15 +225,24 @@ def _get_content(parts: Iterable[Any]) -> Union[list[dict], str]: data_uri = f"data:{inline_data.mime_type};base64,{b64}" if inline_data.mime_type.startswith("image"): content_objects.append( - {"type": "image_url", "image_url": {"url": data_uri}} + { + "type": "image_url", + "image_url": {"url": data_uri}, + } ) elif inline_data.mime_type.startswith("video"): content_objects.append( - {"type": "video_url", "video_url": {"url": data_uri}} + { + "type": "video_url", + "video_url": {"url": data_uri}, + } ) elif inline_data.mime_type.startswith("audio"): content_objects.append( - {"type": "audio_url", "audio_url": {"url": data_uri}} + { + "type": "audio_url", + "audio_url": {"url": data_uri}, + } ) elif inline_data.mime_type == "application/pdf": content_objects.append( @@ -181,18 +289,22 @@ def _content_to_message_param(content: Any) -> Union[dict, list[dict]]: for part in getattr(content, "parts", []) or []: function_call = getattr(part, "function_call", None) if function_call: - tool_calls.append( - { - "type": "function", - "id": getattr(function_call, "id", None), - "function": { - "name": getattr(function_call, "name", None), - "arguments": _safe_json_serialize( - getattr(function_call, "args", None) - ), - }, - } + tool_call: dict[str, Any] = { + "type": "function", + "id": getattr(function_call, "id", None), + "function": { + "name": getattr(function_call, "name", None), + "arguments": _safe_json_serialize( + getattr(function_call, "args", None) + ), + }, + } + thought_signature = _get_gemini_thought_signature( + getattr(part, "thought_signature", None) ) + if thought_signature: + tool_call["thought_signature"] = thought_signature + tool_calls.append(tool_call) elif getattr(part, "text", None) or getattr(part, "inline_data", None): content_present = True @@ -281,7 +393,8 @@ def _model_response_to_chunk( response: Any, ) -> Generator[ Tuple[ - Optional[Union[_TextChunk, _FunctionChunk, _UsageMetadataChunk]], Optional[str] + Optional[Union[_TextChunk, _FunctionChunk, _UsageMetadataChunk, _ThoughtChunk]], + Optional[str], ], None, None, @@ -298,27 +411,54 @@ def _model_response_to_chunk( message = choice0.message if message: - if getattr(message, "content", None): + reasoning_content = getattr(message, "reasoning_content", None) + if reasoning_content: + yield _ThoughtChunk(text=reasoning_content), finish_reason + + content_blocks = _get_anthropic_content_blocks(message) + has_content_blocks = bool(content_blocks) + for block_type, text, thought_signature in _iter_anthropic_content_blocks( + content_blocks + ): + if block_type == "thinking": + yield ( + _ThoughtChunk(text=text, thought_signature=thought_signature), + finish_reason, + ) + elif block_type == "text": + yield ( + _TextChunk(text=text, thought_signature=thought_signature), + finish_reason, + ) + + if not has_content_blocks and getattr(message, "content", None): yield _TextChunk(text=message.content), finish_reason tool_calls = getattr(message, "tool_calls", None) if tool_calls: for tool_call in tool_calls: if getattr(tool_call, "type", None) == "function": - yield _FunctionChunk( - id=getattr(tool_call, "id", None), - name=getattr( - getattr(tool_call, "function", None), "name", None - ), - args=getattr( - getattr(tool_call, "function", None), "arguments", None + yield ( + _FunctionChunk( + id=getattr(tool_call, "id", None), + name=getattr( + getattr(tool_call, "function", None), "name", None + ), + args=getattr( + getattr(tool_call, "function", None), + "arguments", + None, + ), + index=getattr(tool_call, "index", 0), ), - index=getattr(tool_call, "index", 0), - ), finish_reason + finish_reason, + ) if finish_reason and not ( (getattr(message, "content", None)) or (getattr(message, "tool_calls", None)) + or reasoning_content + or has_content_blocks ): yield None, finish_reason @@ -327,21 +467,54 @@ def _model_response_to_chunk( usage = getattr(response, "usage", None) if usage: - yield _UsageMetadataChunk( - prompt_tokens=getattr(usage, "prompt_tokens", 0), - completion_tokens=getattr(usage, "completion_tokens", 0), - total_tokens=getattr(usage, "total_tokens", 0), - ), None + yield ( + _UsageMetadataChunk( + prompt_tokens=getattr(usage, "prompt_tokens", 0), + completion_tokens=getattr(usage, "completion_tokens", 0), + total_tokens=getattr(usage, "total_tokens", 0), + ), + None, + ) -def _message_to_generate_content_response(message: Any, is_partial: bool = False) -> "LlmResponse": # type: ignore[name-defined] +def _message_to_generate_content_response( + message: Any, is_partial: bool = False +) -> "LlmResponse": # type: ignore[name-defined] """Convert a Portkey-style message object to ADK LlmResponse.""" from google.genai import types as genai_types # type: ignore from google.adk.models.llm_response import LlmResponse # type: ignore parts: list[Any] = [] - if getattr(message, "content", None): - parts.append(genai_types.Part.from_text(text=message.content)) + + content_blocks = _get_anthropic_content_blocks(message) + if content_blocks: + # content_blocks take priority; they carry both thinking and text with signatures + for block_type, text, thought_signature in _iter_anthropic_content_blocks( + content_blocks + ): + if block_type == "thinking": + thought_part = genai_types.Part.from_text(text=text) + thought_part.thought = True + if thought_signature: + # stubs say bytes | None but _get_gemini_thought_signature + # and _iter_anthropic_content_blocks return str; works at runtime + thought_part.thought_signature = thought_signature # type: ignore[assignment] + parts.append(thought_part) + elif block_type == "text": + text_part = genai_types.Part.from_text(text=text) + if thought_signature: + # stubs say bytes | None but _get_gemini_thought_signature + # and _iter_anthropic_content_blocks return str; works at runtime + text_part.thought_signature = thought_signature # type: ignore[assignment] + parts.append(text_part) + else: + reasoning_content = getattr(message, "reasoning_content", None) + if reasoning_content: + thought_part = genai_types.Part.from_text(text=reasoning_content) + thought_part.thought = True + parts.append(thought_part) + if getattr(message, "content", None): + parts.append(genai_types.Part.from_text(text=message.content)) if getattr(message, "tool_calls", None): for tool_call in message.tool_calls: @@ -443,13 +616,30 @@ def _get_completion_inputs( return messages, tools, response_format +def _get_thinking_config(llm_request: "LlmRequest") -> Optional[dict[str, Any]]: # type: ignore[name-defined] + config = getattr(llm_request, "config", None) + thinking_config = getattr(config, "thinking_config", None) if config else None + if not thinking_config: + return None + include_thoughts = getattr(thinking_config, "include_thoughts", None) + thinking_budget = getattr(thinking_config, "thinking_budget", None) + if not include_thoughts: + return None + result: dict[str, Any] = {"type": "enabled"} + if thinking_budget: + result["budget_tokens"] = thinking_budget + return result + + # ----------------------------- main adapter --------------------------------- class PortkeyAdk(_AdkBaseLlm): # type: ignore[misc] """ADK `BaseLlm` adapter backed by Portkey Async client.""" - def __init__(self, model: str, api_key: Optional[str] = None, **kwargs: Any) -> None: # type: ignore[override] + def __init__( + self, model: str, api_key: Optional[str] = None, **kwargs: Any + ) -> None: # type: ignore[override] if not _HAS_ADK: raise ImportError( "google-adk is not installed. Install with: pip install 'portkey-ai[adk]'" @@ -461,7 +651,9 @@ def __init__(self, model: str, api_key: Optional[str] = None, **kwargs: Any) -> sys_role if sys_role in ("developer", "system") else "developer" ) - super().__init__(model=model, **{k: v for k, v in kwargs.items() if k != "model"}) # type: ignore[misc] + super().__init__( + model=model, **{k: v for k, v in kwargs.items() if k != "model"} + ) # type: ignore[misc] # Set up Portkey client client_args: dict[str, Any] = {} @@ -478,6 +670,9 @@ def __init__(self, model: str, api_key: Optional[str] = None, **kwargs: Any) -> client_args["provider"] = kwargs.pop("provider") if "Authorization" in kwargs: client_args["Authorization"] = kwargs.pop("Authorization") + client_args["strict_open_ai_compliance"] = kwargs.pop( + "strict_open_ai_compliance", False + ) self._client = AsyncPortkey(**client_args) # type: ignore[arg-type] @@ -488,7 +683,9 @@ def __init__(self, model: str, api_key: Optional[str] = None, **kwargs: Any) -> self._additional_args.pop("tools", None) self._additional_args.pop("stream", None) - async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = False) -> AsyncGenerator["LlmResponse", None]: # type: ignore[override,name-defined] + async def generate_content_async( + self, llm_request: "LlmRequest", stream: bool = False + ) -> AsyncGenerator["LlmResponse", None]: # type: ignore[override,name-defined] """Generate ADK LlmResponse objects using Portkey Chat Completions.""" # Use ADK BaseLlm helper to ensure a user message exists so model can respond self._maybe_append_user_content(llm_request) @@ -496,6 +693,7 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = messages, tools, response_format = _get_completion_inputs( llm_request, getattr(self, "_system_role", "developer") ) + thinking_config = _get_thinking_config(llm_request) completion_args: dict[str, Any] = { "model": getattr(self, "model", None), @@ -504,6 +702,8 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = # Only include response_format if we successfully converted it **({"response_format": response_format} if response_format else {}), } + if thinking_config: + completion_args["thinking"] = thinking_config completion_args.update(self._additional_args) if tools and "tool_choice" not in completion_args: # Encourage tool use when functions are provided, mirroring Strands behavior @@ -512,6 +712,9 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = if stream: # Aggregate streaming text and tool calls to yield ADK LlmResponse objects text_accum = "" + thought_accum = "" + text_signature = None + thought_signature = None function_calls: dict[int, dict[str, Any]] = {} fallback_index = 0 usage_metadata = None @@ -519,7 +722,9 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = aggregated_llm_response_with_tool_call = None # Await the creation to obtain an async iterator for streaming - stream_obj = await self._client.chat.completions.create(stream=True, **completion_args) # type: ignore[arg-type] + stream_obj = await self._client.chat.completions.create( + stream=True, **completion_args + ) # type: ignore[arg-type] stream_iter = cast(AsyncIterator[Any], stream_obj) async for part in stream_iter: for chunk, finish_reason in _model_response_to_chunk(part): @@ -540,12 +745,50 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = function_calls[idx]["id"] = ( chunk.id or function_calls[idx]["id"] or str(idx) ) + elif isinstance(chunk, _ThoughtChunk): + thought_accum += chunk.text + if chunk.thought_signature: + thought_signature = chunk.thought_signature + yield _message_to_generate_content_response( + type( + "Msg", + (), + { + "content": None, + "reasoning_content": chunk.text, + "content_blocks": [ + { + "type": "thinking", + "thinking": chunk.text, + "thought_signature": chunk.thought_signature, + } + ], + "tool_calls": None, + }, + )(), + is_partial=True, + ) elif isinstance(chunk, _TextChunk): text_accum += chunk.text + if chunk.thought_signature: + text_signature = chunk.thought_signature # Yield partials for better interactivity yield _message_to_generate_content_response( type( - "Msg", (), {"content": chunk.text, "tool_calls": None} + "Msg", + (), + { + "content": chunk.text, + "reasoning_content": None, + "content_blocks": [ + { + "type": "text", + "text": chunk.text, + "thought_signature": chunk.thought_signature, + } + ], + "tool_calls": None, + }, )(), is_partial=True, ) @@ -587,18 +830,45 @@ async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = aggregated_llm_response_with_tool_call = ( _message_to_generate_content_response( type( - "Msg", (), {"content": "", "tool_calls": tool_calls} + "Msg", + (), + { + "content": "", + "reasoning_content": thought_accum or None, + "content_blocks": _build_content_blocks( + thought_accum, + thought_signature, + text_accum, + text_signature, + ), + "tool_calls": tool_calls, + }, )() ) ) function_calls.clear() - elif finish_reason == "stop" and text_accum: + elif finish_reason == "stop" and (text_accum or thought_accum): aggregated_llm_response = _message_to_generate_content_response( type( - "Msg", (), {"content": text_accum, "tool_calls": None} + "Msg", + (), + { + "content": text_accum or None, + "reasoning_content": thought_accum or None, + "content_blocks": _build_content_blocks( + thought_accum, + thought_signature, + text_accum, + text_signature, + ), + "tool_calls": None, + }, )() ) text_accum = "" + thought_accum = "" + text_signature = None + thought_signature = None # End of stream: yield aggregated responses (attach usage if available) if aggregated_llm_response: diff --git a/tests/integrations/test_adk_adapter.py b/tests/integrations/test_adk_adapter.py index f244392b..6c1b427a 100644 --- a/tests/integrations/test_adk_adapter.py +++ b/tests/integrations/test_adk_adapter.py @@ -1,4 +1,3 @@ -import asyncio from typing import Any, AsyncIterator, Optional import pytest @@ -9,17 +8,39 @@ from google.adk.models.llm_request import LlmRequest # type: ignore from google.genai import types as genai_types # type: ignore -from portkey_ai.integrations.adk import PortkeyAdk +from portkey_ai.integrations.adk import ( + PortkeyAdk, + _get_anthropic_content_blocks, + _iter_anthropic_content_blocks, + _get_gemini_thought_signature, + _get_thinking_config, +) class _FakeDelta: - def __init__(self, content: Optional[str] = None): + def __init__( + self, + content: Optional[str] = None, + reasoning_content: Optional[str] = None, + content_blocks: Optional[list[dict[str, Any]]] = None, + ): self.content = content + self.reasoning_content = reasoning_content + if content_blocks is not None: + self.content_blocks = content_blocks class _FakeMessage: - def __init__(self, content: Optional[str] = None): + def __init__( + self, + content: Optional[str] = None, + reasoning_content: Optional[str] = None, + content_blocks: Optional[list[dict[str, Any]]] = None, + ): self.content = content + self.reasoning_content = reasoning_content + if content_blocks is not None: + self.content_blocks = content_blocks self.tool_calls = None @@ -36,9 +57,21 @@ def __init__( class _FakeResponse: - def __init__(self, message_text: str): + def __init__( + self, + message_text: str, + reasoning_content: Optional[str] = None, + content_blocks: Optional[list[dict[str, Any]]] = None, + ): self.choices = [ - _FakeChoice(message=_FakeMessage(message_text), finish_reason="stop") + _FakeChoice( + message=_FakeMessage( + message_text, + reasoning_content=reasoning_content, + content_blocks=content_blocks, + ), + finish_reason="stop", + ) ] self.usage = type( "Usage", (), {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} @@ -74,8 +107,9 @@ async def fake_create(**kwargs: Any) -> _FakeResponse: assert not getattr(resp, "partial", False) assert resp.content and resp.content.parts for p in resp.content.parts: - if getattr(p, "text", None): - outputs.append(p.text) + text = getattr(p, "text", None) + if text: + outputs.append(text) assert "".join(outputs).strip() == "Hello world!" @@ -109,13 +143,464 @@ async def fake_create(**kwargs: Any) -> AsyncIterator[Any] | _FakeResponse: async for resp in llm.generate_content_async(req, stream=True): assert resp.content and resp.content.parts - text_parts = [p.text for p in resp.content.parts if getattr(p, "text", None)] + text_parts = [ + p.text for p in resp.content.parts if getattr(p, "text", None) is not None + ] if getattr(resp, "partial", False): - partial_text.extend(text_parts) + partial_text.extend([t for t in text_parts if t]) else: - final_text.extend(text_parts) + final_text.extend([t for t in text_parts if t]) - # Partial updates should reflect the stream pieces assert "".join(partial_text) == "Hello world!" - # Final message should be aggregated once assert "".join(final_text) == "Hello world!" + + +@pytest.mark.asyncio +async def test_non_streaming_with_reasoning_content( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") + + async def fake_create(**kwargs: Any) -> _FakeResponse: + return _FakeResponse( + message_text="The answer is 42.", + reasoning_content="Let me think about this step by step...", + ) + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="@openai/gpt-4o-mini", text="test") + thought_parts: list[str] = [] + text_parts: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=False): + assert resp.content and resp.content.parts + for p in resp.content.parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + thought_parts.append(text) + elif text: + text_parts.append(text) + + assert "".join(thought_parts) == "Let me think about this step by step..." + assert "".join(text_parts) == "The answer is 42." + + +@pytest.mark.asyncio +async def test_streaming_with_reasoning_content( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") + + part1 = type( + "Chunk", + (), + {"choices": [_FakeChoice(delta=_FakeDelta(reasoning_content="Thinking..."))]}, + )() + part2 = type( + "Chunk", (), {"choices": [_FakeChoice(delta=_FakeDelta(content="Answer: 42"))]} + )() + part3 = type( + "Chunk", + (), + {"choices": [_FakeChoice(message=_FakeMessage(None), finish_reason="stop")]}, + )() + + async def fake_stream_gen() -> AsyncIterator[Any]: + yield part1 + yield part2 + yield part3 + + async def fake_create(**kwargs: Any) -> AsyncIterator[Any] | _FakeResponse: + if kwargs.get("stream"): + return fake_stream_gen() + return _FakeResponse(message_text="unused") + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="@openai/gpt-4o-mini", text="test") + partial_thoughts: list[str] = [] + partial_text: list[str] = [] + final_thoughts: list[str] = [] + final_text: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=True): + assert resp.content and resp.content.parts + for p in resp.content.parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + if getattr(resp, "partial", False): + partial_thoughts.append(text) + else: + final_thoughts.append(text) + elif text: + if getattr(resp, "partial", False): + partial_text.append(text) + else: + final_text.append(text) + + assert "".join(partial_thoughts) == "Thinking..." + assert "".join(partial_text) == "Answer: 42" + assert "".join(final_thoughts) == "Thinking..." + assert "".join(final_text) == "Answer: 42" + + +@pytest.mark.asyncio +async def test_non_streaming_without_reasoning_content( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") + + async def fake_create(**kwargs: Any) -> _FakeResponse: + return _FakeResponse(message_text="Simple response") + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="@openai/gpt-4o-mini", text="test") + thought_parts: list[str] = [] + text_parts: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=False): + assert resp.content and resp.content.parts + for p in resp.content.parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + thought_parts.append(text) + elif text: + text_parts.append(text) + + assert thought_parts == [] + assert "".join(text_parts) == "Simple response" + + +@pytest.mark.asyncio +async def test_non_streaming_with_content_blocks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="gemini-2.5-pro", api_key="test") + + content_blocks = [ + {"type": "thinking", "thinking": "Let me analyze this problem..."}, + {"type": "text", "text": "The answer is 4."}, + ] + + async def fake_create(**kwargs: Any) -> _FakeResponse: + return _FakeResponse( + message_text="The answer is 4.", + content_blocks=content_blocks, + ) + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="gemini-2.5-pro", text="What is 2+2?") + thought_parts: list[str] = [] + text_parts: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=False): + assert resp.content and resp.content.parts + for p in resp.content.parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + thought_parts.append(text) + elif text: + text_parts.append(text) + + assert "".join(thought_parts) == "Let me analyze this problem..." + assert "".join(text_parts) == "The answer is 4." + + +@pytest.mark.asyncio +async def test_streaming_with_content_blocks_delta_format( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="gemini-2.5-pro", api_key="test") + + part1 = type( + "Chunk", + (), + { + "choices": [ + _FakeChoice( + delta=_FakeDelta( + content_blocks=[ + {"index": 0, "delta": {"thinking": "Thinking part 1..."}} + ] + ) + ) + ] + }, + )() + part2 = type( + "Chunk", + (), + { + "choices": [ + _FakeChoice( + delta=_FakeDelta( + content_blocks=[ + {"index": 0, "delta": {"thinking": "Thinking part 2..."}} + ] + ) + ) + ] + }, + )() + part3 = type( + "Chunk", + (), + { + "choices": [ + _FakeChoice( + delta=_FakeDelta( + content_blocks=[{"index": 1, "delta": {"text": "Answer: 42"}}] + ) + ) + ] + }, + )() + part4 = type( + "Chunk", + (), + {"choices": [_FakeChoice(message=_FakeMessage(None), finish_reason="stop")]}, + )() + + async def fake_stream_gen() -> AsyncIterator[Any]: + yield part1 + yield part2 + yield part3 + yield part4 + + async def fake_create(**kwargs: Any) -> AsyncIterator[Any] | _FakeResponse: + if kwargs.get("stream"): + return fake_stream_gen() + return _FakeResponse(message_text="unused") + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="gemini-2.5-pro", text="test") + partial_thoughts: list[str] = [] + partial_text: list[str] = [] + final_thoughts: list[str] = [] + final_text: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=True): + assert resp.content and resp.content.parts + for p in resp.content.parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + if getattr(resp, "partial", False): + partial_thoughts.append(text) + else: + final_thoughts.append(text) + elif text: + if getattr(resp, "partial", False): + partial_text.append(text) + else: + final_text.append(text) + + assert "".join(partial_thoughts) == "Thinking part 1...Thinking part 2..." + assert "".join(partial_text) == "Answer: 42" + + +@pytest.mark.asyncio +async def test_non_streaming_with_thought_signature( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="gemini-2.5-pro", api_key="test") + + content_blocks = [ + { + "type": "thinking", + "thinking": "Deep thinking...", + "thought_signature": "sig123abc", + }, + {"type": "text", "text": "Final answer."}, + ] + + async def fake_create(**kwargs: Any) -> _FakeResponse: + return _FakeResponse( + message_text="Final answer.", + content_blocks=content_blocks, + ) + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="gemini-2.5-pro", text="test") + thought_signatures: list[str] = [] + + async for resp in llm.generate_content_async(req, stream=False): + assert resp.content and resp.content.parts + for p in resp.content.parts: + sig = getattr(p, "thought_signature", None) + if sig: + thought_signatures.append(sig) + + assert "sig123abc" in thought_signatures + + +@pytest.mark.asyncio +async def test_thinking_config_passed_to_request( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="gemini-2.5-pro", api_key="test") + captured_kwargs: dict[str, Any] = {} + + async def fake_create(**kwargs: Any) -> _FakeResponse: + captured_kwargs.update(kwargs) + return _FakeResponse(message_text="response") + + monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + + req = LlmRequest( + model="gemini-2.5-pro", + contents=[ + genai_types.Content( + role="user", + parts=[genai_types.Part.from_text(text="test")], + ) + ], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=1024, + ), + ), + ) + + async for _ in llm.generate_content_async(req, stream=False): + pass + + assert "thinking" in captured_kwargs + assert captured_kwargs["thinking"]["type"] == "enabled" + assert captured_kwargs["thinking"]["budget_tokens"] == 1024 + + +def test_get_anthropic_content_blocks_from_attribute() -> None: + msg = _FakeMessage( + content="text", + content_blocks=[{"type": "text", "text": "hello"}], + ) + blocks = _get_anthropic_content_blocks(msg) + assert blocks is not None + assert len(blocks) == 1 + assert blocks[0]["type"] == "text" + + +def test_get_anthropic_content_blocks_from_list_content() -> None: + msg = type( + "Msg", + (), + { + "content": [{"type": "thinking", "thinking": "thought"}], + "content_blocks": None, + }, + )() + blocks = _get_anthropic_content_blocks(msg) + assert blocks is not None + assert len(blocks) == 1 + assert blocks[0]["type"] == "thinking" + + +def test_get_anthropic_content_blocks_returns_none_for_string() -> None: + msg = type("Msg", (), {"content": "plain string", "content_blocks": None})() + blocks = _get_anthropic_content_blocks(msg) + assert blocks is None + + +def test_iter_anthropic_content_blocks_non_streaming_format() -> None: + blocks = [ + {"type": "thinking", "thinking": "thought text", "thought_signature": "sig1"}, + {"type": "text", "text": "response text"}, + ] + result = list(_iter_anthropic_content_blocks(blocks)) + assert len(result) == 2 + assert result[0] == ("thinking", "thought text", "sig1") + assert result[1] == ("text", "response text", None) + + +def test_iter_anthropic_content_blocks_streaming_delta_format() -> None: + blocks = [ + {"index": 0, "delta": {"thinking": "streaming thought"}}, + {"index": 1, "delta": {"text": "streaming text"}}, + ] + result = list(_iter_anthropic_content_blocks(blocks)) + assert len(result) == 2 + assert result[0] == ("thinking", "streaming thought", None) + assert result[1] == ("text", "streaming text", None) + + +def test_iter_anthropic_content_blocks_empty_delta() -> None: + blocks = [{"index": 0, "delta": {}}] + result = list(_iter_anthropic_content_blocks(blocks)) + assert len(result) == 0 + + +def test_iter_anthropic_content_blocks_none_input() -> None: + result = list(_iter_anthropic_content_blocks(None)) + assert result == [] + + +def test_get_gemini_thought_signature_string() -> None: + result = _get_gemini_thought_signature("signature_string") + assert result == "signature_string" + + +def test_get_gemini_thought_signature_bytes() -> None: + result = _get_gemini_thought_signature(b"binary_sig") + assert result == "YmluYXJ5X3NpZw==" + + +def test_get_gemini_thought_signature_none() -> None: + result = _get_gemini_thought_signature(None) + assert result is None + + +def test_get_thinking_config_with_budget() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=2048, + ), + ), + ) + result = _get_thinking_config(req) + assert result is not None + assert result["type"] == "enabled" + assert result["budget_tokens"] == 2048 + + +def test_get_thinking_config_without_budget() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + ), + ), + ) + result = _get_thinking_config(req) + assert result is not None + assert result["type"] == "enabled" + assert "budget_tokens" not in result + + +def test_get_thinking_config_none_when_not_configured() -> None: + req = LlmRequest(model="test", contents=[]) + result = _get_thinking_config(req) + assert result is None + + +def test_get_thinking_config_none_when_empty() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig(), + ), + ) + result = _get_thinking_config(req) + assert result is None