From 92b35baaa634f217b462df4a53123d48bbf5559e Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 29 May 2026 07:24:18 -0700 Subject: [PATCH] feat: Add async artifact management to AdkApp PiperOrigin-RevId: 923400255 --- agentplatform/agent_engines/templates/adk.py | 261 +++++++++++++++++ .../frameworks/test_frameworks_adk.py | 148 +++++++++- .../test_agent_engine_templates_adk.py | 86 ++++++ .../test_reasoning_engine_templates_adk.py | 88 ++++++ vertexai/agent_engines/templates/adk.py | 261 +++++++++++++++++ .../reasoning_engines/templates/adk.py | 267 +++++++++++++++++- 6 files changed, 1104 insertions(+), 7 deletions(-) diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index c68cc1d45d..29a1dfcf82 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -37,6 +37,13 @@ from google.auth.transport import requests as requests_auth if TYPE_CHECKING: + try: + from google.genai import types + + types = types + except (ImportError, AttributeError): + types = Any + try: from google.adk.events.event import Event @@ -1754,6 +1761,253 @@ async def async_search_memory(self, *, user_id: str, query: str): query=query, ) + async def async_save_artifact( + self, + *, + user_id: str, + filename: str, + artifact: Union["types.Part", Dict[str, Any]], + session_id: Optional[str] = None, + custom_metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Saves an artifact to the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + artifact (Union[types.Part, Dict[str, Any]]): + Required. The artifact to save. + session_id (Optional[str]): + Optional. The ID of the session. + custom_metadata (Optional[Dict[str, Any]]): + Optional. Custom metadata to associate with the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + int: The revision ID. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").save_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + artifact=artifact, + session_id=session_id, + custom_metadata=custom_metadata, + **kwargs, + ) + + async def async_load_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets an artifact from the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[types.Part]: The artifact or None if not found. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").load_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + + async def async_list_artifact_keys( + self, + *, + user_id: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all the artifact filenames within a session. + + Args: + user_id (str): + Required. The ID of the user. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[str]: A list of artifact filenames. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + + async def async_delete_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Deletes an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + await self._tmpl_attrs.get("artifact_service").delete_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[int]: A list of all available versions of the artifact. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_artifact_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions and their metadata for a specific artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[ArtifactVersion]: A list of ArtifactVersion objects. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_get_artifact_version( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets the metadata for a specific version of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version number of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[ArtifactVersion]: An ArtifactVersion object or None. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").get_artifact_version( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the ADK application.""" return { @@ -1770,6 +2024,13 @@ def register_operations(self) -> Dict[str, List[str]]: "async_delete_session", "async_add_session_to_memory", "async_search_memory", + "async_save_artifact", + "async_load_artifact", + "async_list_artifact_keys", + "async_delete_artifact", + "async_list_versions", + "async_list_artifact_versions", + "async_get_artifact_version", ], "stream": ["stream_query"], "async_stream": [ diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py index c7256a7c89..6acb0e905e 100644 --- a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py +++ b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py @@ -646,6 +646,146 @@ def test_delete_session(self, get_project_id_mock: mock.Mock): response0 = app.list_sessions(user_id=_TEST_USER_ID) assert not response0.sessions + @pytest.mark.asyncio + async def test_async_artifact_management(self, get_project_id_mock: mock.Mock): + app = adk_template.AdkApp(agent=_TEST_AGENT) + session = await app.async_create_session(user_id=_TEST_USER_ID) + session_id = session["id"] + + part = types.Part(text="test artifact content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + artifact=part, + session_id=session_id, + ) + assert version == 0 + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert loaded.text == "test artifact content" + + keys = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert keys == ["test.txt"] + + versions = await app.async_list_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert versions == [0] + + art_versions = await app.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert len(art_versions) == 1 + assert art_versions[0].version == 0 + + art_ver = await app.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + version=0, + ) + assert art_ver.version == 0 + + await app.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + keys_after = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert not keys_after + + @pytest.mark.asyncio + async def test_async_artifact_management_lazy_init( + self, get_project_id_mock: mock.Mock + ): + part = types.Part(text="test lazy content") + + # 1. Save Artifact lazy init + app1 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app1._tmpl_attrs.get("artifact_service") is None + version = await app1.async_save_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + artifact=part, + session_id="lazy_session", + ) + assert version == 0 + assert app1._tmpl_attrs.get("artifact_service") is not None + + # 2. Load Artifact lazy init + app2 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app2._tmpl_attrs.get("artifact_service") is None + await app2.async_load_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app2._tmpl_attrs.get("artifact_service") is not None + + # 3. List keys lazy init + app3 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app3._tmpl_attrs.get("artifact_service") is None + await app3.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id="lazy_session", + ) + assert app3._tmpl_attrs.get("artifact_service") is not None + + # 4. Delete lazy init + app4 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app4._tmpl_attrs.get("artifact_service") is None + await app4.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app4._tmpl_attrs.get("artifact_service") is not None + + # 5. List versions lazy init + app5 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app5._tmpl_attrs.get("artifact_service") is None + await app5.async_list_versions( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app5._tmpl_attrs.get("artifact_service") is not None + + # 6. List artifact versions lazy init + app6 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app6._tmpl_attrs.get("artifact_service") is None + await app6.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app6._tmpl_attrs.get("artifact_service") is not None + + # 7. Get version lazy init + app7 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app7._tmpl_attrs.get("artifact_service") is None + await app7.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + version=0, + ) + assert app7._tmpl_attrs.get("artifact_service") is not None + @pytest.mark.asyncio async def test_async_add_session_to_memory_dict( self, @@ -1125,9 +1265,7 @@ def test_create_default_telemetry_enablement( from agentplatform._genai import types as _genai_types mock_operation = mock.Mock() - mock_operation.name = ( - "projects/test-project/locations/us-central1/reasoningEngines/123456/operations/789" - ) + mock_operation.name = "projects/test-project/locations/us-central1/reasoningEngines/123456/operations/789" mock_create.return_value = mock_operation mock_await_operation.return_value = _genai_types.AgentEngineOperation( response=_genai_types.ReasoningEngine( @@ -1180,9 +1318,7 @@ def test_update_default_telemetry_enablement( from agentplatform._genai import types as _genai_types mock_operation = mock.Mock() - mock_operation.name = ( - "projects/test-project/locations/us-central1/reasoningEngines/123456/operations/789" - ) + mock_operation.name = "projects/test-project/locations/us-central1/reasoningEngines/123456/operations/789" mock_update.return_value = mock_operation mock_await_operation.return_value = _genai_types.AgentEngineOperation( response=_genai_types.ReasoningEngine( diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index ca4503a581..aee4631d94 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -643,6 +643,92 @@ def test_delete_session(self, get_project_id_mock: mock.Mock): response0 = app.list_sessions(user_id=_TEST_USER_ID) assert not response0.sessions + @pytest.mark.asyncio + async def test_async_artifact_management(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session = await app.async_create_session(user_id=_TEST_USER_ID) + session_id = session["id"] + + part = types.Part(text="test artifact content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + artifact=part, + session_id=session_id, + ) + assert version == 0 + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert loaded.text == "test artifact content" + + keys = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert keys == ["test.txt"] + + versions = await app.async_list_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert versions == [0] + + art_versions = await app.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert len(art_versions) == 1 + assert art_versions[0].version == 0 + + art_ver = await app.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + version=0, + ) + assert art_ver.version == 0 + + await app.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + keys_after = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert not keys_after + + @pytest.mark.asyncio + async def test_async_artifact_management_lazy_init( + self, get_project_id_mock: mock.Mock + ): + app = adk_template.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("artifact_service") is None + + part = types.Part(text="test lazy content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + artifact=part, + session_id="lazy_session", + ) + assert version == 0 + assert app._tmpl_attrs.get("artifact_service") is not None + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert loaded.text == "test lazy content" + @pytest.mark.asyncio async def test_async_add_session_to_memory_dict( self, diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index e943ceee96..90b7bb7373 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -755,6 +755,94 @@ def test_delete_session(self): response0 = app.list_sessions(user_id=_TEST_USER_ID) assert not response0.sessions + @pytest.mark.asyncio + async def test_async_artifact_management(self): + app = reasoning_engines.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + session = await app.async_create_session(user_id=_TEST_USER_ID) + session_id = session["id"] + + part = types.Part(text="test artifact content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + artifact=part, + session_id=session_id, + ) + assert version == 0 + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert loaded.text == "test artifact content" + + keys = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert keys == ["test.txt"] + + versions = await app.async_list_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert versions == [0] + + art_versions = await app.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert len(art_versions) == 1 + assert art_versions[0].version == 0 + + art_ver = await app.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + version=0, + ) + assert art_ver.version == 0 + + await app.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + keys_after = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert not keys_after + + @pytest.mark.asyncio + async def test_async_artifact_management_lazy_init(self): + app = reasoning_engines.AdkApp( + agent=Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("artifact_service") is None + + part = types.Part(text="test lazy content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + artifact=part, + session_id="lazy_session", + ) + assert version == 0 + assert app._tmpl_attrs.get("artifact_service") is not None + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert loaded.text == "test lazy content" + @pytest.mark.asyncio async def test_async_add_session_to_memory(self): app = reasoning_engines.AdkApp( diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 16b477c573..031fc7302b 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -37,6 +37,13 @@ from google.auth.transport import requests as requests_auth if TYPE_CHECKING: + try: + from google.genai import types + + types = types + except (ImportError, AttributeError): + types = Any + try: from google.adk.events.event import Event @@ -1754,6 +1761,253 @@ async def async_search_memory(self, *, user_id: str, query: str): query=query, ) + async def async_save_artifact( + self, + *, + user_id: str, + filename: str, + artifact: Union["types.Part", Dict[str, Any]], + session_id: Optional[str] = None, + custom_metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Saves an artifact to the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + artifact (Union[types.Part, Dict[str, Any]]): + Required. The artifact to save. + session_id (Optional[str]): + Optional. The ID of the session. + custom_metadata (Optional[Dict[str, Any]]): + Optional. Custom metadata to associate with the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + int: The revision ID. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").save_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + artifact=artifact, + session_id=session_id, + custom_metadata=custom_metadata, + **kwargs, + ) + + async def async_load_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets an artifact from the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[types.Part]: The artifact or None if not found. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").load_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + + async def async_list_artifact_keys( + self, + *, + user_id: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all the artifact filenames within a session. + + Args: + user_id (str): + Required. The ID of the user. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[str]: A list of artifact filenames. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + + async def async_delete_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Deletes an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + await self._tmpl_attrs.get("artifact_service").delete_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[int]: A list of all available versions of the artifact. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_artifact_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions and their metadata for a specific artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[ArtifactVersion]: A list of ArtifactVersion objects. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_get_artifact_version( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets the metadata for a specific version of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version number of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[ArtifactVersion]: An ArtifactVersion object or None. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").get_artifact_version( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the ADK application.""" return { @@ -1770,6 +2024,13 @@ def register_operations(self) -> Dict[str, List[str]]: "async_delete_session", "async_add_session_to_memory", "async_search_memory", + "async_save_artifact", + "async_load_artifact", + "async_list_artifact_keys", + "async_delete_artifact", + "async_list_versions", + "async_list_artifact_versions", + "async_get_artifact_version", ], "stream": ["stream_query"], "async_stream": [ diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 469ecf71cc..aaad046f44 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -30,8 +30,14 @@ import sys import threading - if TYPE_CHECKING: + try: + from google.genai import types + + types = types + except (ImportError, AttributeError): + types = Any + try: from google.adk.events.event import Event @@ -602,6 +608,11 @@ def _serialize(self, obj: Any) -> Any: return [self._serialize(v) for v in obj] return obj + def _app_name(self) -> str: + """Returns the app name.""" + app = self._tmpl_attrs.get("app") + return app.name if app else self._tmpl_attrs.get("app_name") + async def _init_session( self, session_service: "BaseSessionService", @@ -1538,6 +1549,253 @@ async def async_search_memory(self, *, user_id: str, query: str): query=query, ) + async def async_save_artifact( + self, + *, + user_id: str, + filename: str, + artifact: Union["types.Part", Dict[str, Any]], + session_id: Optional[str] = None, + custom_metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Saves an artifact to the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + artifact (Union[types.Part, Dict[str, Any]]): + Required. The artifact to save. + session_id (Optional[str]): + Optional. The ID of the session. + custom_metadata (Optional[Dict[str, Any]]): + Optional. Custom metadata to associate with the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + int: The revision ID. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").save_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + artifact=artifact, + session_id=session_id, + custom_metadata=custom_metadata, + **kwargs, + ) + + async def async_load_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets an artifact from the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[types.Part]: The artifact or None if not found. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").load_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + + async def async_list_artifact_keys( + self, + *, + user_id: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all the artifact filenames within a session. + + Args: + user_id (str): + Required. The ID of the user. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[str]: A list of artifact filenames. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + + async def async_delete_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Deletes an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + await self._tmpl_attrs.get("artifact_service").delete_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[int]: A list of all available versions of the artifact. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_artifact_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions and their metadata for a specific artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[ArtifactVersion]: A list of ArtifactVersion objects. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_get_artifact_version( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets the metadata for a specific version of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version number of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[ArtifactVersion]: An ArtifactVersion object or None. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").get_artifact_version( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the ADK application.""" return { @@ -1554,6 +1812,13 @@ def register_operations(self) -> Dict[str, List[str]]: "async_delete_session", "async_add_session_to_memory", "async_search_memory", + "async_save_artifact", + "async_load_artifact", + "async_list_artifact_keys", + "async_delete_artifact", + "async_list_versions", + "async_list_artifact_versions", + "async_get_artifact_version", ], "stream": ["stream_query", "streaming_agent_run_with_events"], "async_stream": ["async_stream_query"],