Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- 5432:5432
strategy:
matrix:
python: ["3.8", "3.10", "3.11", "3.13.3"]
python: ["3.14", "3.10", "3.11", "3.13.3"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down
14 changes: 6 additions & 8 deletions performance_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

import asyncio
import time
import typing as t
from statistics import mean, median, stdev
from typing import Dict, List

from sqlalchemy import (
Integer,
Expand All @@ -36,7 +34,7 @@ class TestModel(Base):

id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
description: Mapped[t.Optional[str]] = mapped_column(Text)
description: Mapped[str | None] = mapped_column(Text)
value: Mapped[int] = mapped_column(Integer, default=0)


Expand All @@ -45,13 +43,13 @@ class BenchmarkResult:

def __init__(self, name: str):
self.name = name
self.times: List[float] = []
self.times: list[float] = []

def add_time(self, duration: float) -> None:
"""Add a timing measurement."""
self.times.append(duration)

def get_stats(self) -> Dict[str, float]:
def get_stats(self) -> dict[str, float]:
"""Calculate statistics for the benchmark."""
if not self.times:
return {"mean": 0, "median": 0, "stdev": 0, "min": 0, "max": 0}
Expand Down Expand Up @@ -298,7 +296,7 @@ async def benchmark_transaction(

async def run_benchmarks(
url: str, dialect_name: str
) -> Dict[str, BenchmarkResult]:
) -> dict[str, BenchmarkResult]:
"""Run all benchmarks for a specific dialect."""
print(f"\n{'=' * 60}")
print(f"Running benchmarks for {dialect_name}")
Expand Down Expand Up @@ -338,8 +336,8 @@ async def run_benchmarks(


def print_comparison(
psqlpy_results: Dict[str, BenchmarkResult],
asyncpg_results: Dict[str, BenchmarkResult],
psqlpy_results: dict[str, BenchmarkResult],
asyncpg_results: dict[str, BenchmarkResult],
) -> None:
"""Print comparison of results."""
print(f"\n{'=' * 60}")
Expand Down
2 changes: 1 addition & 1 deletion psqlpy_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

PsqlpyDialect = PSQLPyAsyncDialect

__version__ = "0.1.0a12"
__version__ = "0.1.1b1"
__all__ = ["PsqlpyDialect", "PSQLPyAsyncDialect"]
58 changes: 25 additions & 33 deletions psqlpy_sqlalchemy/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import typing as t
from collections import deque
from typing import Any, Optional, Tuple, Union
from typing import Any

import psqlpy
from psqlpy import row_factories
Expand Down Expand Up @@ -55,7 +55,7 @@ class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor):

_adapt_connection: "AsyncAdapt_psqlpy_connection"
_connection: psqlpy.Connection # type: ignore[assignment]
_cursor: t.Optional[t.Any] # type: ignore[assignment]
_cursor: t.Any | None # type: ignore[assignment]
_awaitable_cursor_close: bool = False

def __init__(
Expand All @@ -66,17 +66,15 @@ def __init__(
self.await_ = adapt_connection.await_
self._rows: deque[t.Any] = deque()
self._cursor = None
self._description: t.Optional[t.List[t.Tuple[t.Any, ...]]] = None
self._description: list[tuple[t.Any, ...]] | None = None
self._arraysize = 1
self._rowcount = -1
self._invalidate_schema_cache_asof = 0

async def _prepare_execute(
self,
querystring: str,
parameters: t.Union[
t.Sequence[t.Any], t.Mapping[str, Any], None
] = None,
parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None,
) -> None:
"""Execute a prepared statement.

Expand Down Expand Up @@ -175,10 +173,8 @@ async def _prepare_execute(

def _process_parameters(
self,
parameters: t.Union[
t.Sequence[t.Any], t.Mapping[str, Any], None
] = None,
) -> t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None]:
parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None,
) -> t.Sequence[t.Any] | t.Mapping[str, Any] | None:
"""Process parameters for type conversion.

Converts UUID objects to bytes format required by psqlpy.
Expand Down Expand Up @@ -209,7 +205,7 @@ def process_value(value: Any) -> Any:
return {
key: process_value(value) for key, value in parameters.items()
}
if isinstance(parameters, (list, tuple)):
if isinstance(parameters, list | tuple):
return type(parameters)(
process_value(value) for value in parameters
)
Expand All @@ -218,10 +214,8 @@ def process_value(value: Any) -> Any:
def _convert_named_params_with_casting(
self,
querystring: str,
parameters: t.Union[
t.Sequence[t.Any], t.Mapping[str, Any], None
] = None,
) -> t.Tuple[str, t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None]]:
parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None,
) -> tuple[str, t.Sequence[t.Any] | t.Mapping[str, Any] | None]:
"""Convert named parameters with PostgreSQL casting syntax to positional parameters.

Transforms queries like:
Expand Down Expand Up @@ -326,7 +320,7 @@ def _convert_named_params_with_casting(
return converted_query, converted_params

@property
def description(self) -> "Optional[_DBAPICursorDescription]":
def description(self) -> "_DBAPICursorDescription | None":
return self._description

@property
Expand Down Expand Up @@ -387,7 +381,7 @@ async def _executemany(

# Process all parameters first
if seq_of_parameters and all(
isinstance(p, (list, tuple)) for p in seq_of_parameters
isinstance(p, list | tuple) for p in seq_of_parameters
):
converted_seq = [list(p) for p in seq_of_parameters]
else:
Expand All @@ -398,7 +392,7 @@ async def _executemany(
converted_seq.append([])
elif isinstance(processed, dict):
converted_seq.append(list(processed.values()))
elif isinstance(processed, (list, tuple)):
elif isinstance(processed, list | tuple):
converted_seq.append(list(processed))
else:
converted_seq.append([processed])
Expand Down Expand Up @@ -458,7 +452,7 @@ async def _executemany(
if adapt_connection._transaction is not None:
try:
# Build queries list for pipeline: [(query, params), ...]
queries: t.List[t.Tuple[str, t.Optional[t.List[t.Any]]]] = [
queries: list[tuple[str, list[t.Any] | None]] = [
(operation, params) for params in converted_seq
]
await adapt_connection._transaction.pipeline(
Expand All @@ -479,9 +473,7 @@ async def _executemany(
def execute(
self,
operation: t.Any,
parameters: t.Union[
t.Sequence[t.Any], t.Mapping[str, Any], None
] = None,
parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None,
) -> None:
self.await_(self._prepare_execute(operation, parameters))

Expand All @@ -500,7 +492,7 @@ class AsyncAdapt_psqlpy_ss_cursor(
):
"""Server-side cursor implementation for psqlpy."""

_cursor: t.Optional[psqlpy.Cursor] # type: ignore[assignment]
_cursor: psqlpy.Cursor | None # type: ignore[assignment]

def __init__(
self, adapt_connection: "AsyncAdapt_psqlpy_connection"
Expand All @@ -514,7 +506,7 @@ def __init__(
def _convert_result(
self,
result: psqlpy.QueryResult,
) -> Tuple[Tuple[Any, ...], ...]:
) -> tuple[tuple[Any, ...], ...]:
"""Convert psqlpy QueryResult to tuple of tuples."""
if result is None:
return ()
Expand All @@ -540,7 +532,7 @@ def close(self) -> None:
self._cursor = None
self._closed = True

def fetchone(self) -> Optional[Tuple[Any, ...]]:
def fetchone(self) -> tuple[Any, ...] | None:
"""Fetch the next row from the cursor."""
if self._closed or self._cursor is None:
return None
Expand All @@ -552,7 +544,7 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]:
except Exception:
return None

def fetchmany(self, size: Optional[int] = None) -> t.List[Tuple[Any, ...]]:
def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
"""Fetch the next set of rows from the cursor."""
if self._closed or self._cursor is None:
return []
Expand All @@ -565,7 +557,7 @@ def fetchmany(self, size: Optional[int] = None) -> t.List[Tuple[Any, ...]]:
except Exception:
return []

def fetchall(self) -> t.List[Tuple[Any, ...]]:
def fetchall(self) -> list[tuple[Any, ...]]:
"""Fetch all remaining rows from the cursor."""
if self._closed or self._cursor is None:
return []
Expand All @@ -576,7 +568,7 @@ def fetchall(self) -> t.List[Tuple[Any, ...]]:
except Exception:
return []

def __iter__(self) -> t.Iterator[Tuple[Any, ...]]:
def __iter__(self) -> t.Iterator[tuple[Any, ...]]:
if self._closed or self._cursor is None:
return

Expand All @@ -596,7 +588,7 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection):
_ss_cursor_cls = AsyncAdapt_psqlpy_ss_cursor # type: ignore[assignment]

_connection: psqlpy.Connection # type: ignore[assignment]
_transaction: t.Optional[psqlpy.Transaction]
_transaction: psqlpy.Transaction | None

__slots__ = (
"_invalidate_schema_cache_asof",
Expand Down Expand Up @@ -637,7 +629,7 @@ def __init__(
# LRU cache for prepared statements. Defaults to 100 statements per
# connection. The cache is on a per-connection basis, stored within
# connections pooled by the connection pool.
self._prepared_statement_cache: t.Optional[util.LRUCache[t.Any, t.Any]]
self._prepared_statement_cache: util.LRUCache[t.Any, t.Any] | None
if prepared_statement_cache_size > 0:
self._prepared_statement_cache = util.LRUCache(
prepared_statement_cache_size
Expand All @@ -649,7 +641,7 @@ def __init__(
self._prepared_statement_name_func = self._default_name_func

# Legacy query cache (kept for compatibility)
self._query_cache: t.Dict[str, t.Any] = {}
self._query_cache: dict[str, t.Any] = {}
self._cache_max_size = prepared_statement_cache_size

async def _check_type_cache_invalidation(
Expand Down Expand Up @@ -739,7 +731,7 @@ def ping(self, reconnect: t.Any = None) -> t.Any:
self._connection_valid = False
return False

def _get_cached_query(self, query_key: str) -> t.Optional[t.Any]:
def _get_cached_query(self, query_key: str) -> t.Any | None:
"""Get a cached prepared statement if available."""
return self._query_cache.get(query_key)

Expand All @@ -761,7 +753,7 @@ def close(self) -> None:

def cursor(
self, server_side: bool = False
) -> Union[AsyncAdapt_psqlpy_cursor, AsyncAdapt_psqlpy_ss_cursor]:
) -> AsyncAdapt_psqlpy_cursor | AsyncAdapt_psqlpy_ss_cursor:
if server_side:
return self._ss_cursor_cls(self)
return self._cursor_cls(self)
Expand Down
2 changes: 1 addition & 1 deletion psqlpy_sqlalchemy/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def TimestampFromTicks(self, ticks: float) -> t.Any:

return datetime.datetime.fromtimestamp(ticks)

def Binary(self, string: t.Union[str, bytes]) -> bytes:
def Binary(self, string: str | bytes) -> bytes:
"""Construct a binary value"""
if isinstance(string, str):
return string.encode("utf-8")
Expand Down
42 changes: 35 additions & 7 deletions psqlpy_sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing as t
import uuid
from collections.abc import MutableMapping, Sequence
from types import ModuleType
from typing import Any, Dict, MutableMapping, Sequence, Tuple
from typing import Any

import psqlpy
from sqlalchemy import URL, util
Expand All @@ -28,8 +29,8 @@ class CompatibleNullPool(NullPool):
def __init__(
self,
creator: t.Any,
pool_size: t.Optional[int] = None,
max_overflow: t.Optional[int] = None,
pool_size: int | None = None,
max_overflow: int | None = None,
**kw: t.Any,
) -> None:
# Filter out pool sizing arguments that NullPool doesn't accept
Expand Down Expand Up @@ -231,10 +232,10 @@ class _PGUUID(UUID[t.Any]):

def bind_processor(
self, dialect: t.Any
) -> t.Optional[t.Callable[[t.Any], t.Any]]:
) -> t.Callable[[t.Any], t.Any] | None:
"""Process UUID parameters for psqlpy compatibility."""

def process(value: t.Any) -> t.Optional[bytes]:
def process(value: t.Any) -> bytes | None:
if value is None:
return None
if isinstance(value, uuid.UUID):
Expand All @@ -256,6 +257,33 @@ def process(value: t.Any) -> t.Optional[bytes]:

return process

def result_processor(
self, dialect: t.Any, coltype: t.Any
) -> t.Callable[[t.Any], t.Any] | None:
"""Process UUID results from psqlpy.

Converts string UUID values returned by psqlpy to Python uuid.UUID objects
when as_uuid=True (which is the default in SQLAlchemy 2.0+).
"""
if self.as_uuid:

def process(value: t.Any) -> uuid.UUID | None:
if value is None:
return None
if isinstance(value, uuid.UUID):
return value
if isinstance(value, str):
# psqlpy returns UUID as string, convert to uuid.UUID
return uuid.UUID(value)
if isinstance(value, bytes):
# Handle bytes representation
return uuid.UUID(bytes=value)
# For other types, try to convert
return uuid.UUID(str(value))

return process
return None


class PSQLPyAsyncDialect(PGDialect):
driver = "psqlpy"
Expand Down Expand Up @@ -314,7 +342,7 @@ def import_dbapi(cls) -> ModuleType:
return t.cast(ModuleType, PSQLPyAdaptDBAPI(__import__("psqlpy")))

@util.memoized_property
def _isolation_lookup(self) -> Dict[str, Any]:
def _isolation_lookup(self) -> dict[str, Any]:
"""Mapping of SQLAlchemy isolation levels to psqlpy isolation levels"""
return {
"READ_COMMITTED": psqlpy.IsolationLevel.ReadCommitted,
Expand All @@ -325,7 +353,7 @@ def _isolation_lookup(self) -> Dict[str, Any]:
def create_connect_args(
self,
url: URL,
) -> Tuple[Sequence[str], MutableMapping[str, Any]]:
) -> tuple[Sequence[str], MutableMapping[str, Any]]:
opts = url.translate_connect_args()
return (
[],
Expand Down
Loading