Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/google/adk/agents/live_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class LiveRequest(BaseModel):
close: bool = False
"""If set, close the queue. queue.shutdown() is only supported in Python 3.13+."""

turn_complete: bool = True
"""If set, content messages complete the current model turn."""


class LiveRequestQueue:
"""Queue used to send LiveRequest in a live(bidirectional streaming) way."""
Expand All @@ -66,8 +69,10 @@ def __init__(self):
def close(self):
self._queue.put_nowait(LiveRequest(close=True))

def send_content(self, content: types.Content):
self._queue.put_nowait(LiveRequest(content=content))
def send_content(self, content: types.Content, turn_complete: bool = True):
self._queue.put_nowait(
LiveRequest(content=content, turn_complete=turn_complete)
)

def send_realtime(self, blob: types.Blob):
self._queue.put_nowait(LiveRequest(blob=blob))
Expand Down
10 changes: 6 additions & 4 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,9 +771,9 @@ async def _send_to_model(
is_function_response = content.parts and any(
part.function_response for part in content.parts
)
if not is_function_response:
if not content.role:
content.role = 'user'
if not is_function_response and not content.role:
content.role = 'user'
if not is_function_response and live_request.turn_complete:
user_content_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
Expand All @@ -784,7 +784,9 @@ async def _send_to_model(
session=invocation_context.session,
event=user_content_event,
)
await llm_connection.send_content(live_request.content)
await llm_connection.send_content(
live_request.content, turn_complete=live_request.turn_complete
)

async def _receive_from_model(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,9 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
)
],
)
invocation_context.live_request_queue.send_content(updated_content)
invocation_context.live_request_queue.send_content(
updated_content, turn_complete=False
)
except asyncio.CancelledError:
raise # Re-raise to properly propagate the cancellation

Expand Down
7 changes: 5 additions & 2 deletions src/google/adk/models/base_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ async def send_history(self, history: list[types.Content]):
pass

@abstractmethod
async def send_content(self, content: types.Content):
async def send_content(
self, content: types.Content, turn_complete: bool = True
):
"""Sends a user content to the model.

The model will respond immediately upon receiving the content.
By default, the model will respond upon receiving the content.
If you send function responses, all parts in the content should be function
responses.

Args:
content: The content to send to the model.
turn_complete: Whether this content completes the model turn.
"""
pass

Expand Down
16 changes: 12 additions & 4 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ async def send_history(self, history: list[types.Content]):
else:
logger.info('no content is sent')

async def send_content(self, content: types.Content):
async def send_content(
self, content: types.Content, turn_complete: bool = True
):
"""Sends a user content to the gemini model.

The model will respond immediately upon receiving the content.
By default, the model will respond upon receiving the content.
If you send function responses, all parts in the content should be function
responses.

Args:
content: The content to send to the model.
turn_complete: Whether this content completes the model turn.
"""
assert content.parts
if content.parts[0].function_response:
Expand All @@ -129,7 +132,12 @@ async def send_content(self, content: types.Content):
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(
self._model_version
)
if is_gemini_31 and len(content.parts) == 1 and content.parts[0].text:
if (
turn_complete
and is_gemini_31
and len(content.parts) == 1
and content.parts[0].text
):
logger.debug('Using send_realtime_input for Gemini 3.1 text input')
await self._gemini_session.send_realtime_input(
text=content.parts[0].text
Expand All @@ -138,7 +146,7 @@ async def send_content(self, content: types.Content):
await self._gemini_session.send(
input=types.LiveClientContent(
turns=[content],
turn_complete=True,
turn_complete=turn_complete,
)
)

Expand Down
11 changes: 11 additions & 0 deletions tests/unittests/agents/test_live_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def test_send_content():
mock_put_nowait.assert_called_once_with(LiveRequest(content=content))


def test_send_content_sets_turn_complete():
queue = LiveRequestQueue()
content = MagicMock(spec=types.Content)

with patch.object(queue._queue, "put_nowait") as mock_put_nowait:
queue.send_content(content, turn_complete=False)
mock_put_nowait.assert_called_once_with(
LiveRequest(content=content, turn_complete=False)
)


def test_send_realtime():
queue = LiveRequestQueue()
blob = MagicMock(spec=types.Blob)
Expand Down
33 changes: 32 additions & 1 deletion tests/unittests/flows/llm_flows/test_base_llm_flow_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,36 @@ async def test_send_to_model_with_text_content(mock_llm_connection):
await flow._send_to_model(mock_llm_connection, invocation_context)

# Verify send_content was called instead of send_realtime
mock_llm_connection.send_content.assert_called_once_with(content)
mock_llm_connection.send_content.assert_called_once_with(
content, turn_complete=True
)
mock_llm_connection.send_realtime.assert_not_called()


@pytest.mark.asyncio
async def test_send_to_model_with_intermediate_text_content(
mock_llm_connection,
):
agent = Agent(name='test_agent', model='mock')
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)
invocation_context.live_request_queue = LiveRequestQueue()
invocation_context.session_service.append_event = mock.AsyncMock()

flow = TestBaseLlmFlow()

content = types.Content(
role='user', parts=[types.Part.from_text(text='progress')]
)
invocation_context.live_request_queue.send(
LiveRequest(content=content, turn_complete=False)
)
invocation_context.live_request_queue.close()

await flow._send_to_model(mock_llm_connection, invocation_context)

mock_llm_connection.send_content.assert_called_once_with(
content, turn_complete=False
)
invocation_context.session_service.append_event.assert_not_called()
16 changes: 16 additions & 0 deletions tests/unittests/models/test_gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ async def test_send_content_text(gemini_connection, mock_gemini_session):
assert call_args['input'].turn_complete is True


@pytest.mark.asyncio
async def test_send_content_text_can_keep_turn_open(
gemini_connection, mock_gemini_session
):
content = types.Content(
role='user', parts=[types.Part.from_text(text='progress')]
)

await gemini_connection.send_content(content, turn_complete=False)

mock_gemini_session.send.assert_called_once()
call_args = mock_gemini_session.send.call_args[1]
assert call_args['input'].turns == [content]
assert call_args['input'].turn_complete is False


@pytest.mark.asyncio
async def test_send_content_function_response(
gemini_connection, mock_gemini_session
Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,9 @@ def __init__(self, llm_responses: list[LlmResponse]):
async def send_history(self, history: list[types.Content]):
pass

async def send_content(self, content: types.Content):
async def send_content(
self, content: types.Content, turn_complete: bool = True
):
pass

async def send(self, data):
Expand Down
4 changes: 3 additions & 1 deletion tests/unittests/workflow/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,9 @@ def __init__(
async def send_history(self, history: list[types.Content]):
pass

async def send_content(self, content: types.Content):
async def send_content(
self, content: types.Content, turn_complete: bool = True
):
self.mock_model.live_contents.append(content)
self._input_event.set()

Expand Down