From 2b7b5f11be2262d18173605d5e012fab97b7a03c Mon Sep 17 00:00:00 2001 From: jasinluo <1127097451@qq.com> Date: Wed, 15 Apr 2026 11:57:12 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=A1=A5=E5=85=85=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E5=AD=98=E5=82=A8=20usage=5Fmetadata=20=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=E5=B9=B6=E6=96=B0=E5=A2=9E=E8=87=AA=E5=8A=A8=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=E7=BC=BA=E5=A4=B1=E5=88=97=E8=83=BD=E5=8A=9B=20=20=20?= =?UTF-8?q?=20=20-=20SessionStorageEvent=20=E6=96=B0=E5=A2=9E=20usage=5Fme?= =?UTF-8?q?tadata=20=E5=88=97=EF=BC=8C=E6=94=AF=E6=8C=81=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=92=8C=E8=AF=BB=E5=8F=96=20token=20=E7=94=A8=E9=87=8F?= =?UTF-8?q?=E7=BB=9F=E8=AE=A1=20=20=20=20=20-=20=E6=96=B0=E5=A2=9E=20decod?= =?UTF-8?q?e=5Fusage=5Fmetadata=20=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=94=A8=E4=BA=8E=E5=8F=8D=E5=BA=8F=E5=88=97=E5=8C=96=20=20=20?= =?UTF-8?q?=20=20-=20=E6=96=B0=E5=A2=9E=20=5Fmigrate=5Fmissing=5Fcolumns?= =?UTF-8?q?=20=E6=96=B9=E6=B3=95=EF=BC=8C=E8=87=AA=E5=8A=A8=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E5=B9=B6=20ALTER=20TABLE=20=E6=B7=BB=E5=8A=A0=20ORM?= =?UTF-8?q?=20=E6=A8=A1=E5=9E=8B=E4=B8=AD=E6=96=B0=E5=A2=9E=E4=BD=86?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E7=BC=BA=E5=A4=B1=E7=9A=84=E5=88=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/sessions/test_sql_session_service.py | 3 +- tests/storage/test_sql.py | 24 ++--- .../sessions/_sql_session_service.py | 5 + trpc_agent_sdk/storage/__init__.py | 2 + trpc_agent_sdk/storage/_sql.py | 95 ++++++++++++++++++- trpc_agent_sdk/storage/_sql_common.py | 15 +++ 6 files changed, 122 insertions(+), 22 deletions(-) diff --git a/tests/sessions/test_sql_session_service.py b/tests/sessions/test_sql_session_service.py index b6c4fa7..9925f60 100644 --- a/tests/sessions/test_sql_session_service.py +++ b/tests/sessions/test_sql_session_service.py @@ -66,8 +66,7 @@ def _make_event_with_function_call(): async def _create_service(config=None): config = config or _make_config() svc = SqlSessionService(db_url="sqlite:///:memory:", session_config=config, is_async=False) - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await svc._sql_storage.create_sql_engine() + await svc._sql_storage.create_sql_engine() return svc diff --git a/tests/storage/test_sql.py b/tests/storage/test_sql.py index 4c90bc3..b1a3ce2 100644 --- a/tests/storage/test_sql.py +++ b/tests/storage/test_sql.py @@ -51,9 +51,7 @@ def async_db_url(self): async def sync_storage(self, db_url): """Synchronous SQL storage fixture with initialized engine.""" storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata) - # Patch event.listen to work around SQLite pragma issue - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() yield storage await storage.close() @@ -61,9 +59,7 @@ async def sync_storage(self, db_url): async def async_storage(self, async_db_url): """Asynchronous SQL storage fixture with initialized engine.""" storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata) - # Patch event.listen to work around async engine event limitation - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() yield storage await storage.close() @@ -93,9 +89,7 @@ async def test_create_sql_engine_async(self, async_db_url): """Test creating async SQL engine.""" storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata) - # Patch event.listen to work around async engine event limitation - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() assert storage._db_engine is not None assert storage._database_session_factory is not None @@ -108,9 +102,7 @@ async def test_create_sql_engine_sync(self, db_url): """Test creating sync SQL engine.""" storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata) - # Patch event.listen to work around SQLite pragma issue in tests - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() assert storage._db_engine is not None assert storage._database_session_factory is not None @@ -473,9 +465,7 @@ async def test_close_async(self, async_db_url): """Test closing async SQL engine.""" storage = SqlStorage(is_async=True, db_url=async_db_url, metadata=StorageData.metadata) - # Patch event.listen to work around async engine event limitation - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() assert storage._db_engine is not None @@ -487,9 +477,7 @@ async def test_close_sync(self, db_url): """Test closing sync SQL engine.""" storage = SqlStorage(is_async=False, db_url=db_url, metadata=StorageData.metadata) - # Patch event.listen to work around SQLite pragma issue in tests - with patch('trpc_agent_sdk.storage._sql.event.listen'): - await storage.create_sql_engine() + await storage.create_sql_engine() assert storage._db_engine is not None diff --git a/trpc_agent_sdk/sessions/_sql_session_service.py b/trpc_agent_sdk/sessions/_sql_session_service.py index 5a59b48..779604a 100644 --- a/trpc_agent_sdk/sessions/_sql_session_service.py +++ b/trpc_agent_sdk/sessions/_sql_session_service.py @@ -64,6 +64,7 @@ from trpc_agent_sdk.storage import UTF8MB4String from trpc_agent_sdk.storage import decode_content from trpc_agent_sdk.storage import decode_grounding_metadata +from trpc_agent_sdk.storage import decode_usage_metadata from trpc_agent_sdk.utils import user_key from ._base_session_service import BaseSessionService @@ -165,6 +166,7 @@ class SessionStorageEvent(SessionStorageBase): content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) grounding_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + usage_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) custom_metadata: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) partial: Mapped[bool] = mapped_column(Boolean, nullable=True) @@ -218,6 +220,8 @@ def from_event(cls, session: Session, event: Event) -> SessionStorageEvent: storage_event.content = event.content.model_dump(exclude_none=True, mode="json") if event.grounding_metadata: storage_event.grounding_metadata = event.grounding_metadata.model_dump(exclude_none=True, mode="json") + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump(exclude_none=True, mode="json") if event.custom_metadata: storage_event.custom_metadata = event.custom_metadata return storage_event @@ -238,6 +242,7 @@ def to_event(self) -> Event: error_message=self.error_message, interrupted=self.interrupted, grounding_metadata=decode_grounding_metadata(self.grounding_metadata), + usage_metadata=decode_usage_metadata(self.usage_metadata), custom_metadata=self.custom_metadata, ) diff --git a/trpc_agent_sdk/storage/__init__.py b/trpc_agent_sdk/storage/__init__.py index eaeedcb..56c2432 100644 --- a/trpc_agent_sdk/storage/__init__.py +++ b/trpc_agent_sdk/storage/__init__.py @@ -29,6 +29,7 @@ from ._sql_common import UTF8MB4String from ._sql_common import decode_content from ._sql_common import decode_grounding_metadata +from ._sql_common import decode_usage_metadata __all__ = [ "EXPIRE_METHOD", @@ -55,4 +56,5 @@ "UTF8MB4String", "decode_content", "decode_grounding_metadata", + "decode_usage_metadata", ] diff --git a/trpc_agent_sdk/storage/_sql.py b/trpc_agent_sdk/storage/_sql.py index 810b36d..b836cd1 100644 --- a/trpc_agent_sdk/storage/_sql.py +++ b/trpc_agent_sdk/storage/_sql.py @@ -19,11 +19,16 @@ from sqlalchemy import MetaData from sqlalchemy import and_ from sqlalchemy import delete as sql_delete +from sqlalchemy import Dialect +from sqlalchemy.sql.compiler import IdentifierPreparer from sqlalchemy import event from sqlalchemy import select +from sqlalchemy import text from sqlalchemy.engine import Engine from sqlalchemy.engine import create_engine from sqlalchemy.engine.interfaces import DBAPICursor +from sqlalchemy.engine import Connection +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.exc import ArgumentError from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession @@ -122,6 +127,89 @@ def __init__(self, is_async: bool, db_url: str, metadata: Optional[MetaData] = N self.__db_url = db_url self.__kwargs = kwargs + def _migrate_missing_columns(self, connection: Connection) -> None: + """Add columns that exist in the ORM model but are missing from the database, + for forward compatibility across version changes. + + SQLAlchemy's create_all only creates tables — it never ALTERs existing + tables. This helper bridges the gap for lightweight forward-only migrations. + Only handles adding new columns (forward-only). + + All-or-nothing semantics: on databases that support transactional DDL + (e.g. PostgreSQL) the caller's transaction handles rollback. On databases + where DDL auto-commits (e.g. MySQL), a compensating DROP COLUMN is issued + for every column that was already added before the failure. + + Args: + connection: A synchronous SQLAlchemy Connection object. + """ + insp: Inspector = inspect(connection) + dialect: Dialect = connection.dialect + preparer: IdentifierPreparer = dialect.identifier_preparer + ddl_compiler = dialect.ddl_compiler(dialect, None) + + pending_add_columns: list[tuple[str, str, str]] = [] + for table_name, table in self.__metadata.tables.items(): + if not insp.has_table(table_name): + continue + existing: set[str] = {col["name"] for col in insp.get_columns(table_name)} + for column in table.columns: + if column.name in existing: + continue + col_type: str = column.type.compile(dialect=dialect) + # handle different types of default value + nullable: str = "" if column.nullable else " NOT NULL" + default: str = "" + default_value = ddl_compiler.get_column_default_string(column) + if default_value is not None: + default = f" DEFAULT {default_value}" + elif column.server_default is not None: + # if the column has server_default, but it is not a DDL server_default, warning + logger.warning( + "Column '%s' on table '%s' has a non-DDL server_default " + "(%s); skipping DEFAULT clause generation.", + column.name, + table_name, + type(column.server_default).__name__, + ) + elif not column.nullable: + # if the column is NOT NULL and has no server_default, raise error + logger.warning( + "Column '%s' on table '%s' is NOT NULL without a server_default; " + "migration may fail if the table already contains rows.", + column.name, + table_name, + ) + quoted_table: str = preparer.quote_identifier(table_name) + quoted_col: str = preparer.quote_identifier(column.name) + stmt: str = f"ALTER TABLE {quoted_table} ADD COLUMN {quoted_col} {col_type}{default}{nullable}" + pending_add_columns.append((stmt, column.name, table_name)) + + if not pending_add_columns: + return + + added_columns: list[tuple[str, str]] = [] + try: + for stmt, col_name, table_name in pending_add_columns: + connection.execute(text(stmt)) + added_columns.append((col_name, table_name)) + logger.info("Auto-migrated: added column '%s' to table '%s'", col_name, table_name) + except Exception: + logger.error("Migration failed, compensating %d already-added column(s).", len(added_columns)) + for col_name, tbl_name in reversed(added_columns): + drop_stmt = (f"ALTER TABLE {preparer.quote_identifier(tbl_name)} " + f"DROP COLUMN {preparer.quote_identifier(col_name)}") + try: + connection.execute(text(drop_stmt)) + logger.info("Compensated: dropped column '%s' from table '%s'", col_name, tbl_name) + except Exception: + logger.error( + "Failed to compensate column '%s' on table '%s'; manual cleanup required.", + col_name, + tbl_name, + ) + raise + async def create_sql_engine(self): """Create the database engine.""" if self._db_engine: @@ -137,16 +225,19 @@ async def _async_inspect(): self.inspector = await _async_inspect() async with db_engine.begin() as conn: await conn.run_sync(self.__metadata.create_all) + await conn.run_sync(self._migrate_missing_columns) self._database_session_factory = async_sessionmaker(bind=db_engine) else: db_engine: SqlEngine = create_engine(self.__db_url, **self.__kwargs) self.inspector = inspect(db_engine) self.__metadata.create_all(db_engine) + with db_engine.begin() as conn: + self._migrate_missing_columns(conn) self._database_session_factory = sessionmaker(bind=db_engine) if db_engine.dialect.name == "sqlite": - # Set sqlite pragma to enable foreign keys constraints - event.listen(db_engine, "connect", _set_sqlite_pragma) + listen_target = db_engine.sync_engine if isinstance(db_engine, AsyncEngine) else db_engine + event.listen(listen_target, "connect", _set_sqlite_pragma) except Exception as ex: # pylint: disable=broad-except if isinstance(ex, ArgumentError): diff --git a/trpc_agent_sdk/storage/_sql_common.py b/trpc_agent_sdk/storage/_sql_common.py index d9d4c04..6ec0092 100644 --- a/trpc_agent_sdk/storage/_sql_common.py +++ b/trpc_agent_sdk/storage/_sql_common.py @@ -42,6 +42,7 @@ from trpc_agent_sdk.types import Content from trpc_agent_sdk.types import GroundingMetadata +from trpc_agent_sdk.types import GenerateContentResponseUsageMetadata def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]: @@ -58,6 +59,20 @@ def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]: return Content.model_validate(content) +def decode_usage_metadata(usage_metadata: Optional[dict[str, Any]]) -> Optional[GenerateContentResponseUsageMetadata]: + """Decode a usage metadata object from a JSON dictionary. + + Args: + usage_metadata: JSON dictionary containing usage metadata + + Returns: + Decoded GenerateContentResponseUsageMetadata object or None if usage_metadata is None + """ + if not usage_metadata: + return None + return GenerateContentResponseUsageMetadata.model_validate(usage_metadata) + + def decode_grounding_metadata(grounding_metadata: Optional[dict[str, Any]]) -> Optional[GroundingMetadata]: """Decode a grounding metadata object from a JSON dictionary.