Skip to content

Commit 202acdc

Browse files
committed
fix: load artifacts from workflow text responses
1 parent 9670ce2 commit 202acdc

2 files changed

Lines changed: 146 additions & 39 deletions

File tree

src/google/adk/tools/load_artifacts_tool.py

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import ast
1718
import base64
1819
import binascii
1920
import json
@@ -46,6 +47,7 @@
4647
from .tool_context import ToolContext
4748

4849
logger = logging.getLogger('google_adk.' + __name__)
50+
_LOAD_ARTIFACTS_TEXT_MARKER = '`load_artifacts` tool returned result:'
4951

5052

5153
def _normalize_mime_type(mime_type: str | None) -> str | None:
@@ -121,6 +123,47 @@ def _as_safe_part_for_llm(
121123
)
122124

123125

126+
def _artifact_names_from_response(response: Any) -> list[str]:
127+
if not isinstance(response, dict):
128+
return []
129+
130+
artifact_names = response.get('artifact_names', [])
131+
if isinstance(artifact_names, str):
132+
return [artifact_names]
133+
if not isinstance(artifact_names, list):
134+
return []
135+
return [name for name in artifact_names if isinstance(name, str)]
136+
137+
138+
def _artifact_names_from_text_response(text: str | None) -> list[str]:
139+
if not text or _LOAD_ARTIFACTS_TEXT_MARKER not in text:
140+
return []
141+
142+
payload = text.split(_LOAD_ARTIFACTS_TEXT_MARKER, 1)[1].strip()
143+
try:
144+
response = ast.literal_eval(payload)
145+
except (SyntaxError, ValueError) as exc:
146+
logger.debug('Could not parse load_artifacts text response: %s', exc)
147+
return []
148+
149+
return _artifact_names_from_response(response)
150+
151+
152+
def _requested_artifact_names(content: types.Content) -> list[str]:
153+
artifact_names: list[str] = []
154+
for part in content.parts or []:
155+
function_response = part.function_response
156+
if function_response and function_response.name == 'load_artifacts':
157+
artifact_names.extend(
158+
_artifact_names_from_response(function_response.response or {})
159+
)
160+
continue
161+
162+
artifact_names.extend(_artifact_names_from_text_response(part.text))
163+
164+
return artifact_names
165+
166+
124167
class LoadArtifactsTool(BaseTool):
125168
"""A tool that loads the artifacts and adds them to the session."""
126169

@@ -210,46 +253,41 @@ async def _append_artifacts_to_llm_request(
210253
# Attach the content of the artifacts if the model requests them.
211254
# This only adds the content to the model request, instead of the session.
212255
if llm_request.contents and llm_request.contents[-1].parts:
213-
function_response = llm_request.contents[-1].parts[0].function_response
214-
if function_response and function_response.name == 'load_artifacts':
215-
response = function_response.response or {}
216-
artifact_names = response.get('artifact_names', [])
217-
for artifact_name in artifact_names:
218-
# Try session-scoped first (default behavior)
219-
artifact = await tool_context.load_artifact(artifact_name)
220-
221-
# If not found and name doesn't already have user: prefix,
222-
# try cross-session artifacts with user: prefix
223-
if artifact is None and not artifact_name.startswith('user:'):
224-
prefixed_name = f'user:{artifact_name}'
225-
artifact = await tool_context.load_artifact(prefixed_name)
226-
227-
if artifact is None:
228-
logger.warning('Artifact "%s" not found, skipping', artifact_name)
229-
continue
230-
231-
artifact_part = _as_safe_part_for_llm(artifact, artifact_name)
232-
if artifact_part is not artifact:
233-
mime_type = (
234-
artifact.inline_data.mime_type if artifact.inline_data else None
235-
)
236-
logger.debug(
237-
'Converted artifact "%s" (mime_type=%s) to text Part',
238-
artifact_name,
239-
mime_type,
240-
)
241-
242-
llm_request.contents.append(
243-
types.Content(
244-
role='user',
245-
parts=[
246-
types.Part.from_text(
247-
text=f'Artifact {artifact_name} is:'
248-
),
249-
artifact_part,
250-
],
251-
)
256+
artifact_names = _requested_artifact_names(llm_request.contents[-1])
257+
for artifact_name in artifact_names:
258+
# Try session-scoped first (default behavior)
259+
artifact = await tool_context.load_artifact(artifact_name)
260+
261+
# If not found and name doesn't already have user: prefix,
262+
# try cross-session artifacts with user: prefix
263+
if artifact is None and not artifact_name.startswith('user:'):
264+
prefixed_name = f'user:{artifact_name}'
265+
artifact = await tool_context.load_artifact(prefixed_name)
266+
267+
if artifact is None:
268+
logger.warning('Artifact "%s" not found, skipping', artifact_name)
269+
continue
270+
271+
artifact_part = _as_safe_part_for_llm(artifact, artifact_name)
272+
if artifact_part is not artifact:
273+
mime_type = (
274+
artifact.inline_data.mime_type if artifact.inline_data else None
275+
)
276+
logger.debug(
277+
'Converted artifact "%s" (mime_type=%s) to text Part',
278+
artifact_name,
279+
mime_type,
252280
)
253281

282+
llm_request.contents.append(
283+
types.Content(
284+
role='user',
285+
parts=[
286+
types.Part.from_text(text=f'Artifact {artifact_name} is:'),
287+
artifact_part,
288+
],
289+
)
290+
)
291+
254292

255293
load_artifacts_tool = LoadArtifactsTool()

tests/unittests/tools/test_load_artifacts_tool.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,75 @@ async def test_load_artifacts_keeps_supported_mime_types():
144144
assert artifact_part.inline_data.mime_type == 'application/pdf'
145145

146146

147+
@mark.asyncio
148+
async def test_load_artifacts_reads_workflow_text_response():
149+
"""Workflow context can stringify tool responses from other nodes."""
150+
artifact_name = 'invoice.txt'
151+
artifact = types.Part.from_text(text='invoice total: 42')
152+
153+
tool_context = _StubToolContext({artifact_name: artifact})
154+
llm_request = LlmRequest(
155+
contents=[
156+
types.Content(
157+
role='user',
158+
parts=[
159+
types.Part.from_text(text='For context:'),
160+
types.Part.from_text(
161+
text=(
162+
'[workflow_node] `load_artifacts` tool returned'
163+
" result: {'artifact_names': ['invoice.txt'],"
164+
" 'status': 'ok'}"
165+
)
166+
),
167+
],
168+
)
169+
]
170+
)
171+
172+
await load_artifacts_tool.process_llm_request(
173+
tool_context=tool_context, llm_request=llm_request
174+
)
175+
176+
assert llm_request.contents[-1].parts[0].text == (
177+
f'Artifact {artifact_name} is:'
178+
)
179+
assert llm_request.contents[-1].parts[1].text == 'invoice total: 42'
180+
181+
182+
@mark.asyncio
183+
async def test_load_artifacts_checks_all_function_response_parts():
184+
"""The load_artifacts response may not be the first part in a turn."""
185+
artifact_name = 'notes.txt'
186+
artifact = types.Part.from_text(text='important notes')
187+
188+
tool_context = _StubToolContext({artifact_name: artifact})
189+
llm_request = LlmRequest(
190+
contents=[
191+
types.Content(
192+
role='user',
193+
parts=[
194+
types.Part.from_text(text='Done.'),
195+
types.Part(
196+
function_response=types.FunctionResponse(
197+
name='load_artifacts',
198+
response={'artifact_names': [artifact_name]},
199+
)
200+
),
201+
],
202+
)
203+
]
204+
)
205+
206+
await load_artifacts_tool.process_llm_request(
207+
tool_context=tool_context, llm_request=llm_request
208+
)
209+
210+
assert llm_request.contents[-1].parts[0].text == (
211+
f'Artifact {artifact_name} is:'
212+
)
213+
assert llm_request.contents[-1].parts[1].text == 'important notes'
214+
215+
147216
def test_maybe_base64_to_bytes_decodes_standard_base64():
148217
"""Standard base64 encoded strings are decoded correctly."""
149218
original = b'hello world'

0 commit comments

Comments
 (0)