|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import ast |
17 | 18 | import base64 |
18 | 19 | import binascii |
19 | 20 | import json |
|
46 | 47 | from .tool_context import ToolContext |
47 | 48 |
|
48 | 49 | logger = logging.getLogger('google_adk.' + __name__) |
| 50 | +_LOAD_ARTIFACTS_TEXT_MARKER = '`load_artifacts` tool returned result:' |
49 | 51 |
|
50 | 52 |
|
51 | 53 | def _normalize_mime_type(mime_type: str | None) -> str | None: |
@@ -121,6 +123,47 @@ def _as_safe_part_for_llm( |
121 | 123 | ) |
122 | 124 |
|
123 | 125 |
|
| 126 | +def _artifact_names_from_response(response: Any) -> list[str]: |
| 127 | + if not isinstance(response, dict): |
| 128 | + return [] |
| 129 | + |
| 130 | + artifact_names = response.get('artifact_names', []) |
| 131 | + if isinstance(artifact_names, str): |
| 132 | + return [artifact_names] |
| 133 | + if not isinstance(artifact_names, list): |
| 134 | + return [] |
| 135 | + return [name for name in artifact_names if isinstance(name, str)] |
| 136 | + |
| 137 | + |
| 138 | +def _artifact_names_from_text_response(text: str | None) -> list[str]: |
| 139 | + if not text or _LOAD_ARTIFACTS_TEXT_MARKER not in text: |
| 140 | + return [] |
| 141 | + |
| 142 | + payload = text.split(_LOAD_ARTIFACTS_TEXT_MARKER, 1)[1].strip() |
| 143 | + try: |
| 144 | + response = ast.literal_eval(payload) |
| 145 | + except (SyntaxError, ValueError) as exc: |
| 146 | + logger.debug('Could not parse load_artifacts text response: %s', exc) |
| 147 | + return [] |
| 148 | + |
| 149 | + return _artifact_names_from_response(response) |
| 150 | + |
| 151 | + |
| 152 | +def _requested_artifact_names(content: types.Content) -> list[str]: |
| 153 | + artifact_names: list[str] = [] |
| 154 | + for part in content.parts or []: |
| 155 | + function_response = part.function_response |
| 156 | + if function_response and function_response.name == 'load_artifacts': |
| 157 | + artifact_names.extend( |
| 158 | + _artifact_names_from_response(function_response.response or {}) |
| 159 | + ) |
| 160 | + continue |
| 161 | + |
| 162 | + artifact_names.extend(_artifact_names_from_text_response(part.text)) |
| 163 | + |
| 164 | + return artifact_names |
| 165 | + |
| 166 | + |
124 | 167 | class LoadArtifactsTool(BaseTool): |
125 | 168 | """A tool that loads the artifacts and adds them to the session.""" |
126 | 169 |
|
@@ -210,46 +253,41 @@ async def _append_artifacts_to_llm_request( |
210 | 253 | # Attach the content of the artifacts if the model requests them. |
211 | 254 | # This only adds the content to the model request, instead of the session. |
212 | 255 | if llm_request.contents and llm_request.contents[-1].parts: |
213 | | - function_response = llm_request.contents[-1].parts[0].function_response |
214 | | - if function_response and function_response.name == 'load_artifacts': |
215 | | - response = function_response.response or {} |
216 | | - artifact_names = response.get('artifact_names', []) |
217 | | - for artifact_name in artifact_names: |
218 | | - # Try session-scoped first (default behavior) |
219 | | - artifact = await tool_context.load_artifact(artifact_name) |
220 | | - |
221 | | - # If not found and name doesn't already have user: prefix, |
222 | | - # try cross-session artifacts with user: prefix |
223 | | - if artifact is None and not artifact_name.startswith('user:'): |
224 | | - prefixed_name = f'user:{artifact_name}' |
225 | | - artifact = await tool_context.load_artifact(prefixed_name) |
226 | | - |
227 | | - if artifact is None: |
228 | | - logger.warning('Artifact "%s" not found, skipping', artifact_name) |
229 | | - continue |
230 | | - |
231 | | - artifact_part = _as_safe_part_for_llm(artifact, artifact_name) |
232 | | - if artifact_part is not artifact: |
233 | | - mime_type = ( |
234 | | - artifact.inline_data.mime_type if artifact.inline_data else None |
235 | | - ) |
236 | | - logger.debug( |
237 | | - 'Converted artifact "%s" (mime_type=%s) to text Part', |
238 | | - artifact_name, |
239 | | - mime_type, |
240 | | - ) |
241 | | - |
242 | | - llm_request.contents.append( |
243 | | - types.Content( |
244 | | - role='user', |
245 | | - parts=[ |
246 | | - types.Part.from_text( |
247 | | - text=f'Artifact {artifact_name} is:' |
248 | | - ), |
249 | | - artifact_part, |
250 | | - ], |
251 | | - ) |
| 256 | + artifact_names = _requested_artifact_names(llm_request.contents[-1]) |
| 257 | + for artifact_name in artifact_names: |
| 258 | + # Try session-scoped first (default behavior) |
| 259 | + artifact = await tool_context.load_artifact(artifact_name) |
| 260 | + |
| 261 | + # If not found and name doesn't already have user: prefix, |
| 262 | + # try cross-session artifacts with user: prefix |
| 263 | + if artifact is None and not artifact_name.startswith('user:'): |
| 264 | + prefixed_name = f'user:{artifact_name}' |
| 265 | + artifact = await tool_context.load_artifact(prefixed_name) |
| 266 | + |
| 267 | + if artifact is None: |
| 268 | + logger.warning('Artifact "%s" not found, skipping', artifact_name) |
| 269 | + continue |
| 270 | + |
| 271 | + artifact_part = _as_safe_part_for_llm(artifact, artifact_name) |
| 272 | + if artifact_part is not artifact: |
| 273 | + mime_type = ( |
| 274 | + artifact.inline_data.mime_type if artifact.inline_data else None |
| 275 | + ) |
| 276 | + logger.debug( |
| 277 | + 'Converted artifact "%s" (mime_type=%s) to text Part', |
| 278 | + artifact_name, |
| 279 | + mime_type, |
252 | 280 | ) |
253 | 281 |
|
| 282 | + llm_request.contents.append( |
| 283 | + types.Content( |
| 284 | + role='user', |
| 285 | + parts=[ |
| 286 | + types.Part.from_text(text=f'Artifact {artifact_name} is:'), |
| 287 | + artifact_part, |
| 288 | + ], |
| 289 | + ) |
| 290 | + ) |
| 291 | + |
254 | 292 |
|
255 | 293 | load_artifacts_tool = LoadArtifactsTool() |
0 commit comments