Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions agentplatform/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
from google.auth.transport import requests as requests_auth

if TYPE_CHECKING:
try:
from google.genai import types

types = types
except (ImportError, AttributeError):
types = Any

try:
from google.adk.events.event import Event

Expand Down Expand Up @@ -1754,6 +1761,253 @@ async def async_search_memory(self, *, user_id: str, query: str):
query=query,
)

async def async_save_artifact(
self,
*,
user_id: str,
filename: str,
artifact: Union["types.Part", Dict[str, Any]],
session_id: Optional[str] = None,
custom_metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""Saves an artifact to the artifact service storage.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
artifact (Union[types.Part, Dict[str, Any]]):
Required. The artifact to save.
session_id (Optional[str]):
Optional. The ID of the session.
custom_metadata (Optional[Dict[str, Any]]):
Optional. Custom metadata to associate with the artifact.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
int: The revision ID.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").save_artifact(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
artifact=artifact,
session_id=session_id,
custom_metadata=custom_metadata,
**kwargs,
)

async def async_load_artifact(
self,
*,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
**kwargs,
):
"""Gets an artifact from the artifact service storage.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
session_id (Optional[str]):
Optional. The ID of the session.
version (Optional[int]):
Optional. The version of the artifact.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
Optional[types.Part]: The artifact or None if not found.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").load_artifact(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
session_id=session_id,
version=version,
**kwargs,
)

async def async_list_artifact_keys(
self,
*,
user_id: str,
session_id: Optional[str] = None,
**kwargs,
):
"""Lists all the artifact filenames within a session.

Args:
user_id (str):
Required. The ID of the user.
session_id (Optional[str]):
Optional. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
list[str]: A list of artifact filenames.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").list_artifact_keys(
app_name=self._app_name(),
user_id=user_id,
session_id=session_id,
**kwargs,
)

async def async_delete_artifact(
self,
*,
user_id: str,
filename: str,
session_id: Optional[str] = None,
**kwargs,
):
"""Deletes an artifact.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
session_id (Optional[str]):
Optional. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
await self._tmpl_attrs.get("artifact_service").delete_artifact(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
session_id=session_id,
**kwargs,
)

async def async_list_versions(
self,
*,
user_id: str,
filename: str,
session_id: Optional[str] = None,
**kwargs,
):
"""Lists all versions of an artifact.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
session_id (Optional[str]):
Optional. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
list[int]: A list of all available versions of the artifact.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").list_versions(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
session_id=session_id,
**kwargs,
)

async def async_list_artifact_versions(
self,
*,
user_id: str,
filename: str,
session_id: Optional[str] = None,
**kwargs,
):
"""Lists all versions and their metadata for a specific artifact.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
session_id (Optional[str]):
Optional. The ID of the session.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
list[ArtifactVersion]: A list of ArtifactVersion objects.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").list_artifact_versions(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
session_id=session_id,
**kwargs,
)

async def async_get_artifact_version(
self,
*,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
**kwargs,
):
"""Gets the metadata for a specific version of an artifact.

Args:
user_id (str):
Required. The ID of the user.
filename (str):
Required. The filename of the artifact.
session_id (Optional[str]):
Optional. The ID of the session.
version (Optional[int]):
Optional. The version number of the artifact.
**kwargs (dict[str, Any]):
Optional. Additional keyword arguments to pass to the
artifact service.

Returns:
Optional[ArtifactVersion]: An ArtifactVersion object or None.
"""
if not self._tmpl_attrs.get("artifact_service"):
self.set_up()
return await self._tmpl_attrs.get("artifact_service").get_artifact_version(
app_name=self._app_name(),
user_id=user_id,
filename=filename,
session_id=session_id,
version=version,
**kwargs,
)

def register_operations(self) -> Dict[str, List[str]]:
"""Registers the operations of the ADK application."""
return {
Expand All @@ -1770,6 +2024,13 @@ def register_operations(self) -> Dict[str, List[str]]:
"async_delete_session",
"async_add_session_to_memory",
"async_search_memory",
"async_save_artifact",
"async_load_artifact",
"async_list_artifact_keys",
"async_delete_artifact",
"async_list_versions",
"async_list_artifact_versions",
"async_get_artifact_version",
],
"stream": ["stream_query"],
"async_stream": [
Expand Down
Loading
Loading