diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 1aeb464b7f..b393fdc301 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -28,6 +28,8 @@ from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.engine import make_url from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import async_sessionmaker @@ -359,6 +361,70 @@ async def _session_matches_storage_revision( latest_storage_event_id = result.scalar_one_or_none() return latest_storage_event_id == session.events[-1].id + async def _ensure_state_rows( + self, + *, + sql_session: DatabaseSessionFactory, + schema: _SchemaClasses, + app_name: str, + user_id: str, + ) -> None: + """Ensure app_states and user_states rows exist, creating them atomically. + + Uses INSERT ... ON CONFLICT DO NOTHING (PostgreSQL/SQLite) to avoid + UniqueViolation errors when multiple concurrent create_session calls + race to insert the same app_name or (app_name, user_id) row. + """ + dialect_name = self.db_engine.dialect.name + + if dialect_name == _POSTGRESQL_DIALECT: + app_stmt = ( + pg_insert(schema.StorageAppState) + .values(app_name=app_name, state={}) + .on_conflict_do_nothing(index_elements=["app_name"]) + ) + user_stmt = ( + pg_insert(schema.StorageUserState) + .values(app_name=app_name, user_id=user_id, state={}) + .on_conflict_do_nothing( + index_elements=["app_name", "user_id"] + ) + ) + elif dialect_name == _SQLITE_DIALECT: + app_stmt = ( + sqlite_insert(schema.StorageAppState) + .values(app_name=app_name, state={}) + .on_conflict_do_nothing() + ) + user_stmt = ( + sqlite_insert(schema.StorageUserState) + .values(app_name=app_name, user_id=user_id, state={}) + .on_conflict_do_nothing() + ) + else: + # Fallback for other dialects: use the original get-then-add pattern. + # This is not race-safe but maintains backward compatibility. + storage_app_state = await sql_session.get( + schema.StorageAppState, (app_name) + ) + if not storage_app_state: + sql_session.add( + schema.StorageAppState(app_name=app_name, state={}) + ) + storage_user_state = await sql_session.get( + schema.StorageUserState, (app_name, user_id) + ) + if not storage_user_state: + sql_session.add( + schema.StorageUserState( + app_name=app_name, user_id=user_id, state={} + ) + ) + return + + await sql_session.execute(app_stmt) + await sql_session.execute(user_stmt) + @override async def create_session( self, @@ -382,7 +448,17 @@ async def create_session( raise AlreadyExistsError( f"Session with id {session_id} already exists." ) - # Fetch app and user states from storage + # Ensure app and user state rows exist using INSERT ... ON CONFLICT + # DO NOTHING to avoid race conditions under concurrent + # create_session calls for the same app_name/user_id. + await self._ensure_state_rows( + sql_session=sql_session, + schema=schema, + app_name=app_name, + user_id=user_id, + ) + + # Fetch the (now guaranteed to exist) state rows storage_app_state = await sql_session.get( schema.StorageAppState, (app_name) ) @@ -390,16 +466,6 @@ async def create_session( schema.StorageUserState, (app_name, user_id) ) - # Create state tables if not exist - if not storage_app_state: - storage_app_state = schema.StorageAppState(app_name=app_name, state={}) - sql_session.add(storage_app_state) - if not storage_user_state: - storage_user_state = schema.StorageUserState( - app_name=app_name, user_id=user_id, state={} - ) - sql_session.add(storage_user_state) - # Extract state deltas state_deltas = _session_util.extract_state_delta(state) app_state_delta = state_deltas["app"] diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 8a56600edc..072ba52c41 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1481,3 +1481,68 @@ async def tracking_fn(**kwargs): finally: database_session_service._select_required_state = original_fn await service.close() + + +@pytest.mark.asyncio +async def test_concurrent_create_session_no_unique_violation(): + """Concurrent create_session calls for the same app_name must not raise. + + Regression test for https://github.com/google/adk-python/issues/4954. + Before the fix, the SELECT-then-INSERT pattern on app_states/user_states + caused UniqueViolation when multiple tasks raced to initialise the same + app_name on a fresh database. + """ + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + app_name = 'race_test_app' + num_concurrent = 10 + + # Spawn many concurrent create_session calls for the same app_name + # but different user_ids — all will try to INSERT the same app_states row. + sessions = await asyncio.gather( + *[ + service.create_session( + app_name=app_name, + user_id=f'user_{i}', + ) + for i in range(num_concurrent) + ] + ) + + # All sessions should have been created successfully (no exceptions) + assert len(sessions) == num_concurrent + for i, session in enumerate(sessions): + assert session.app_name == app_name + assert session.user_id == f'user_{i}' + finally: + await service.close() + + +@pytest.mark.asyncio +async def test_concurrent_create_session_same_user_no_unique_violation(): + """Concurrent create_session for same app_name AND user_id must not raise. + + This tests the user_states race condition (same composite key). + """ + service = DatabaseSessionService('sqlite+aiosqlite:///:memory:') + try: + app_name = 'race_test_app' + user_id = 'shared_user' + num_concurrent = 10 + + sessions = await asyncio.gather( + *[ + service.create_session( + app_name=app_name, + user_id=user_id, + ) + for i in range(num_concurrent) + ] + ) + + assert len(sessions) == num_concurrent + for session in sessions: + assert session.app_name == app_name + assert session.user_id == user_id + finally: + await service.close()