diff --git a/psqlpy_sqlalchemy/connection.py b/psqlpy_sqlalchemy/connection.py index 00b20ac..c94b1fe 100644 --- a/psqlpy_sqlalchemy/connection.py +++ b/psqlpy_sqlalchemy/connection.py @@ -1,8 +1,8 @@ import asyncio -import contextlib import re import time import typing as t +import uuid from collections import deque from typing import Any @@ -19,7 +19,6 @@ # Compiled regex patterns used for parameter substitution _PARAM_PATTERN = re.compile(r":([a-zA-Z_][a-zA-Z0-9_]*)(::[\w\[\]]+)?") -_CASTING_PATTERN = re.compile(r":([a-zA-Z_][a-zA-Z0-9_]*)::") _POSITIONAL_CHECK = re.compile(r"\$\d+:$") # UUID pattern for validation @@ -28,6 +27,9 @@ re.IGNORECASE, ) +# Cache for compiled parameter-specific regex patterns +_PARAM_REGEX_CACHE: dict[str, re.Pattern[str]] = {} + if t.TYPE_CHECKING: from sqlalchemy.engine.interfaces import ( DBAPICursor, @@ -76,73 +78,39 @@ async def _prepare_execute( querystring: str, parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, ) -> None: - """Execute a prepared statement. - - Ensures transaction context is active before executing the statement. - Processes parameters to handle type conversions and named-to-positional - parameter conversion for PostgreSQL's numeric parameter style. - """ - adapt_conn = self._adapt_connection - - # Ensure transaction is started - if not adapt_conn._started: - await adapt_conn._start_transaction() + """Execute a prepared statement.""" + if not self._adapt_connection._started: + await self._adapt_connection._start_transaction() - # Process parameters to ensure proper type conversion (especially for UUIDs) - processed_parameters = self._process_parameters(parameters) - - # Convert named parameters with casting syntax to positional parameters - converted_query, converted_params = ( - self._convert_named_params_with_casting( - querystring, processed_parameters - ) + # Convert params + converted_query, converted_params = self._convert_params_single_pass( + querystring, parameters ) - # Handle mixed parameter styles specifically for explicit PostgreSQL casting - # Only trigger this for queries with explicit casting syntax like :param::TYPE - if ( - converted_params is not None - and not isinstance(converted_params, dict) - and converted_query == querystring - ): # Query unchanged means mixed parameters detected - # Look specifically for PostgreSQL casting syntax :param::TYPE - casting_pattern = r":([a-zA-Z_][a-zA-Z0-9_]*)::" - casting_matches = re.findall(casting_pattern, converted_query) - - if casting_matches: - # This is a known limitation: SQLAlchemy can't handle named parameters with explicit PostgreSQL casting - raise RuntimeError( - f"Named parameters with explicit PostgreSQL casting are not supported. " - f"Found casting parameters: {casting_matches} in query: {converted_query[:100]}... " - f"SQLAlchemy filters out parameters when explicit casting syntax like ':param::TYPE' is used. " - f"Solutions: " - f"1) Use positional parameters: 'WHERE uid = $1::UUID LIMIT $2' with parameters as a list, " - f"2) Remove explicit casting: 'WHERE uid = :uid LIMIT :limit' (casting will be handled automatically), " - f"3) Use SQLAlchemy's cast() function: 'WHERE uid = cast(:uid, UUID) LIMIT :limit'" + try: + # DML without RETURNING: use execute() directly + query_upper = converted_query.upper() + if ( + query_upper.lstrip()[:6] in ("INSERT", "UPDATE", "DELETE") + and "RETURNING" not in query_upper + ): + await self._connection.execute( + converted_query, converted_params, prepared=True ) + self._description = None + self._rowcount = 1 + self._rows = deque() + return - try: - # NOTE: psqlpy's Python API requires parameters at prepare() time - # and PreparedStatement.execute() doesn't accept parameters. - # While psqlpy's internal Rust API supports reusable prepared statements - # (used by execute_many), the Python API doesn't expose this capability. - # This prevents caching prepared statements like asyncpg does. + # SELECT/complex: use prepare() for column metadata prepared_stmt = await self._connection.prepare( querystring=converted_query, parameters=converted_params, ) self._description = [ - ( - column.name, - column.table_oid, - None, # display_size - None, # internal_size - None, # precision - None, # scale - None, # null_ok - ) - for column in prepared_stmt.columns() + (col.name, col.table_oid, None, None, None, None, None) + for col in prepared_stmt.columns() ] if self.server_side: @@ -156,13 +124,11 @@ async def _prepare_execute( results = await prepared_stmt.execute() - # Direct iteration without intermediate tuple creation - rows_list = [ + self._rows = deque( tuple(value for _, value in row) for row in results.row_factory(row_factories.tuple_row) - ] - self._rows = deque(rows_list) - self._rowcount = len(rows_list) + ) + self._rowcount = len(self._rows) except Exception: self._description = None @@ -175,146 +141,166 @@ def _process_parameters( self, parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, ) -> t.Sequence[t.Any] | t.Mapping[str, Any] | None: - """Process parameters for type conversion. + """Process parameters for type conversion (legacy, used by executemany). Converts UUID objects to bytes format required by psqlpy. - Also handles string UUIDs by parsing and converting to bytes. """ if parameters is None: return None - import uuid - def process_value(value: Any) -> Any: - """Process a single parameter value for UUID conversion.""" if value is None: return None if isinstance(value, uuid.UUID): return value.bytes - # Only attempt UUID parsing for strings matching UUID pattern if isinstance(value, str) and _UUID_PATTERN.match(value): try: - parsed_uuid = uuid.UUID(value) - return parsed_uuid.bytes + return uuid.UUID(value).bytes except ValueError: - # Shouldn't happen with valid pattern, but be safe return value return value if isinstance(parameters, dict): - return { - key: process_value(value) for key, value in parameters.items() - } + return {k: process_value(v) for k, v in parameters.items()} if isinstance(parameters, list | tuple): - return type(parameters)( - process_value(value) for value in parameters - ) + return type(parameters)(process_value(v) for v in parameters) return process_value(parameters) - def _convert_named_params_with_casting( + def _convert_params_single_pass( self, querystring: str, parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, - ) -> tuple[str, t.Sequence[t.Any] | t.Mapping[str, Any] | None]: - """Convert named parameters with PostgreSQL casting syntax to positional parameters. + ) -> tuple[str, list[Any] | None]: + """Single-pass conversion: named→positional + UUID→bytes. + + Optimized to avoid multiple iterations over parameters. + """ + # Fast path: no parameters + if parameters is None: + return querystring, None + + # Fast path: already positional (list/tuple) + if isinstance(parameters, list | tuple): + # Just process UUIDs + converted: list[Any] = [] + for val in parameters: + if isinstance(val, uuid.UUID): + converted.append(val.bytes) + elif isinstance(val, str) and _UUID_PATTERN.match(val): + try: + converted.append(uuid.UUID(val).bytes) + except ValueError: + converted.append(val) + else: + converted.append(val) + return querystring, converted - Transforms queries like: - 'SELECT * FROM table WHERE col = :param::UUID LIMIT :limit' + # Dict parameters: need named→positional conversion + if not isinstance(parameters, dict): + return querystring, None - To: - 'SELECT * FROM table WHERE col = $1::UUID LIMIT $2' + # Fast path: no named params in query + if ":" not in querystring: + return querystring, list(parameters.values()) + + # Find all parameter references + matches = list(_PARAM_PATTERN.finditer(querystring)) + if not matches: + return querystring, list(parameters.values()) - And converts the parameters dict to a list in the correct order. - Uses pre-compiled regex patterns for parameter detection. + # Build param order (first occurrence wins) + param_order: list[str] = [] + seen: set[str] = set() + for match in matches: + name = match.group(1) + if name not in seen and name in parameters: + param_order.append(name) + seen.add(name) + + # Check for missing params - return original if any missing + for match in matches: + name = match.group(1) + if name not in parameters: + # Missing param - return original query and values as list + return querystring, list(parameters.values()) + + # Single loop: build converted params + query replacement + converted_params: list[Any] = [] + converted_query = querystring + + for i, name in enumerate(param_order, 1): + val = parameters[name] + # UUID conversion inline + if isinstance(val, uuid.UUID): + converted_params.append(val.bytes) + elif isinstance(val, str) and _UUID_PATTERN.match(val): + try: + converted_params.append(uuid.UUID(val).bytes) + except ValueError: + converted_params.append(val) + else: + converted_params.append(val) + + # Get or create cached regex for this param + if name not in _PARAM_REGEX_CACHE: + _PARAM_REGEX_CACHE[name] = re.compile( + rf":({re.escape(name)})(::[\w\[\]]+)?" + ) + converted_query = _PARAM_REGEX_CACHE[name].sub( + f"${i}\\2", converted_query + ) + + return converted_query, converted_params + + def _convert_named_params_with_casting( + self, + querystring: str, + parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, + ) -> tuple[str, t.Sequence[t.Any] | t.Mapping[str, Any] | None]: + """Convert named parameters to positional (without UUID conversion). + + Legacy method for backward compatibility. """ + # Fast path: no parameters or not a dict if parameters is None or not isinstance(parameters, dict): return querystring, parameters - # Find all parameter references in the query using pre-compiled pattern - matches = list(_PARAM_PATTERN.finditer(querystring)) + # Fast path: no named params in query + if ":" not in querystring: + return querystring, parameters + # Find all parameter references + matches = list(_PARAM_PATTERN.finditer(querystring)) if not matches: return querystring, parameters - # Build the conversion mapping and new parameter list - param_order = [] - seen_params = set() - missing_params = [] - - # Process matches to determine parameter order (first occurrence wins) + # Build param order (first occurrence wins) + param_order: list[str] = [] + seen: set[str] = set() for match in matches: - param_name = match.group(1) - if param_name not in seen_params: - if param_name in parameters: - param_order.append(param_name) - seen_params.add(param_name) - else: - missing_params.append(param_name) + name = match.group(1) + if name not in seen and name in parameters: + param_order.append(name) + seen.add(name) - # Defensive check: ensure all parameters found in query are available - if missing_params: - # Instead of raising an error, return the original query and parameters - # This prevents partial conversion which can cause SQL syntax errors - return querystring, parameters + # Check for missing params + for match in matches: + name = match.group(1) + if name not in parameters: + return querystring, parameters - # Convert the query string by replacing each parameter with its positional equivalent + # Build converted params + query + converted_params: list[Any] = [] converted_query = querystring - for i, param_name in enumerate(param_order, 1): - # Replace all occurrences of this parameter with $N, preserving any casting - param_pattern_specific = re.compile( - f":({re.escape(param_name)})" + r"(::[\w\[\]]+)?" - ) - replacement = f"${i}\\2" # $N + casting part (group 2) - - # Perform replacement and verify it worked - new_query = param_pattern_specific.sub( - replacement, converted_query - ) - - # Defensive check: ensure replacement actually occurred - if ( - new_query == converted_query - and f":{param_name}" in converted_query - ): - raise RuntimeError( - f"Failed to replace parameter '{param_name}' in query. " - f"Query: {converted_query}" + for i, name in enumerate(param_order, 1): + converted_params.append(parameters[name]) + if name not in _PARAM_REGEX_CACHE: + _PARAM_REGEX_CACHE[name] = re.compile( + rf":({re.escape(name)})(::[\w\[\]]+)?" ) - - converted_query = new_query - - # Convert parameters dict to list in the correct order - converted_params = [ - parameters[param_name] for param_name in param_order - ] - - # Final defensive check: ensure no named parameters remain in the converted query - # Look for the original parameter pattern, but exclude matches that are part of casting syntax - remaining_matches = [] - for match in _PARAM_PATTERN.finditer(converted_query): - full_match = match.group(0) - # Check if this looks like a real parameter (not casting syntax) - # Real parameters should not be preceded by a positional parameter like $1, $2, etc. - start_pos = match.start() - if start_pos > 0: - # Look at the characters before the match to see if this is casting syntax - # For casting syntax like $1::UUID, we need to check if preceded by $N: - preceding_text = converted_query[ - max(0, start_pos - 4) : start_pos - ] - # If preceded by $N: (positional parameter followed by colon), this is casting syntax - if _POSITIONAL_CHECK.search(preceding_text): - continue - # Also check the older pattern for backward compatibility - if re.search(r"\$\d+$", preceding_text): - continue - remaining_matches.append(full_match) - - if remaining_matches: - raise RuntimeError( - f"Conversion incomplete: named parameters still present in query: {remaining_matches}. " - f"Converted query: {converted_query}, Original query: {querystring}" + converted_query = _PARAM_REGEX_CACHE[name].sub( + f"${i}\\2", converted_query ) return converted_query, converted_params @@ -335,147 +321,74 @@ def arraysize(self) -> int: def arraysize(self, value: int) -> None: self._arraysize = value - def _is_simple_insert(self, operation: str) -> bool: - """Check if operation is a simple INSERT statement. - - Returns True if the query is INSERT INTO with VALUES clause - and no RETURNING clause, allowing multi-value INSERT transformation. - """ - operation_upper = operation.upper().strip() - return ( - operation_upper.startswith("INSERT INTO") - and "VALUES" in operation_upper - and "RETURNING" not in operation_upper - ) - async def _executemany( self, operation: str, seq_of_parameters: t.Sequence[t.Sequence[t.Any]], ) -> None: - """Execute a batch of parameter sets. - - For simple INSERT statements, automatically transforms multiple - individual INSERTs into a single multi-value INSERT statement: - - INSERT INTO t VALUES ($1, $2) (executed N times) + """Execute a batch of parameter sets.""" + if not self._adapt_connection._started: + await self._adapt_connection._start_transaction() - Becomes: + # Fast conversion + def convert_row(params: Any) -> list[Any]: + if params is None: + return [] + vals = params.values() if isinstance(params, dict) else params + return [v.bytes if isinstance(v, uuid.UUID) else v for v in vals] - INSERT INTO t VALUES ($1,$2), ($3,$4), ..., ($N*2-1,$N*2) + converted_seq = [convert_row(p) for p in seq_of_parameters] - This transformation reduces network round-trips from N to 1. - For non-INSERT statements, delegates to psqlpy's execute_many. - """ - adapt_connection = self._adapt_connection - self._description = None - - # Check for schema cache invalidation - await adapt_connection._check_type_cache_invalidation( - self._invalidate_schema_cache_asof - ) - - # Ensure transaction context is active before batch execution - if not adapt_connection._started: - await adapt_connection._start_transaction() - - # Process all parameters first - if seq_of_parameters and all( - isinstance(p, list | tuple) for p in seq_of_parameters + # INSERT: multi-value optimization + if ( + len(converted_seq) > 1 + and operation.lstrip()[:6].upper() == "INSERT" + and "RETURNING" not in operation.upper() ): - converted_seq = [list(p) for p in seq_of_parameters] - else: - converted_seq = [] - for params in seq_of_parameters: - processed = self._process_parameters(params) - if processed is None: - converted_seq.append([]) - elif isinstance(processed, dict): - converted_seq.append(list(processed.values())) - elif isinstance(processed, list | tuple): - converted_seq.append(list(processed)) - else: - converted_seq.append([processed]) - - # For simple INSERT statements, transform to multi-value INSERT - # to reduce network round-trips - if self._is_simple_insert(operation) and len(converted_seq) > 1: - # Build multi-value INSERT: VALUES ($1,$2), ($3,$4), ... - # Count placeholders in original query - placeholder_count = operation.count("$") - - if placeholder_count > 0: - # Build new VALUES clause with all rows - values_parts = [] - flat_params = [] - param_idx = 1 - - for row_params in converted_seq: - # Create placeholders for this row: ($1, $2, ...) - row_placeholders = ", ".join( - [ - f"${i}" - for i in range( - param_idx, param_idx + len(row_params) - ) - ] + try: + idx = 1 + parts = [] + flat: list[Any] = [] + for row in converted_seq: + n = len(row) + parts.append( + f"({', '.join(f'${i}' for i in range(idx, idx + n))})" ) - values_parts.append(f"({row_placeholders})") - flat_params.extend(row_params) - param_idx += len(row_params) - - # Replace original VALUES (...) with multi-row VALUES - # Find and replace the VALUES clause - import re + flat.extend(row) + idx += n - multi_value_query = re.sub( + query = re.sub( r"VALUES\s*\([^)]*\)", - f"VALUES {', '.join(values_parts)}", + f"VALUES {', '.join(parts)}", operation, flags=re.IGNORECASE, ) - - # Execute as single query - try: - await self._connection.execute( - multi_value_query, flat_params - ) - return None - except Exception: - # If multi-value fails, fall back to execute_many - pass - - # For non-INSERT statements, use pipeline when transaction is active. - # This provides protocol-level batching similar to asyncpg.executemany(). - # Pipeline sends all queries together and waits for all responses, - # dramatically reducing network round-trips compared to execute_many. - if adapt_connection._transaction is not None: - try: - # Build queries list for pipeline: [(query, params), ...] - queries: list[tuple[str, list[t.Any] | None]] = [ - (operation, params) for params in converted_seq - ] - await adapt_connection._transaction.pipeline( - queries, prepared=True - ) - return None + await self._connection.execute(query, flat) + self._rowcount = len(converted_seq) + return except Exception: - # If pipeline fails, fall back to execute_many pass - # Fallback: use standard execute_many with prepared statements - return await self._connection.execute_many( - operation, - converted_seq, - prepared=True, + await self._connection.execute_many( + operation, converted_seq, prepared=True ) + self._rowcount = len(converted_seq) def execute( self, operation: t.Any, parameters: t.Sequence[t.Any] | t.Mapping[str, Any] | None = None, ) -> None: - self.await_(self._prepare_execute(operation, parameters)) + # Auto-detect batch operations: if parameters is a list of dicts/tuples, + # treat it as executemany for better performance + if ( + isinstance(parameters, list) + and len(parameters) > 1 + and all(isinstance(p, dict | tuple) for p in parameters) + ): + self.await_(self._executemany(operation, parameters)) + else: + self.await_(self._prepare_execute(operation, parameters)) def executemany( self, operation: t.Any, seq_of_parameters: t.Sequence[t.Any] @@ -678,35 +591,26 @@ def set_isolation_level(self, level: t.Any) -> None: def rollback(self) -> None: """Rollback the current transaction.""" - try: - if self._transaction is not None: + if self._transaction is not None: + try: await_only(self._transaction.rollback()) - else: - await_only(self._connection.rollback()) # type: ignore[attr-defined] - except Exception: - self._connection_valid = False - # Ignore rollback errors as connection might be in bad state - pass - finally: - self._transaction = None - self._started = False + except Exception: + self._connection_valid = False + self._transaction = None + self._started = False def commit(self) -> None: """Commit the current transaction.""" - try: - if self._transaction is not None: + if self._transaction is not None: + try: await_only(self._transaction.commit()) - else: - await_only(self._connection.commit()) # type: ignore[attr-defined] - except Exception as e: - self._connection_valid = False - # On commit failure, try to rollback - with contextlib.suppress(Exception): - self.rollback() - raise e - finally: - self._transaction = None - self._started = False + except Exception as e: + self._connection_valid = False + self._transaction = None + self._started = False + raise e + self._transaction = None + self._started = False def is_valid(self) -> bool: """Check if connection is valid""" diff --git a/tests/test_connection.py b/tests/test_connection.py index 21de09b..b9f6eb7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -671,16 +671,16 @@ def test_rollback_with_transaction(self, mock_await_only): self.assertIsNone(self.connection._transaction) self.assertFalse(self.connection._started) - @patch("psqlpy_sqlalchemy.connection.await_only") - def test_rollback_without_transaction(self, mock_await_only): - """Test rollback without active transaction""" + def test_rollback_without_transaction(self): + """Test rollback without active transaction - should be no-op""" self.connection._transaction = None + self.connection._started = True self.connection.rollback() - mock_await_only.assert_called_once_with( - self.mock_connection.rollback() - ) + # Without transaction, rollback should be a no-op + self.assertIsNone(self.connection._transaction) + self.assertFalse(self.connection._started) @patch("psqlpy_sqlalchemy.connection.await_only") def test_rollback_with_exception(self, mock_await_only): @@ -709,14 +709,16 @@ def test_commit_with_transaction(self, mock_await_only): self.assertIsNone(self.connection._transaction) self.assertFalse(self.connection._started) - @patch("psqlpy_sqlalchemy.connection.await_only") - def test_commit_without_transaction(self, mock_await_only): - """Test commit without active transaction""" + def test_commit_without_transaction(self): + """Test commit without active transaction - should be no-op""" self.connection._transaction = None + self.connection._started = True self.connection.commit() - mock_await_only.assert_called_once_with(self.mock_connection.commit()) + # Without transaction, commit should be a no-op + self.assertIsNone(self.connection._transaction) + self.assertFalse(self.connection._started) @patch("psqlpy_sqlalchemy.connection.await_only") def test_commit_with_exception(self, mock_await_only):