Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/sessions/test_sql_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def _make_event_with_function_call():
async def _create_service(config=None):
config = config or _make_config()
svc = SqlSessionService(db_url="sqlite:///:memory:", session_config=config, is_async=False)
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await svc._sql_storage.create_sql_engine()
await svc._sql_storage.create_sql_engine()
return svc


Expand Down
24 changes: 6 additions & 18 deletions tests/storage/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,15 @@ def async_db_url(self):
async def sync_storage(self, db_url):
"""Synchronous SQL storage fixture with initialized engine."""
storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata)
# Patch event.listen to work around SQLite pragma issue
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()
yield storage
await storage.close()

@pytest.fixture
async def async_storage(self, async_db_url):
"""Asynchronous SQL storage fixture with initialized engine."""
storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata)
# Patch event.listen to work around async engine event limitation
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()
yield storage
await storage.close()

Expand Down Expand Up @@ -93,9 +89,7 @@ async def test_create_sql_engine_async(self, async_db_url):
"""Test creating async SQL engine."""
storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata)

# Patch event.listen to work around async engine event limitation
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()

assert storage._db_engine is not None
assert storage._database_session_factory is not None
Expand All @@ -108,9 +102,7 @@ async def test_create_sql_engine_sync(self, db_url):
"""Test creating sync SQL engine."""
storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata)

# Patch event.listen to work around SQLite pragma issue in tests
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()

assert storage._db_engine is not None
assert storage._database_session_factory is not None
Expand Down Expand Up @@ -473,9 +465,7 @@ async def test_close_async(self, async_db_url):
"""Test closing async SQL engine."""
storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata)

# Patch event.listen to work around async engine event limitation
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()

assert storage._db_engine is not None

Expand All @@ -487,9 +477,7 @@ async def test_close_sync(self, db_url):
"""Test closing sync SQL engine."""
storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata)

# Patch event.listen to work around SQLite pragma issue in tests
with patch('trpc_agent_sdk.storage._sql.event.listen'):
await storage.create_sql_engine()
await storage.create_sql_engine()

assert storage._db_engine is not None

Expand Down
5 changes: 5 additions & 0 deletions trpc_agent_sdk/sessions/_sql_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from trpc_agent_sdk.storage import UTF8MB4String
from trpc_agent_sdk.storage import decode_content
from trpc_agent_sdk.storage import decode_grounding_metadata
from trpc_agent_sdk.storage import decode_usage_metadata
from trpc_agent_sdk.utils import user_key

from ._base_session_service import BaseSessionService
Expand Down Expand Up @@ -165,6 +166,7 @@ class SessionStorageEvent(SessionStorageBase):

content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
usage_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
custom_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)

partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
Expand Down Expand Up @@ -218,6 +220,8 @@ def from_event(cls, session: Session, event: Event) -> SessionStorageEvent:
storage_event.content = event.content.model_dump(exclude_none=True, mode="json")
if event.grounding_metadata:
storage_event.grounding_metadata = event.grounding_metadata.model_dump(exclude_none=True, mode="json")
if event.usage_metadata:
storage_event.usage_metadata = event.usage_metadata.model_dump(exclude_none=True, mode="json")
if event.custom_metadata:
storage_event.custom_metadata = event.custom_metadata
return storage_event
Expand All @@ -238,6 +242,7 @@ def to_event(self) -> Event:
error_message=self.error_message,
interrupted=self.interrupted,
grounding_metadata=decode_grounding_metadata(self.grounding_metadata),
usage_metadata=decode_usage_metadata(self.usage_metadata),
custom_metadata=self.custom_metadata,
)

Expand Down
2 changes: 2 additions & 0 deletions trpc_agent_sdk/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ._sql_common import UTF8MB4String
from ._sql_common import decode_content
from ._sql_common import decode_grounding_metadata
from ._sql_common import decode_usage_metadata

__all__ = [
"EXPIRE_METHOD",
Expand All @@ -55,4 +56,5 @@
"UTF8MB4String",
"decode_content",
"decode_grounding_metadata",
"decode_usage_metadata",
]
95 changes: 93 additions & 2 deletions trpc_agent_sdk/storage/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
from sqlalchemy import MetaData
from sqlalchemy import and_
from sqlalchemy import delete as sql_delete
from sqlalchemy import Dialect
from sqlalchemy.sql.compiler import IdentifierPreparer
from sqlalchemy import event
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.engine import Engine
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.interfaces import DBAPICursor
from sqlalchemy.engine import Connection
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.exc import ArgumentError
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -122,6 +127,89 @@ def __init__(self, is_async: bool, db_url: str, metadata: Optional[MetaData] = N
self.__db_url = db_url
self.__kwargs = kwargs

def _migrate_missing_columns(self, connection: Connection) -> None:
"""Add columns that exist in the ORM model but are missing from the database,
for forward compatibility across version changes.

SQLAlchemy's create_all only creates tables — it never ALTERs existing
tables. This helper bridges the gap for lightweight forward-only migrations.
Only handles adding new columns (forward-only).

All-or-nothing semantics: on databases that support transactional DDL
(e.g. PostgreSQL) the caller's transaction handles rollback. On databases
where DDL auto-commits (e.g. MySQL), a compensating DROP COLUMN is issued
for every column that was already added before the failure.

Args:
connection: A synchronous SQLAlchemy Connection object.
"""
insp: Inspector = inspect(connection)
dialect: Dialect = connection.dialect
preparer: IdentifierPreparer = dialect.identifier_preparer
ddl_compiler = dialect.ddl_compiler(dialect, None)

pending_add_columns: list[tuple[str, str, str]] = []
for table_name, table in self.__metadata.tables.items():
if not insp.has_table(table_name):
continue
existing: set[str] = {col["name"] for col in insp.get_columns(table_name)}
for column in table.columns:
if column.name in existing:
continue
col_type: str = column.type.compile(dialect=dialect)
# handle different types of default value
nullable: str = "" if column.nullable else " NOT NULL"
default: str = ""
default_value = ddl_compiler.get_column_default_string(column)
if default_value is not None:
default = f" DEFAULT {default_value}"
elif column.server_default is not None:
# if the column has server_default, but it is not a DDL server_default, warning
logger.warning(
"Column '%s' on table '%s' has a non-DDL server_default "
"(%s); skipping DEFAULT clause generation.",
column.name,
table_name,
type(column.server_default).__name__,
)
elif not column.nullable:
# if the column is NOT NULL and has no server_default, raise error
logger.warning(
"Column '%s' on table '%s' is NOT NULL without a server_default; "
"migration may fail if the table already contains rows.",
column.name,
table_name,
)
quoted_table: str = preparer.quote_identifier(table_name)
quoted_col: str = preparer.quote_identifier(column.name)
stmt: str = f"ALTER TABLE {quoted_table} ADD COLUMN {quoted_col} {col_type}{default}{nullable}"
pending_add_columns.append((stmt, column.name, table_name))

if not pending_add_columns:
return

added_columns: list[tuple[str, str]] = []
try:
for stmt, col_name, table_name in pending_add_columns:
connection.execute(text(stmt))
added_columns.append((col_name, table_name))
logger.info("Auto-migrated: added column '%s' to table '%s'", col_name, table_name)
except Exception:
logger.error("Migration failed, compensating %d already-added column(s).", len(added_columns))
for col_name, tbl_name in reversed(added_columns):
drop_stmt = (f"ALTER TABLE {preparer.quote_identifier(tbl_name)} "
f"DROP COLUMN {preparer.quote_identifier(col_name)}")
try:
connection.execute(text(drop_stmt))
logger.info("Compensated: dropped column '%s' from table '%s'", col_name, tbl_name)
except Exception:
logger.error(
"Failed to compensate column '%s' on table '%s'; manual cleanup required.",
col_name,
tbl_name,
)
raise

async def create_sql_engine(self):
"""Create the database engine."""
if self._db_engine:
Expand All @@ -137,16 +225,19 @@ async def _async_inspect():
self.inspector = await _async_inspect()
async with db_engine.begin() as conn:
await conn.run_sync(self.__metadata.create_all)
await conn.run_sync(self._migrate_missing_columns)
self._database_session_factory = async_sessionmaker(bind=db_engine)
else:
db_engine: SqlEngine = create_engine(self.__db_url, **self.__kwargs)
self.inspector = inspect(db_engine)
self.__metadata.create_all(db_engine)
with db_engine.begin() as conn:
self._migrate_missing_columns(conn)
self._database_session_factory = sessionmaker(bind=db_engine)

if db_engine.dialect.name == "sqlite":
# Set sqlite pragma to enable foreign keys constraints
event.listen(db_engine, "connect", _set_sqlite_pragma)
listen_target = db_engine.sync_engine if isinstance(db_engine, AsyncEngine) else db_engine
event.listen(listen_target, "connect", _set_sqlite_pragma)

except Exception as ex: # pylint: disable=broad-except
if isinstance(ex, ArgumentError):
Expand Down
15 changes: 15 additions & 0 deletions trpc_agent_sdk/storage/_sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from trpc_agent_sdk.types import Content
from trpc_agent_sdk.types import GroundingMetadata
from trpc_agent_sdk.types import GenerateContentResponseUsageMetadata


def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
Expand All @@ -58,6 +59,20 @@ def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]:
return Content.model_validate(content)


def decode_usage_metadata(usage_metadata: Optional[dict[str, Any]]) -> Optional[GenerateContentResponseUsageMetadata]:
"""Decode a usage metadata object from a JSON dictionary.

Args:
usage_metadata: JSON dictionary containing usage metadata

Returns:
Decoded GenerateContentResponseUsageMetadata object or None if usage_metadata is None
"""
if not usage_metadata:
return None
return GenerateContentResponseUsageMetadata.model_validate(usage_metadata)


def decode_grounding_metadata(grounding_metadata: Optional[dict[str, Any]]) -> Optional[GroundingMetadata]:
"""Decode a grounding metadata object from a JSON dictionary.

Expand Down
Loading