diff --git a/src/google/adk_community/memory/__init__.py b/src/google/adk_community/memory/__init__.py index 1f3442c0..6cffbd8f 100644 --- a/src/google/adk_community/memory/__init__.py +++ b/src/google/adk_community/memory/__init__.py @@ -16,9 +16,10 @@ from .open_memory_service import OpenMemoryService from .open_memory_service import OpenMemoryServiceConfig +from .redis_memory_service import RedisMemoryService __all__ = [ "OpenMemoryService", "OpenMemoryServiceConfig", + "RedisMemoryService", ] - diff --git a/src/google/adk_community/memory/redis_memory_service.py b/src/google/adk_community/memory/redis_memory_service.py new file mode 100644 index 00000000..5d171089 --- /dev/null +++ b/src/google/adk_community/memory/redis_memory_service.py @@ -0,0 +1,267 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Mapping +from collections.abc import Sequence +import hashlib +import json +import re +from typing import Any +from typing import TYPE_CHECKING +from urllib.parse import quote + +from google.adk.memory import _utils +from google.adk.memory.base_memory_service import BaseMemoryService +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.memory_entry import MemoryEntry +from google.genai import types +import redis.asyncio as redis +from typing_extensions import override + +from .utils import extract_text_from_event + +if TYPE_CHECKING: + from google.adk.events.event import Event + from google.adk.sessions.session import Session + +_UNKNOWN_SESSION_ID = '__unknown_session_id__' + + +def _key_part(value: str) -> str: + return quote(value, safe='') + + +def _decode(value: Any) -> str: + if isinstance(value, bytes): + return value.decode('utf-8') + return str(value) + + +def _extract_words_lower(text: str) -> set[str]: + return set(word.lower() for word in re.findall(r'[A-Za-z]+', text)) + + +def _content_from_text(text: str) -> types.Content: + return types.Content(parts=[types.Part(text=text)]) + + +def _event_id(event: Event, content_text: str) -> str: + if event.id: + return event.id + digest = hashlib.sha256( + f'{event.author}:{event.timestamp}:{content_text}'.encode('utf-8') + ).hexdigest() + return f'generated-{digest}' + + +def _event_to_payload( + event: Event, + *, + session_id: str, + content_text: str, + custom_metadata: Mapping[str, object] | None = None, +) -> dict[str, Any]: + metadata = dict(custom_metadata or {}) + metadata.setdefault('session_id', session_id) + return { + 'id': _event_id(event, content_text), + 'author': event.author, + 'timestamp': ( + _utils.format_timestamp(event.timestamp) + if event.timestamp is not None + else None + ), + 'content': ( + _content_from_text(content_text).model_dump( + mode='json', by_alias=True, exclude_none=True + ) + ), + 'text': content_text, + 'custom_metadata': metadata, + } + + +class RedisMemoryService(BaseMemoryService): + """Redis-backed memory service for ADK community integrations. + + This service mirrors InMemoryMemoryService's keyword search behavior while + keeping memory entries in Redis so they survive process restarts. + """ + + def __init__( + self, + host: str = 'localhost', + port: int = 6379, + db: int = 0, + uri: str | None = None, + cluster_uri: str | None = None, + *, + key_prefix: str = 'adk:memory:', + client: Any | None = None, + **kwargs: Any, + ): + """Initializes the Redis memory service. + + Args: + host: Redis host used when uri, cluster_uri, and client are not supplied. + port: Redis port used when uri, cluster_uri, and client are not supplied. + db: Redis database used when uri, cluster_uri, and client are not supplied. + uri: Redis URL used to create a standalone Redis client. + cluster_uri: Redis Cluster URL used to create a Redis Cluster client. + key_prefix: Prefix for all Redis keys written by this service. + client: Optional async Redis-compatible client, mainly for tests. + **kwargs: Extra keyword arguments forwarded to the Redis client factory. + """ + if client is not None: + self.cache = client + elif cluster_uri: + self.cache = redis.RedisCluster.from_url(cluster_uri, **kwargs) + elif uri: + self.cache = redis.Redis.from_url(uri, **kwargs) + else: + self.cache = redis.Redis(host=host, port=port, db=db, **kwargs) + + self._key_prefix = key_prefix + + def _scope_prefix(self, app_name: str, user_id: str) -> str: + return f'{self._key_prefix}{_key_part(app_name)}:{_key_part(user_id)}' + + def _sessions_key(self, app_name: str, user_id: str) -> str: + return f'{self._scope_prefix(app_name, user_id)}:sessions' + + def _session_keys( + self, app_name: str, user_id: str, session_id: str + ) -> tuple[str, str]: + session_prefix = ( + f'{self._scope_prefix(app_name, user_id)}:{_key_part(session_id)}' + ) + return f'{session_prefix}:order', f'{session_prefix}:entries' + + async def _append_events( + self, + *, + app_name: str, + user_id: str, + session_id: str, + events: Sequence[Event], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + await self.cache.sadd(self._sessions_key(app_name, user_id), session_id) + order_key, entries_key = self._session_keys(app_name, user_id, session_id) + + for event in events: + content_text = extract_text_from_event(event) + if not content_text: + continue + + event_id = _event_id(event, content_text) + payload = _event_to_payload( + event, + session_id=session_id, + content_text=content_text, + custom_metadata=custom_metadata, + ) + was_added = await self.cache.hsetnx( + entries_key, event_id, json.dumps(payload) + ) + if was_added: + await self.cache.rpush(order_key, event_id) + + @override + async def add_session_to_memory(self, session: Session) -> None: + session_id = session.id or _UNKNOWN_SESSION_ID + order_key, entries_key = self._session_keys( + session.app_name, session.user_id, session_id + ) + await self.cache.delete(order_key) + await self.cache.delete(entries_key) + await self._append_events( + app_name=session.app_name, + user_id=session.user_id, + session_id=session_id, + events=session.events, + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + await self._append_events( + app_name=app_name, + user_id=user_id, + session_id=session_id or _UNKNOWN_SESSION_ID, + events=events, + custom_metadata=custom_metadata, + ) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + sessions_key = self._sessions_key(app_name, user_id) + session_ids = sorted( + [_decode(value) for value in await self.cache.smembers(sessions_key)] + ) + words_in_query = _extract_words_lower(query) + response = SearchMemoryResponse() + + for session_id in session_ids: + order_key, entries_key = self._session_keys(app_name, user_id, session_id) + event_ids = [ + _decode(value) for value in await self.cache.lrange(order_key, 0, -1) + ] + for event_id in event_ids: + raw_payload = await self.cache.hget(entries_key, event_id) + if raw_payload is None: + continue + payload = json.loads(_decode(raw_payload)) + words_in_memory = _extract_words_lower(payload.get('text', '')) + if not words_in_memory: + continue + if any(query_word in words_in_memory for query_word in words_in_query): + response.memories.append( + MemoryEntry( + id=payload['id'], + content=types.Content.model_validate(payload['content']), + author=payload.get('author'), + timestamp=payload.get('timestamp'), + custom_metadata=payload.get('custom_metadata') or {}, + ) + ) + + return response + + async def close(self) -> None: + """Closes the Redis client if it exposes a close method.""" + close = getattr(self.cache, 'aclose', None) + if close is None: + close = getattr(self.cache, 'close', None) + if close is not None: + result = close() + if hasattr(result, '__await__'): + await result + + async def __aenter__(self) -> RedisMemoryService: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/tests/unittests/memory/test_redis_memory_service.py b/tests/unittests/memory/test_redis_memory_service.py new file mode 100644 index 00000000..eba5bacd --- /dev/null +++ b/tests/unittests/memory/test_redis_memory_service.py @@ -0,0 +1,453 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.events.event import Event +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +from google.adk_community.memory import RedisMemoryService + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' +MOCK_OTHER_USER_ID = 'another-user' + +MOCK_SESSION_1 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-1', + last_update_time=1000, + events=[ + Event( + id='event-1a', + invocation_id='inv-1', + author='user', + timestamp=12345, + content=types.Content( + parts=[types.Part(text='The ADK is a great toolkit.')] + ), + ), + Event( + id='event-1b', + invocation_id='inv-2', + author='user', + timestamp=12346, + ), + Event( + id='event-1c', + invocation_id='inv-3', + author='model', + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text='I agree. The Agent Development Kit (ADK) rocks!' + ) + ] + ), + ), + ], +) + +MOCK_SESSION_2 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-2', + last_update_time=2000, + events=[ + Event( + id='event-2a', + invocation_id='inv-4', + author='user', + timestamp=54321, + content=types.Content( + parts=[types.Part(text='I like to code in Python.')] + ), + ), + ], +) + +MOCK_SESSION_DIFFERENT_USER = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_OTHER_USER_ID, + id='session-3', + last_update_time=3000, + events=[ + Event( + id='event-3a', + invocation_id='inv-5', + author='user', + timestamp=60000, + content=types.Content(parts=[types.Part(text='This is a secret.')]), + ), + ], +) + +MOCK_SESSION_WITH_NO_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-4', + last_update_time=4000, +) + + +class FakeAsyncRedis: + + def __init__(self): + self.sets: dict[str, set[str]] = {} + self.lists: dict[str, list[str]] = {} + self.hashes: dict[str, dict[str, str]] = {} + self.closed = False + + async def sadd(self, key: str, *values: str) -> int: + values_set = self.sets.setdefault(key, set()) + old_len = len(values_set) + values_set.update(values) + return len(values_set) - old_len + + async def smembers(self, key: str) -> set[str]: + return set(self.sets.get(key, set())) + + async def delete(self, *keys: str) -> int: + deleted = 0 + for key in keys: + for store in (self.sets, self.lists, self.hashes): + if key in store: + del store[key] + deleted += 1 + return deleted + + async def hsetnx(self, key: str, field: str, value: str) -> int: + values = self.hashes.setdefault(key, {}) + if field in values: + return 0 + values[field] = value + return 1 + + async def rpush(self, key: str, value: str) -> int: + values = self.lists.setdefault(key, []) + values.append(value) + return len(values) + + async def lrange(self, key: str, start: int, end: int) -> list[str]: + values = self.lists.get(key, []) + if end == -1: + return values[start:] + return values[start : end + 1] + + async def hget(self, key: str, field: str) -> str | None: + return self.hashes.get(key, {}).get(field) + + async def aclose(self) -> None: + self.closed = True + + +def redis_memory_service() -> RedisMemoryService: + return RedisMemoryService(client=FakeAsyncRedis()) + + +@pytest.mark.asyncio +async def test_add_session_to_memory(): + memory_service = redis_memory_service() + + await memory_service.add_session_to_memory(MOCK_SESSION_1) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='ADK' + ) + + assert len(result.memories) == 2 + assert {memory.id for memory in result.memories} == {'event-1a', 'event-1c'} + + +@pytest.mark.asyncio +async def test_add_events_to_memory_with_explicit_events(): + memory_service = redis_memory_service() + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[MOCK_SESSION_1.events[0]], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='toolkit' + ) + + assert len(result.memories) == 1 + assert result.memories[0].id == 'event-1a' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_without_session_id_uses_default_bucket(): + memory_service = redis_memory_service() + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + events=[MOCK_SESSION_1.events[0]], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='toolkit' + ) + + assert len(result.memories) == 1 + assert result.memories[0].custom_metadata['session_id'] + + +@pytest.mark.asyncio +async def test_add_events_to_memory_appends_without_replacing(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + new_event = Event( + id='event-1d', + invocation_id='inv-6', + author='user', + timestamp=12348, + content=types.Content(parts=[types.Part(text='A new fact.')]), + ) + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[new_event], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='fact' + ) + + assert len(result.memories) == 1 + assert result.memories[0].id == 'event-1d' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_deduplicates_event_ids(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + duplicate_event = Event( + id='event-1a', + invocation_id='inv-7', + author='user', + timestamp=12349, + content=types.Content(parts=[types.Part(text='Updated duplicate text.')]), + ) + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[duplicate_event], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='duplicate' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_add_session_replaces_existing_session_memory(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + replacement_session = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_1.id, + last_update_time=5000, + events=[ + Event( + id='replacement', + invocation_id='inv-8', + author='user', + timestamp=12350, + content=types.Content(parts=[types.Part(text='Replacement')]), + ) + ], + ) + + await memory_service.add_session_to_memory(replacement_session) + old_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='ADK' + ) + new_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Replacement' + ) + + assert not old_result.memories + assert len(new_result.memories) == 1 + assert new_result.memories[0].id == 'replacement' + + +@pytest.mark.asyncio +async def test_add_session_with_no_events_to_memory(): + memory_service = redis_memory_service() + + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_NO_EVENTS) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='anything' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_simple_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_2) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Python' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'I like to code in Python.' + assert result.memories[0].author == 'user' + + +@pytest.mark.asyncio +async def test_search_memory_case_insensitive_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='development' + ) + + assert len(result.memories) == 1 + assert ( + result.memories[0].content.parts[0].text + == 'I agree. The Agent Development Kit (ADK) rocks!' + ) + + +@pytest.mark.asyncio +async def test_search_memory_multiple_matches(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='How about ADK?' + ) + + assert len(result.memories) == 2 + texts = {memory.content.parts[0].text for memory in result.memories} + assert 'The ADK is a great toolkit.' in texts + assert 'I agree. The Agent Development Kit (ADK) rocks!' in texts + + +@pytest.mark.asyncio +async def test_search_memory_no_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='nonexistent' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_is_scoped_by_user(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_DIFFERENT_USER) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='secret' + ) + result_other_user = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_OTHER_USER_ID, query='secret' + ) + + assert not result.memories + assert len(result_other_user.memories) == 1 + assert ( + result_other_user.memories[0].content.parts[0].text == 'This is a secret.' + ) + + +@pytest.mark.asyncio +async def test_thought_parts_are_filtered_from_memory(): + memory_service = redis_memory_service() + session = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='thought-session', + last_update_time=7000, + events=[ + Event( + id='thought-event', + invocation_id='inv-9', + author='model', + timestamp=7001, + content=types.Content( + parts=[ + types.Part(text='Private reasoning', thought=True), + types.Part(text='Visible answer'), + ] + ), + ) + ], + ) + + await memory_service.add_session_to_memory(session) + private_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='reasoning' + ) + visible_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='answer' + ) + + assert not private_result.memories + assert len(visible_result.memories) == 1 + assert visible_result.memories[0].content.parts[0].text == 'Visible answer' + + +@pytest.mark.asyncio +async def test_event_without_timestamp_is_stored(): + memory_service = redis_memory_service() + event = Event( + id='missing-timestamp-event', + invocation_id='inv-10', + author='user', + content=types.Content(parts=[types.Part(text='No timestamp')]), + ) + event.timestamp = None + session = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='missing-timestamp-session', + last_update_time=8000, + events=[event], + ) + + await memory_service.add_session_to_memory(session) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='timestamp' + ) + + assert len(result.memories) == 1 + assert result.memories[0].timestamp is None + + +@pytest.mark.asyncio +async def test_close_closes_client(): + client = FakeAsyncRedis() + memory_service = RedisMemoryService(client=client) + + await memory_service.close() + + assert client.closed