From 000c4e73f79a62c9cdc7060c5d8df1ad7d832445 Mon Sep 17 00:00:00 2001 From: Kenneth Geisshirt Date: Wed, 25 Feb 2026 16:44:31 +0100 Subject: [PATCH 1/3] Support SQL Alchemy 2.1 --- .github/workflows/tests.yml | 8 +- pyproject.toml | 2 +- src/sqlalchemy_cratedb/compat/core21.py | 425 +++++++++++++++++++++ src/sqlalchemy_cratedb/compiler.py | 3 + src/sqlalchemy_cratedb/dialect.py | 7 +- src/sqlalchemy_cratedb/sa_version.py | 1 + src/sqlalchemy_cratedb/support/polyfill.py | 3 +- tests/create_table_test.py | 12 +- tests/dict_test.py | 97 ++++- tests/test_error_handling.py | 4 +- tests/update_test.py | 3 +- 11 files changed, 541 insertions(+), 24 deletions(-) create mode 100644 src/sqlalchemy_cratedb/compat/core21.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4d043a43..497eefba 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: os: ['ubuntu-22.04'] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13', '3.14'] cratedb-version: ['nightly'] - sqla-version: ['<1.4', '<1.5', '<2.1'] + sqla-version: ['<1.4', '<1.5', '<2.1', '<2.2'] pip-allow-prerelease: ['false'] exclude: @@ -34,6 +34,12 @@ jobs: - python-version: '3.14' sqla-version: '<1.4' + # SQLAlchemy 2.1 requires Python 3.9+. + - python-version: '3.7' + sqla-version: '<2.2' + - python-version: '3.8' + sqla-version: '<2.2' + # Another CI test matrix slot to test against prerelease versions of Python packages. include: - os: 'ubuntu-latest' diff --git a/pyproject.toml b/pyproject.toml index 0e18c01e..9215e180 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "geojson>=2.5,<4", "importlib-metadata; python_version<'3.8'", "importlib-resources; python_version<'3.9'", - "sqlalchemy>=1,<2.1", + "sqlalchemy>=1,<2.2", "verlib2<0.4", ] optional-dependencies.all = [ diff --git a/src/sqlalchemy_cratedb/compat/core21.py b/src/sqlalchemy_cratedb/compat/core21.py new file mode 100644 index 00000000..aa873c7b --- /dev/null +++ b/src/sqlalchemy_cratedb/compat/core21.py @@ -0,0 +1,425 @@ +# -*- coding: utf-8; -*- +# +# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor +# license agreements. See the NOTICE file distributed with this work for +# additional information regarding copyright ownership. Crate licenses +# this file to you under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. You may +# obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# However, if you have executed another commercial license agreement +# with Crate these terms will supersede the license and you may use the +# software solely pursuant to the terms of the relevant commercial agreement. + +# ruff: noqa: S101 # Use of `assert` detected + +from typing import Any, Dict, List, MutableMapping, Optional, Tuple, Union + +import sqlalchemy as sa +from sqlalchemy import ColumnClause, ValuesBase, cast, exc +from sqlalchemy.sql import dml +from sqlalchemy.sql.base import _from_objects +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.crud import ( + REQUIRED, + _as_dml_column, + _create_bind_param, + _CrudParamElement, + _CrudParams, + _extend_values_for_multiparams, + _get_stmt_parameter_tuples_params, + _get_update_multitable_params, + _key_getters_for_crud_column, + _scan_cols, + _scan_insert_from_select_cols, + _setup_delete_return_defaults, +) +from sqlalchemy.sql.dml import DMLState, _DMLColumnElement +from sqlalchemy.sql.dml import isinsert as _compile_state_isinsert + +from sqlalchemy_cratedb.compiler import CrateCompiler + + +class CrateCompilerSA21(CrateCompiler): + def visit_update(self, update_stmt, visiting_cte=None, **kw): + compile_state = update_stmt._compile_state_factory(update_stmt, self, **kw) + update_stmt = compile_state.statement + + # [21] CrateDB patch. + if not compile_state._dict_parameters and not hasattr(update_stmt, "_crate_specific"): + return super().visit_update(update_stmt, visiting_cte=visiting_cte, **kw) + + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + + if toplevel: + self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + extra_froms = compile_state._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(_from_objects(update_stmt.table)) + render_extra_froms = [f for f in extra_froms if f not in main_froms] + correlate_froms = main_froms.union(extra_froms) + else: + render_extra_froms = [] + correlate_froms = {update_stmt.table} + + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes(update_stmt, update_stmt._prefixes, **kw) + + table_text = self.update_tables_clause( + update_stmt, update_stmt.table, render_extra_froms, **kw + ) + # [21] CrateDB patch. + crud_params_struct = _get_crud_params(self, update_stmt, compile_state, toplevel, **kw) + crud_params = crud_params_struct.single_params + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints(update_stmt, table_text) + else: + dialect_hints = None + + if update_stmt._independent_ctes: + self._dispatch_independent_ctes(update_stmt, kw) + + text += table_text + + text += " SET " + + # [21] CrateDB patch begin. + include_table = extra_froms and self.render_table_with_column_in_update_from + + set_clauses = [] + + for c, expr, value, _ in crud_params: # noqa: B007 + key = c._compiler_dispatch(self, include_table=include_table) + clause = key + " = " + value + set_clauses.append(clause) + + for k, v in compile_state._dict_parameters.items(): + if isinstance(k, str) and "[" in k: + bindparam = sa.sql.bindparam(k, v) + clause = k + " = " + self.process(bindparam) + set_clauses.append(clause) + + text += ", ".join(set_clauses) + # [21] CrateDB patch end. + + if self.implicit_returning or update_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + dialect_hints, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._where_criteria: + t = self._generate_delimited_and_list(update_stmt._where_criteria, **kw) + if t: + text += " WHERE " + t + + limit_clause = self.update_post_criteria_clause(update_stmt, **kw) + if limit_clause: + text += " " + limit_clause + + if ( + self.implicit_returning or update_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + toplevel: bool, + **kw: Any, +) -> _CrudParams: + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the CursorResult's prefetch_cols() and postfetch_cols() + collections. + + """ + + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + + compiler.postfetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] + compiler.implicit_returning = [] + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._get_bind_name_for_col = _col_bind_name + + if stmt._returning and stmt._return_defaults: + raise exc.CompileError( + "Can't compile statement that includes returning() and return_defaults() simultaneously" + ) + + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and compile_state._no_parameters: + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + (c.key,), + ) + for c in stmt.table.columns + ], + [], + ) + + stmt_parameter_tuples: Optional[List[Tuple[Union[str, ColumnClause[Any]], Any]]] + spd: Optional[MutableMapping[_DMLColumnElement, Any]] + + if _compile_state_isinsert(compile_state) and compile_state._has_multi_parameters: + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] + stmt_parameter_tuples = list(spd.items()) + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) + else: + stmt_parameter_tuples = spd = None + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + elif stmt_parameter_tuples: + assert spd is not None + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys if key not in spd + } + else: + parameters = {_column_as_key(key): REQUIRED for key in compiler.column_keys} + + # create a list of column assignment clauses as tuples + values: List[_CrudParamElement] = [] + + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, + ) + + check_columns: Dict[str, ColumnClause[Any]] = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if dml.isupdate(compile_state) and compile_state.is_multitable: + _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) + + if _compile_state_isinsert(compile_state) and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + + _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + else: + _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + + # [21] CrateDB patch. + # + # This sanity check performed by SQLAlchemy currently needs to be + # deactivated in order to satisfy the rewriting logic of the CrateDB + # dialect in `rewrite_update` and `visit_update`. + # + # It can be quickly reproduced by activating this section and running the + # test cases:: + # + # ./bin/test -vvvv -t dict_test + # + # That croaks like:: + # + # sqlalchemy.exc.CompileError: Unconsumed column names: characters_name + # + # TODO: Investigate why this is actually happening and eventually mitigate + # the root cause. + """ + if parameters and stmt_parameter_tuples: + check = ( + set(parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) + .difference(check_columns) + ) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" + % (", ".join("%s" % (c,) for c in check)) + ) + """ + + if _compile_state_isinsert(compile_state) and compile_state._has_multi_parameters: + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + cast( + "Sequence[_CrudParamElementStr]", + values, + ), + cast("Callable[..., str]", _column_as_key), + kw, + ) + return _CrudParams(values, multi_extended_values) + elif not values and compiler.for_executemany and compiler.dialect.supports_default_metavalue: + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [ + ( + _as_dml_column(stmt.table.columns[0]), + compiler.preparer.format_column(stmt.table.columns[0]), + compiler.dialect.default_metavalue_token, + (), + ) + ] + + return _CrudParams(values, []) diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 7b2c5ccd..851e8ebb 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -214,6 +214,9 @@ def visit_TEXT(self, type_, **kw): def visit_DECIMAL(self, type_, **kw): return "DOUBLE" + def visit_double(self, type_, **kw): + return "DOUBLE" + def visit_BIGINT(self, type_, **kw): return "LONG" diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 90102a78..b780ffd7 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -32,7 +32,7 @@ CrateIdentifierPreparer, CrateTypeCompiler, ) -from .sa_version import SA_1_4, SA_2_0, SA_VERSION +from .sa_version import SA_1_4, SA_2_0, SA_2_1, SA_VERSION from .type import FloatVector, ObjectArray, ObjectType TYPES_MAP = { @@ -160,8 +160,11 @@ def process(value): sqltypes.TIMESTAMP: DateTime, } +if SA_VERSION >= SA_2_1: + from .compat.core21 import CrateCompilerSA21 -if SA_VERSION >= SA_2_0: + statement_compiler = CrateCompilerSA21 +elif SA_VERSION >= SA_2_0: from .compat.core20 import CrateCompilerSA20 statement_compiler = CrateCompilerSA20 diff --git a/src/sqlalchemy_cratedb/sa_version.py b/src/sqlalchemy_cratedb/sa_version.py index 22f31e51..1b9addce 100644 --- a/src/sqlalchemy_cratedb/sa_version.py +++ b/src/sqlalchemy_cratedb/sa_version.py @@ -26,3 +26,4 @@ SA_1_4 = Version("1.4.0b1") SA_2_0 = Version("2.0.0") +SA_2_1 = Version("2.1.0b1") diff --git a/src/sqlalchemy_cratedb/support/polyfill.py b/src/sqlalchemy_cratedb/support/polyfill.py index 22dad7ce..13c040f5 100644 --- a/src/sqlalchemy_cratedb/support/polyfill.py +++ b/src/sqlalchemy_cratedb/support/polyfill.py @@ -56,11 +56,10 @@ def check_uniqueness(mapper, connection, target): stmt = stmt.filter( getattr(sa_entity, attribute_name) == getattr(target, attribute_name) ) - stmt = stmt.compile(bind=connection.engine) results = connection.execute(stmt) if results.rowcount > 0: raise IntegrityError( - statement=stmt, + statement=str(stmt), params=[], orig=Exception( f"DuplicateKeyException in table '{target.__tablename__}' " diff --git a/tests/create_table_test.py b/tests/create_table_test.py index ac4b2dda..cb147f30 100644 --- a/tests/create_table_test.py +++ b/tests/create_table_test.py @@ -26,12 +26,13 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base -from unittest import TestCase +from unittest import TestCase, skipIf from unittest.mock import MagicMock, patch from crate.client.cursor import Cursor from sqlalchemy_cratedb import Geopoint, ObjectArray, ObjectType +from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION fake_cursor = MagicMock(name="fake_cursor") FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) @@ -383,3 +384,12 @@ class DummyTable(self.Base): ), (), ) + + @skipIf(SA_VERSION < SA_2_0, "sa.Double was introduced in SA 2.0") + def test_visit_double(self): + """ + Verify ``CrateTypeCompiler.visit_double()`` compiles ``sa.Double`` + to the CrateDB ``DOUBLE`` type keyword. + """ + result = sa.Double().compile(dialect=self.engine.dialect) + self.assertEqual(str(result), "DOUBLE") diff --git a/tests/dict_test.py b/tests/dict_test.py index 13e9b835..359e4dd8 100644 --- a/tests/dict_test.py +++ b/tests/dict_test.py @@ -21,6 +21,7 @@ from __future__ import absolute_import +import re from unittest import TestCase, skipIf from unittest.mock import MagicMock, patch @@ -29,14 +30,17 @@ from sqlalchemy.sql import select try: - from sqlalchemy.orm import declarative_base + from sqlalchemy.orm import Mapped, declarative_base, mapped_column except ImportError: from sqlalchemy.ext.declarative import declarative_base + Mapped = None + mapped_column = None + from crate.client.cursor import Cursor from sqlalchemy_cratedb import ObjectArray, ObjectType -from sqlalchemy_cratedb.sa_version import SA_1_4, SA_VERSION +from sqlalchemy_cratedb.sa_version import SA_1_4, SA_2_1, SA_VERSION fake_cursor = MagicMock(name="fake_cursor") FakeCursor = MagicMock(name="FakeCursor", spec=Cursor) @@ -96,21 +100,83 @@ def test_update_with_dict_column(self): self.assertSQL("UPDATE mytable SET data['x'] = ? WHERE mytable.name = ?", stmt) def set_up_character_and_cursor(self, return_value=None): - return_value = return_value or [("Trillian", {})] - fake_cursor.fetchall.return_value = return_value - fake_cursor.description = ( - ("characters_name", None, None, None, None, None, None), - ("characters_data", None, None, None, None, None, None), - ) + """ + Set up a ``Character`` model and a fake cursor, compatible with all + supported SQLAlchemy versions. + + **SA 2.1+** may issue SELECTs for different subsets of columns + depending on which attributes have been expired after a flush. A + dynamic ``execute`` side-effect is installed on the fake cursor so that + ``cursor.description`` and ``cursor.fetchall`` are adjusted to return + exactly the columns present in each SELECT statement. The model uses + ``mapped_column`` / ``Mapped`` annotations available since SA 2.0. + + **SA < 2.1** always selects the same fixed set of columns after a + flush. A static cursor mock with a two-column description + (``name``, ``data``) is used, matching the behaviour of those + versions. The model uses plain ``sa.Column`` declarations. + """ fake_cursor.rowcount = 1 Base = declarative_base() - class Character(Base): - __tablename__ = "characters" - name = sa.Column(sa.String, primary_key=True) - age = sa.Column(sa.Integer) - data = sa.Column(ObjectType) - data_list = sa.Column(ObjectArray) + if SA_VERSION >= SA_2_1: + # SA 2.1 may fire SELECTs for different subsets of columns depending + # on what is expired. Use a dynamic mock that adjusts description and + # return value to match exactly the columns present in each SELECT. + data_rows = return_value or [("Trillian", {})] + col_order = [ + "characters_name", + "characters_age", + "characters_data", + "characters_data_list", + ] + full_rows = [(r[0], None, r[1], None) for r in data_rows] + + def _col_in_sql(col, sql): + # Match 'characters.' where attr does not bleed into another column name. + attr = col.replace("characters_", "") + return bool(re.search(rf"characters\.{attr}(?!\w)", sql)) + + def execute_side_effect(sql, *args, **kwargs): + sql_str = str(sql) + if "SELECT" in sql_str.upper(): + cols = [c for c in col_order if _col_in_sql(c, sql_str)] + if cols: + indices = [col_order.index(c) for c in cols] + fake_cursor.description = tuple( + (c, None, None, None, None, None, None) for c in cols + ) + fake_cursor.fetchall.return_value = [ + tuple(row[i] for i in indices) for row in full_rows + ] + + fake_cursor.execute.side_effect = execute_side_effect + # Set defaults so non-SELECT calls (INSERT/UPDATE) don't need description. + fake_cursor.fetchall.return_value = [] + fake_cursor.description = () + + class Character(Base): + __tablename__ = "characters" + name: Mapped[str] = mapped_column(primary_key=True) + age = sa.Column(sa.Integer) + data = sa.Column(ObjectType) + data_list = sa.Column(ObjectArray) + + else: + # Older SA always selects a fixed set of columns; use a static mock. + fake_cursor.execute.side_effect = None + fake_cursor.fetchall.return_value = return_value or [("Trillian", {})] + fake_cursor.description = ( + ("characters_name", None, None, None, None, None, None), + ("characters_data", None, None, None, None, None, None), + ) + + class Character(Base): + __tablename__ = "characters" + name = sa.Column(sa.String, primary_key=True) + age = sa.Column(sa.Integer) + data = sa.Column(ObjectType) + data_list = sa.Column(ObjectArray) session = Session(bind=self.engine) return session, Character @@ -301,6 +367,9 @@ def test_partial_dict_update_with_setitem_delitem_setitem(self): def set_up_character_and_cursor_data_list(self, return_value=None): return_value = return_value or [("Trillian", {})] + # Clear any side_effect installed by set_up_character_and_cursor so the + # static description below is used as-is for this 2-column model. + fake_cursor.execute.side_effect = None fake_cursor.fetchall.return_value = return_value fake_cursor.description = ( ("characters_name", None, None, None, None, None, None), diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 58ee499e..8a342033 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -14,7 +14,7 @@ def test_statement_with_error_trace(cratedb_service): connection.execute(sa.text("CREATE TABLE foo AS SELECT 1 AS _id")) # Make sure both variants match, to validate it's actually an error trace. - assert ex.match(re.escape('InvalidColumnNameException["_id" conflicts with system column]')) assert ex.match( - 'io.crate.exceptions.InvalidColumnNameException: "_id" conflicts with system column' + re.escape('InvalidColumnNameException["_id" conflicts with system column pattern]') ) + assert ex.match('InvalidColumnNameException: "_id" conflicts with system column pattern') diff --git a/tests/update_test.py b/tests/update_test.py index 107979b3..632aa25f 100644 --- a/tests/update_test.py +++ b/tests/update_test.py @@ -65,9 +65,10 @@ def test_onupdate_is_triggered(self): self.session.commit() now = datetime.utcnow() - fake_cursor.fetchall.return_value = [("Arthur", None)] + fake_cursor.fetchall.return_value = [("Arthur", None, None)] fake_cursor.description = ( ("characters_name", None, None, None, None, None, None), + ("characters_obj", None, None, None, None, None, None), ("characters_ts", None, None, None, None, None, None), ) From eedb9794690dfe24fb00fe3177241124fe96b49f Mon Sep 17 00:00:00 2001 From: Kenneth Geisshirt Date: Tue, 10 Mar 2026 15:43:39 +0100 Subject: [PATCH 2/3] Add CLAUDE.md --- CLAUDE.md | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..d980e998 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +`sqlalchemy-cratedb` is a SQLAlchemy dialect for CrateDB, a distributed SQL database. It supports SQLAlchemy 1.3 through 2.1 (with ongoing 2.1 compatibility work on the current branch). + +## Development Setup + +```bash +source bootstrap.sh # Creates .venv with Python 3.11, installs all deps in editable mode +``` + +Environment variables that influence bootstrap: +- `CRATEDB_VERSION` (default: `5.5.1`) — CrateDB Docker image version +- `SQLALCHEMY_VERSION` (default: `<2.2`) — SQLAlchemy version constraint +- `PIP_ALLOW_PRERELEASE=true` — allow pre-release packages + +## Common Commands + +```bash +poe format # Auto-format code (ruff + black) +poe lint # Run linters (ruff, validate-pyproject) +poe test # Run pytest + integration tests +poe check # lint + test combined + +# Run specific tests +pytest tests/dict_test.py +pytest -k SqlAlchemyCompilerTest +pytest -k test_score + +# Run integration/doctests +python -m unittest -vvv tests/integration.py +``` + +Tests require a live CrateDB instance via Docker (managed automatically by `cratedb_toolkit.testing.testcontainers`). + +## Architecture + +### Source layout (`src/sqlalchemy_cratedb/`) + +- **`dialect.py`** — Core dialect: type mappings, Date/DateTime handling, schema reflection +- **`compiler.py`** — SQL/DDL compilation: `CrateDDLCompiler`, `CrateTypeCompiler`, `CrateIdentifierPreparer`, and `rewrite_update()` for partial object updates +- **`predicate.py`** — `match()` predicate for full-text search +- **`sa_version.py`** — Version detection; exports `SA_VERSION`, `SA_1_4`, `SA_2_0`, `SA_2_1` constants +- **`compat/`** — Multi-version SQLAlchemy compatibility: `core10.py`, `core14.py`, `core20.py`, `core21.py`, `api13.py` +- **`type/`** — Custom CrateDB types: `ObjectType` (JSON objects), `ObjectArray`, `FloatVector`, `Geopoint`, `Geoshape` +- **`support/`** — Integrations and polyfills: `pandas.py` (bulk insert), `polyfill.py` (refresh-after-DML, uniqueness, autoincrement timestamps), `util.py` + +### Key architectural patterns + +**Multi-version compatibility:** The `compat/` directory contains separate modules for each major SQLAlchemy version. `sa_version.py` detects the installed version at runtime using `verlib2`, and code conditionally imports from the appropriate compat module. When adding features, check whether they need version-specific handling. + +**Custom types:** CrateDB types (ObjectType, FloatVector, etc.) implement SQLAlchemy's bind/result processor pattern — `bind_processor()` converts Python → SQL, `result_processor()` converts SQL → Python. The `CrateTypeCompiler` generates the SQL type strings. + +**Update rewriting:** `compiler.py::rewrite_update()` transforms partial dictionary updates on `ObjectType` columns into CrateDB's subscript assignment syntax (e.g., `obj['key'] = value`). + +**Polyfills:** `support/polyfill.py` monkey-patches SQLAlchemy internals to add features CrateDB doesn't natively support (e.g., `refresh_after_dml`, `uniqueness_strategy`). + +### Testing + +Tests in `tests/` follow two patterns: +- `*_test.py` files: unit/integration tests using pytest with a live CrateDB instance +- `tests/integration.py`: doctests for documentation examples, run with `unittest` + +The `conftest.py` provides a session-scoped `cratedb_service` fixture that starts CrateDB via Docker containers. + +## Code Style + +- Line length: 100 characters (ruff + black) +- Ruff rules enforced: A, B, C4, E, ERA, F, I, PD, RET, S, T20, W, YTT +- Mypy strict mode is configured but not always enforced in CI From aa3376cc83cc017ad1666212d55db19d0738785f Mon Sep 17 00:00:00 2001 From: Kenneth Geisshirt Date: Tue, 10 Mar 2026 16:09:13 +0100 Subject: [PATCH 3/3] adjust expected value in test --- tests/test_error_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_error_handling.py b/tests/test_error_handling.py index 8a342033..58ee499e 100644 --- a/tests/test_error_handling.py +++ b/tests/test_error_handling.py @@ -14,7 +14,7 @@ def test_statement_with_error_trace(cratedb_service): connection.execute(sa.text("CREATE TABLE foo AS SELECT 1 AS _id")) # Make sure both variants match, to validate it's actually an error trace. + assert ex.match(re.escape('InvalidColumnNameException["_id" conflicts with system column]')) assert ex.match( - re.escape('InvalidColumnNameException["_id" conflicts with system column pattern]') + 'io.crate.exceptions.InvalidColumnNameException: "_id" conflicts with system column' ) - assert ex.match('InvalidColumnNameException: "_id" conflicts with system column pattern')