Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 77 additions & 11 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.engine import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import async_sessionmaker
Expand Down Expand Up @@ -359,6 +361,70 @@ async def _session_matches_storage_revision(
latest_storage_event_id = result.scalar_one_or_none()
return latest_storage_event_id == session.events[-1].id

async def _ensure_state_rows(
self,
*,
sql_session: DatabaseSessionFactory,
schema: _SchemaClasses,
app_name: str,
user_id: str,
) -> None:
"""Ensure app_states and user_states rows exist, creating them atomically.

Uses INSERT ... ON CONFLICT DO NOTHING (PostgreSQL/SQLite) to avoid
UniqueViolation errors when multiple concurrent create_session calls
race to insert the same app_name or (app_name, user_id) row.
"""
dialect_name = self.db_engine.dialect.name

if dialect_name == _POSTGRESQL_DIALECT:
app_stmt = (
pg_insert(schema.StorageAppState)
.values(app_name=app_name, state={})
.on_conflict_do_nothing(index_elements=["app_name"])
)
user_stmt = (
pg_insert(schema.StorageUserState)
.values(app_name=app_name, user_id=user_id, state={})
.on_conflict_do_nothing(
index_elements=["app_name", "user_id"]
)
)
elif dialect_name == _SQLITE_DIALECT:
app_stmt = (
sqlite_insert(schema.StorageAppState)
.values(app_name=app_name, state={})
.on_conflict_do_nothing()
)
user_stmt = (
sqlite_insert(schema.StorageUserState)
.values(app_name=app_name, user_id=user_id, state={})
.on_conflict_do_nothing()
)
else:
# Fallback for other dialects: use the original get-then-add pattern.
# This is not race-safe but maintains backward compatibility.
storage_app_state = await sql_session.get(
schema.StorageAppState, (app_name)
)
if not storage_app_state:
sql_session.add(
schema.StorageAppState(app_name=app_name, state={})
)
storage_user_state = await sql_session.get(
schema.StorageUserState, (app_name, user_id)
)
if not storage_user_state:
sql_session.add(
schema.StorageUserState(
app_name=app_name, user_id=user_id, state={}
)
)
return

await sql_session.execute(app_stmt)
await sql_session.execute(user_stmt)

@override
async def create_session(
self,
Expand All @@ -382,24 +448,24 @@ async def create_session(
raise AlreadyExistsError(
f"Session with id {session_id} already exists."
)
# Fetch app and user states from storage
# Ensure app and user state rows exist using INSERT ... ON CONFLICT
# DO NOTHING to avoid race conditions under concurrent
# create_session calls for the same app_name/user_id.
await self._ensure_state_rows(
sql_session=sql_session,
schema=schema,
app_name=app_name,
user_id=user_id,
)

# Fetch the (now guaranteed to exist) state rows
storage_app_state = await sql_session.get(
schema.StorageAppState, (app_name)
)
storage_user_state = await sql_session.get(
schema.StorageUserState, (app_name, user_id)
)

# Create state tables if not exist
if not storage_app_state:
storage_app_state = schema.StorageAppState(app_name=app_name, state={})
sql_session.add(storage_app_state)
if not storage_user_state:
storage_user_state = schema.StorageUserState(
app_name=app_name, user_id=user_id, state={}
)
sql_session.add(storage_user_state)

# Extract state deltas
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas["app"]
Expand Down
65 changes: 65 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,3 +1481,68 @@ async def tracking_fn(**kwargs):
finally:
database_session_service._select_required_state = original_fn
await service.close()


@pytest.mark.asyncio
async def test_concurrent_create_session_no_unique_violation():
"""Concurrent create_session calls for the same app_name must not raise.

Regression test for https://github.com/google/adk-python/issues/4954.
Before the fix, the SELECT-then-INSERT pattern on app_states/user_states
caused UniqueViolation when multiple tasks raced to initialise the same
app_name on a fresh database.
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
app_name = 'race_test_app'
num_concurrent = 10

# Spawn many concurrent create_session calls for the same app_name
# but different user_ids — all will try to INSERT the same app_states row.
sessions = await asyncio.gather(
*[
service.create_session(
app_name=app_name,
user_id=f'user_{i}',
)
for i in range(num_concurrent)
]
)

# All sessions should have been created successfully (no exceptions)
assert len(sessions) == num_concurrent
for i, session in enumerate(sessions):
assert session.app_name == app_name
assert session.user_id == f'user_{i}'
finally:
await service.close()


@pytest.mark.asyncio
async def test_concurrent_create_session_same_user_no_unique_violation():
"""Concurrent create_session for same app_name AND user_id must not raise.

This tests the user_states race condition (same composite key).
"""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
try:
app_name = 'race_test_app'
user_id = 'shared_user'
num_concurrent = 10

sessions = await asyncio.gather(
*[
service.create_session(
app_name=app_name,
user_id=user_id,
)
for i in range(num_concurrent)
]
)

assert len(sessions) == num_concurrent
for session in sessions:
assert session.app_name == app_name
assert session.user_id == user_id
finally:
await service.close()