From 6c1a7e901aa9423643fce92e9d7b66125635f0c9 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Tue, 2 Dec 2025 13:12:08 +0200 Subject: [PATCH 1/4] add result_processor to handle UUID conversion; enhance tests for UUID return types --- psqlpy_sqlalchemy/dialect.py | 27 +++++++++++++++++++ tests/test_uuid_support.py | 51 ++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/psqlpy_sqlalchemy/dialect.py b/psqlpy_sqlalchemy/dialect.py index 04d2778..820c7a2 100644 --- a/psqlpy_sqlalchemy/dialect.py +++ b/psqlpy_sqlalchemy/dialect.py @@ -256,6 +256,33 @@ def process(value: t.Any) -> t.Optional[bytes]: return process + def result_processor( + self, dialect: t.Any, coltype: t.Any + ) -> t.Optional[t.Callable[[t.Any], t.Any]]: + """Process UUID results from psqlpy. + + Converts string UUID values returned by psqlpy to Python uuid.UUID objects + when as_uuid=True (which is the default in SQLAlchemy 2.0+). + """ + if self.as_uuid: + + def process(value: t.Any) -> t.Optional[uuid.UUID]: + if value is None: + return None + if isinstance(value, uuid.UUID): + return value + if isinstance(value, str): + # psqlpy returns UUID as string, convert to uuid.UUID + return uuid.UUID(value) + if isinstance(value, bytes): + # Handle bytes representation + return uuid.UUID(bytes=value) + # For other types, try to convert + return uuid.UUID(str(value)) + + return process + return None + class PSQLPyAsyncDialect(PGDialect): driver = "psqlpy" diff --git a/tests/test_uuid_support.py b/tests/test_uuid_support.py index 2195965..dfe0db0 100644 --- a/tests/test_uuid_support.py +++ b/tests/test_uuid_support.py @@ -306,6 +306,57 @@ async def test_uuid_edge_cases(self, engine): count = result.fetchone().count assert count == len(test_cases) + async def test_uuid_select_returns_uuid_objects(self, session): + """Test that SELECT returns uuid.UUID objects, not strings. + + This is a regression test for the issue where UUID values were + returned as strings instead of uuid.UUID objects. + """ + from sqlalchemy import select + + test_uuid = uuid.uuid4() + + # Insert test data + obj = UUIDTable(uid=test_uuid, name="test_select_type") + session.add(obj) + await session.commit() + + # Select UUID column using ORM + stmt = select(UUIDTable.uid) + result = (await session.scalars(stmt)).all() + + # Verify result type + assert len(result) == 1 + retrieved_uuid = result[0] + + # Critical assertion: UUID should be returned as uuid.UUID object, not string + assert isinstance(retrieved_uuid, uuid.UUID), ( + f"Expected uuid.UUID but got {type(retrieved_uuid).__name__}" + ) + assert retrieved_uuid == test_uuid + + async def test_uuid_select_full_row(self, session): + """Test that SELECT * returns uuid.UUID objects in full row results.""" + from sqlalchemy import select + + test_uuid = uuid.uuid4() + + # Insert test data + obj = UUIDTable(uid=test_uuid, name="test_full_row") + session.add(obj) + await session.commit() + + # Select full row + stmt = select(UUIDTable) + result = (await session.scalars(stmt)).all() + + assert len(result) == 1 + row = result[0] + + # Verify UUID field is uuid.UUID object + assert isinstance(row.uid, uuid.UUID) + assert row.uid == test_uuid + class TestUUIDTypeCompatibility: """Test UUID type compatibility with existing functionality.""" From 0fffdc006abd11515fae7bf253d6c8b42ff3a66a Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Tue, 2 Dec 2025 13:23:25 +0200 Subject: [PATCH 2/4] update CI configuration to include Python 3.14; adjust Python version requirements in pyproject.toml; drop 3.8 --- .github/workflows/ci.yml | 2 +- pyproject.toml | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c468ae1..9181f42 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,7 @@ jobs: - 5432:5432 strategy: matrix: - python: ["3.8", "3.10", "3.11", "3.13.3"] + python: ["3.14", "3.10", "3.11", "3.13.3"] steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index fab193b..869020d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -26,7 +25,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] keywords = ["postgresql", "psqlpy", "sqlalchemy", "database", "async"] -requires-python = ">=3.8" +requires-python = ">=3.10" dependencies = [ "sqlalchemy>=2.0.0", "psqlpy>=0.11.0", @@ -65,7 +64,7 @@ include = ["psqlpy_sqlalchemy*"] [tool.ruff] line-length = 79 -target-version = "py38" +target-version = "py310" [tool.ruff.lint] select = [ From aa2178bfca0323166b6912b6facea0d9a8a8199c Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Tue, 2 Dec 2025 13:27:32 +0200 Subject: [PATCH 3/4] bump version to 0.1.1b1; update type hints to use new syntax; enhance type annotations across multiple files --- performance_comparison.py | 14 +++--- psqlpy_sqlalchemy/__init__.py | 2 +- psqlpy_sqlalchemy/connection.py | 58 +++++++++++-------------- psqlpy_sqlalchemy/dbapi.py | 2 +- psqlpy_sqlalchemy/dialect.py | 19 ++++---- pyproject.toml | 2 +- tests/test_custom_fastapi_middleware.py | 15 +++---- tests/test_dbapi.py | 28 +++++++----- 8 files changed, 67 insertions(+), 73 deletions(-) diff --git a/performance_comparison.py b/performance_comparison.py index 91fca87..54b33a3 100644 --- a/performance_comparison.py +++ b/performance_comparison.py @@ -7,9 +7,7 @@ import asyncio import time -import typing as t from statistics import mean, median, stdev -from typing import Dict, List from sqlalchemy import ( Integer, @@ -36,7 +34,7 @@ class TestModel(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String(100), nullable=False) - description: Mapped[t.Optional[str]] = mapped_column(Text) + description: Mapped[str | None] = mapped_column(Text) value: Mapped[int] = mapped_column(Integer, default=0) @@ -45,13 +43,13 @@ class BenchmarkResult: def __init__(self, name: str): self.name = name - self.times: List[float] = [] + self.times: list[float] = [] def add_time(self, duration: float) -> None: """Add a timing measurement.""" self.times.append(duration) - def get_stats(self) -> Dict[str, float]: + def get_stats(self) -> dict[str, float]: """Calculate statistics for the benchmark.""" if not self.times: return {"mean": 0, "median": 0, "stdev": 0, "min": 0, "max": 0} @@ -298,7 +296,7 @@ async def benchmark_transaction( async def run_benchmarks( url: str, dialect_name: str -) -> Dict[str, BenchmarkResult]: +) -> dict[str, BenchmarkResult]: """Run all benchmarks for a specific dialect.""" print(f"\n{'=' * 60}") print(f"Running benchmarks for {dialect_name}") @@ -338,8 +336,8 @@ async def run_benchmarks( def print_comparison( - psqlpy_results: Dict[str, BenchmarkResult], - asyncpg_results: Dict[str, BenchmarkResult], + psqlpy_results: dict[str, BenchmarkResult], + asyncpg_results: dict[str, BenchmarkResult], ) -> None: """Print comparison of results.""" print(f"\n{'=' * 60}") diff --git a/psqlpy_sqlalchemy/__init__.py b/psqlpy_sqlalchemy/__init__.py index c270121..44518b2 100644 --- a/psqlpy_sqlalchemy/__init__.py +++ b/psqlpy_sqlalchemy/__init__.py @@ -2,5 +2,5 @@ PsqlpyDialect = PSQLPyAsyncDialect -__version__ = "0.1.0a12" +__version__ = "0.1.1b1" __all__ = ["PsqlpyDialect", "PSQLPyAsyncDialect"] diff --git a/psqlpy_sqlalchemy/connection.py b/psqlpy_sqlalchemy/connection.py index a12962c..00b20ac 100644 --- a/psqlpy_sqlalchemy/connection.py +++ b/psqlpy_sqlalchemy/connection.py @@ -4,7 +4,7 @@ import time import typing as t from collections import deque -from typing import Any, Optional, Tuple, Union +from typing import Any import psqlpy from psqlpy import row_factories @@ -55,7 +55,7 @@ class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor): _adapt_connection: "AsyncAdapt_psqlpy_connection" _connection: psqlpy.Connection # type: ignore[assignment] - _cursor: t.Optional[t.Any] # type: ignore[assignment] + _cursor: t.Any | None # type: ignore[assignment] _awaitable_cursor_close: bool = False def __init__( @@ -66,7 +66,7 @@ def __init__( self.await_ = adapt_connection.await_ self._rows: deque[t.Any] = deque() self._cursor = None - self._description: t.Optional[t.List[t.Tuple[t.Any, ...]]] = None + self._description: list[tuple[t.Any, ...]] | None = None self._arraysize = 1 self._rowcount = -1 self._invalidate_schema_cache_asof = 0 @@ -74,9 +74,7 @@ def __init__( async def _prepare_execute( self, querystring: str, - parameters: t.Union[ - t.Sequence[t.Any], t.Mapping[str, Any], None - ] = None, + parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, ) -> None: """Execute a prepared statement. @@ -175,10 +173,8 @@ async def _prepare_execute( def _process_parameters( self, - parameters: t.Union[ - t.Sequence[t.Any], t.Mapping[str, Any], None - ] = None, - ) -> t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None]: + parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, + ) -> t.Sequence[t.Any] | t.Mapping[str, Any] | None: """Process parameters for type conversion. Converts UUID objects to bytes format required by psqlpy. @@ -209,7 +205,7 @@ def process_value(value: Any) -> Any: return { key: process_value(value) for key, value in parameters.items() } - if isinstance(parameters, (list, tuple)): + if isinstance(parameters, list | tuple): return type(parameters)( process_value(value) for value in parameters ) @@ -218,10 +214,8 @@ def process_value(value: Any) -> Any: def _convert_named_params_with_casting( self, querystring: str, - parameters: t.Union[ - t.Sequence[t.Any], t.Mapping[str, Any], None - ] = None, - ) -> t.Tuple[str, t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None]]: + parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, + ) -> tuple[str, t.Sequence[t.Any] | t.Mapping[str, Any] | None]: """Convert named parameters with PostgreSQL casting syntax to positional parameters. Transforms queries like: @@ -326,7 +320,7 @@ def _convert_named_params_with_casting( return converted_query, converted_params @property - def description(self) -> "Optional[_DBAPICursorDescription]": + def description(self) -> "_DBAPICursorDescription | None": return self._description @property @@ -387,7 +381,7 @@ async def _executemany( # Process all parameters first if seq_of_parameters and all( - isinstance(p, (list, tuple)) for p in seq_of_parameters + isinstance(p, list | tuple) for p in seq_of_parameters ): converted_seq = [list(p) for p in seq_of_parameters] else: @@ -398,7 +392,7 @@ async def _executemany( converted_seq.append([]) elif isinstance(processed, dict): converted_seq.append(list(processed.values())) - elif isinstance(processed, (list, tuple)): + elif isinstance(processed, list | tuple): converted_seq.append(list(processed)) else: converted_seq.append([processed]) @@ -458,7 +452,7 @@ async def _executemany( if adapt_connection._transaction is not None: try: # Build queries list for pipeline: [(query, params), ...] - queries: t.List[t.Tuple[str, t.Optional[t.List[t.Any]]]] = [ + queries: list[tuple[str, list[t.Any] | None]] = [ (operation, params) for params in converted_seq ] await adapt_connection._transaction.pipeline( @@ -479,9 +473,7 @@ async def _executemany( def execute( self, operation: t.Any, - parameters: t.Union[ - t.Sequence[t.Any], t.Mapping[str, Any], None - ] = None, + parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, ) -> None: self.await_(self._prepare_execute(operation, parameters)) @@ -500,7 +492,7 @@ class AsyncAdapt_psqlpy_ss_cursor( ): """Server-side cursor implementation for psqlpy.""" - _cursor: t.Optional[psqlpy.Cursor] # type: ignore[assignment] + _cursor: psqlpy.Cursor | None # type: ignore[assignment] def __init__( self, adapt_connection: "AsyncAdapt_psqlpy_connection" @@ -514,7 +506,7 @@ def __init__( def _convert_result( self, result: psqlpy.QueryResult, - ) -> Tuple[Tuple[Any, ...], ...]: + ) -> tuple[tuple[Any, ...], ...]: """Convert psqlpy QueryResult to tuple of tuples.""" if result is None: return () @@ -540,7 +532,7 @@ def close(self) -> None: self._cursor = None self._closed = True - def fetchone(self) -> Optional[Tuple[Any, ...]]: + def fetchone(self) -> tuple[Any, ...] | None: """Fetch the next row from the cursor.""" if self._closed or self._cursor is None: return None @@ -552,7 +544,7 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]: except Exception: return None - def fetchmany(self, size: Optional[int] = None) -> t.List[Tuple[Any, ...]]: + def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]: """Fetch the next set of rows from the cursor.""" if self._closed or self._cursor is None: return [] @@ -565,7 +557,7 @@ def fetchmany(self, size: Optional[int] = None) -> t.List[Tuple[Any, ...]]: except Exception: return [] - def fetchall(self) -> t.List[Tuple[Any, ...]]: + def fetchall(self) -> list[tuple[Any, ...]]: """Fetch all remaining rows from the cursor.""" if self._closed or self._cursor is None: return [] @@ -576,7 +568,7 @@ def fetchall(self) -> t.List[Tuple[Any, ...]]: except Exception: return [] - def __iter__(self) -> t.Iterator[Tuple[Any, ...]]: + def __iter__(self) -> t.Iterator[tuple[Any, ...]]: if self._closed or self._cursor is None: return @@ -596,7 +588,7 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection): _ss_cursor_cls = AsyncAdapt_psqlpy_ss_cursor # type: ignore[assignment] _connection: psqlpy.Connection # type: ignore[assignment] - _transaction: t.Optional[psqlpy.Transaction] + _transaction: psqlpy.Transaction | None __slots__ = ( "_invalidate_schema_cache_asof", @@ -637,7 +629,7 @@ def __init__( # LRU cache for prepared statements. Defaults to 100 statements per # connection. The cache is on a per-connection basis, stored within # connections pooled by the connection pool. - self._prepared_statement_cache: t.Optional[util.LRUCache[t.Any, t.Any]] + self._prepared_statement_cache: util.LRUCache[t.Any, t.Any] | None if prepared_statement_cache_size > 0: self._prepared_statement_cache = util.LRUCache( prepared_statement_cache_size @@ -649,7 +641,7 @@ def __init__( self._prepared_statement_name_func = self._default_name_func # Legacy query cache (kept for compatibility) - self._query_cache: t.Dict[str, t.Any] = {} + self._query_cache: dict[str, t.Any] = {} self._cache_max_size = prepared_statement_cache_size async def _check_type_cache_invalidation( @@ -739,7 +731,7 @@ def ping(self, reconnect: t.Any = None) -> t.Any: self._connection_valid = False return False - def _get_cached_query(self, query_key: str) -> t.Optional[t.Any]: + def _get_cached_query(self, query_key: str) -> t.Any | None: """Get a cached prepared statement if available.""" return self._query_cache.get(query_key) @@ -761,7 +753,7 @@ def close(self) -> None: def cursor( self, server_side: bool = False - ) -> Union[AsyncAdapt_psqlpy_cursor, AsyncAdapt_psqlpy_ss_cursor]: + ) -> AsyncAdapt_psqlpy_cursor | AsyncAdapt_psqlpy_ss_cursor: if server_side: return self._ss_cursor_cls(self) return self._cursor_cls(self) diff --git a/psqlpy_sqlalchemy/dbapi.py b/psqlpy_sqlalchemy/dbapi.py index 46e9c63..15fb807 100644 --- a/psqlpy_sqlalchemy/dbapi.py +++ b/psqlpy_sqlalchemy/dbapi.py @@ -155,7 +155,7 @@ def TimestampFromTicks(self, ticks: float) -> t.Any: return datetime.datetime.fromtimestamp(ticks) - def Binary(self, string: t.Union[str, bytes]) -> bytes: + def Binary(self, string: str | bytes) -> bytes: """Construct a binary value""" if isinstance(string, str): return string.encode("utf-8") diff --git a/psqlpy_sqlalchemy/dialect.py b/psqlpy_sqlalchemy/dialect.py index 820c7a2..a7efad3 100644 --- a/psqlpy_sqlalchemy/dialect.py +++ b/psqlpy_sqlalchemy/dialect.py @@ -1,7 +1,8 @@ import typing as t import uuid +from collections.abc import MutableMapping, Sequence from types import ModuleType -from typing import Any, Dict, MutableMapping, Sequence, Tuple +from typing import Any import psqlpy from sqlalchemy import URL, util @@ -28,8 +29,8 @@ class CompatibleNullPool(NullPool): def __init__( self, creator: t.Any, - pool_size: t.Optional[int] = None, - max_overflow: t.Optional[int] = None, + pool_size: int | None = None, + max_overflow: int | None = None, **kw: t.Any, ) -> None: # Filter out pool sizing arguments that NullPool doesn't accept @@ -231,10 +232,10 @@ class _PGUUID(UUID[t.Any]): def bind_processor( self, dialect: t.Any - ) -> t.Optional[t.Callable[[t.Any], t.Any]]: + ) -> t.Callable[[t.Any], t.Any] | None: """Process UUID parameters for psqlpy compatibility.""" - def process(value: t.Any) -> t.Optional[bytes]: + def process(value: t.Any) -> bytes | None: if value is None: return None if isinstance(value, uuid.UUID): @@ -258,7 +259,7 @@ def process(value: t.Any) -> t.Optional[bytes]: def result_processor( self, dialect: t.Any, coltype: t.Any - ) -> t.Optional[t.Callable[[t.Any], t.Any]]: + ) -> t.Callable[[t.Any], t.Any] | None: """Process UUID results from psqlpy. Converts string UUID values returned by psqlpy to Python uuid.UUID objects @@ -266,7 +267,7 @@ def result_processor( """ if self.as_uuid: - def process(value: t.Any) -> t.Optional[uuid.UUID]: + def process(value: t.Any) -> uuid.UUID | None: if value is None: return None if isinstance(value, uuid.UUID): @@ -341,7 +342,7 @@ def import_dbapi(cls) -> ModuleType: return t.cast(ModuleType, PSQLPyAdaptDBAPI(__import__("psqlpy"))) @util.memoized_property - def _isolation_lookup(self) -> Dict[str, Any]: + def _isolation_lookup(self) -> dict[str, Any]: """Mapping of SQLAlchemy isolation levels to psqlpy isolation levels""" return { "READ_COMMITTED": psqlpy.IsolationLevel.ReadCommitted, @@ -352,7 +353,7 @@ def _isolation_lookup(self) -> Dict[str, Any]: def create_connect_args( self, url: URL, - ) -> Tuple[Sequence[str], MutableMapping[str, Any]]: + ) -> tuple[Sequence[str], MutableMapping[str, Any]]: opts = url.translate_connect_args() return ( [], diff --git a/pyproject.toml b/pyproject.toml index 869020d..c76ef73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "psqlpy-sqlalchemy" -version = "0.1.0a12" +version = "0.1.1b1" description = "SQLAlchemy dialect for psqlpy PostgreSQL driver" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_custom_fastapi_middleware.py b/tests/test_custom_fastapi_middleware.py index 974fbff..730f9eb 100644 --- a/tests/test_custom_fastapi_middleware.py +++ b/tests/test_custom_fastapi_middleware.py @@ -7,7 +7,6 @@ import asyncio import unittest from contextvars import ContextVar -from typing import Dict, Optional, Union try: from fastapi import FastAPI @@ -82,8 +81,8 @@ class SessionNotInitialisedError(Exception): def create_middleware_and_session_proxy(): """Create the custom middleware and session proxy as provided in the issue.""" - _Session: Optional[async_sessionmaker] = None - _session: ContextVar[Optional[AsyncSession]] = ContextVar( + _Session: async_sessionmaker | None = None + _session: ContextVar[AsyncSession | None] = ContextVar( "_session", default=None ) _multi_sessions_ctx: ContextVar[bool] = ContextVar( @@ -97,10 +96,10 @@ class SQLAlchemyMiddleware(BaseHTTPMiddleware): def __init__( self, app: ASGIApp, - db_url: Optional[Union[str, URL]] = None, - custom_engine: Optional[Engine] = None, - engine_args: Dict = None, - session_args: Dict = None, + db_url: str | URL | None = None, + custom_engine: Engine | None = None, + engine_args: dict = None, + session_args: dict = None, commit_on_exit: bool = False, ): super().__init__(app) @@ -168,7 +167,7 @@ async def cleanup(): class DBSession(metaclass=DBSessionMeta): def __init__( self, - session_args: Dict = None, + session_args: dict = None, commit_on_exit: bool = False, multi_sessions: bool = False, ): diff --git a/tests/test_dbapi.py b/tests/test_dbapi.py index 3841fe4..cc3e956 100644 --- a/tests/test_dbapi.py +++ b/tests/test_dbapi.py @@ -150,9 +150,10 @@ def setUp(self): def test_server_settings_handling(self): """Test server_settings parameter handling""" - with patch( - "psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection" - ), patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only: + with ( + patch("psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection"), + patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only, + ): mock_connection = Mock() mock_await_only.return_value = mock_connection @@ -188,9 +189,10 @@ def test_server_settings_handling(self): def test_server_settings_without_application_name(self): """Test server_settings parameter handling without application_name""" - with patch( - "psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection" - ), patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only: + with ( + patch("psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection"), + patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only, + ): mock_connection = Mock() mock_await_only.return_value = mock_connection @@ -210,9 +212,10 @@ def test_server_settings_without_application_name(self): def test_connect_without_server_settings(self): """Test connect method without server_settings""" - with patch( - "psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection" - ), patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only: + with ( + patch("psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection"), + patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only, + ): mock_connection = Mock() mock_await_only.return_value = mock_connection @@ -226,9 +229,10 @@ def test_connect_without_server_settings(self): def test_parameter_filtering(self): """Test that unsupported parameters are filtered out""" - with patch( - "psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection" - ), patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only: + with ( + patch("psqlpy_sqlalchemy.dbapi.AsyncAdapt_psqlpy_connection"), + patch("psqlpy_sqlalchemy.dbapi.await_only") as mock_await_only, + ): mock_connection = Mock() mock_await_only.return_value = mock_connection From e8f197294d2c7ac25a73625eec18bade44b09656 Mon Sep 17 00:00:00 2001 From: Eugene Shershen Date: Tue, 2 Dec 2025 13:45:06 +0200 Subject: [PATCH 4/4] refactor type hints to use new syntax; update Optional to use union types for better clarity --- tests/test_sqlmodel_compatibility.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_sqlmodel_compatibility.py b/tests/test_sqlmodel_compatibility.py index 0f0312e..7573fcd 100644 --- a/tests/test_sqlmodel_compatibility.py +++ b/tests/test_sqlmodel_compatibility.py @@ -4,7 +4,6 @@ """ import unittest -from typing import Optional from sqlalchemy import create_engine from sqlmodel import Field, Session, SQLModel, select @@ -13,10 +12,10 @@ class Hero(SQLModel, table=True): """Test model for SQLModel compatibility tests""" - id: Optional[int] = Field(default=None, primary_key=True) + id: int | None = Field(default=None, primary_key=True) name: str secret_name: str - age: Optional[int] = None + age: int | None = None class TestSQLModelCompatibility(unittest.TestCase): @@ -120,16 +119,16 @@ def test_sqlmodel_relationships(self): """Test SQLModel relationship handling""" class Team(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) + id: int | None = Field(default=None, primary_key=True) name: str headquarters: str class HeroWithTeam(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) + id: int | None = Field(default=None, primary_key=True) name: str secret_name: str - age: Optional[int] = None - team_id: Optional[int] = Field(default=None, foreign_key="team.id") + age: int | None = None + team_id: int | None = Field(default=None, foreign_key="team.id") # Create tables for these models Team.metadata.create_all(self.engine)