diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b2e77d..021e7a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.3.11] + +### Added + +- Added REST API and Worker integration coverage for composite primary key DDL and constraint reflection + +### Fixed + +- Fixed composite primary key DDL emitting duplicate `PRIMARY KEY` clauses, which D1 rejects with `SQLITE_ERROR` +- Fixed foreign key reflection to include SQLAlchemy's expected `referred_schema` key +- Implemented unique constraint reflection for D1 tables, including inline and named table-level unique constraints + + ## [0.3.10] ### Added diff --git a/examples/workers/src/entry.py b/examples/workers/src/entry.py index 5ae0ca3..c6b37f8 100644 --- a/examples/workers/src/entry.py +++ b/examples/workers/src/entry.py @@ -45,6 +45,8 @@ async def fetch(self, request, env): return await self.test_sqlalchemy_composite_pk() elif path == "sqlalchemy-reflect": return await self.test_sqlalchemy_reflect() + elif path == "sqlalchemy-reflect-constraints": + return await self.test_sqlalchemy_reflect_constraints() # Empty result set tests (GitHub issue #4) elif path == "empty-result": return await self.test_empty_result() @@ -192,6 +194,7 @@ async def index(self): "/sqlalchemy-crud": "Test SQLAlchemy Core CRUD (no raw SQL)", "/sqlalchemy-composite-pk": "Test SQLAlchemy composite primary key DDL", "/sqlalchemy-reflect": "Test SQLAlchemy table reflection", + "/sqlalchemy-reflect-constraints": "Test SQLAlchemy constraint reflection", "/empty-result": "Test empty result set description (issue #4)", "/empty-result-sqlalchemy": "Test SQLAlchemy empty result (issue #4)", "/json-filter": "Test filtering on JSON array columns", @@ -701,6 +704,100 @@ async def test_sqlalchemy_reflect(self): status=500, ) + async def test_sqlalchemy_reflect_constraints(self): + """Test SQLAlchemy foreign key and unique constraint reflection.""" + parent_table_name = f"test_reflect_parent_{uuid.uuid4().hex[:8]}" + child_table_name = f"test_reflect_child_{uuid.uuid4().hex[:8]}" + unique_constraint_name = f"uq_{uuid.uuid4().hex[:12]}" + + try: + from sqlalchemy import ( + Column, + ForeignKey, + Integer, + MetaData, + String, + Table, + UniqueConstraint, + inspect, + ) + + engine = self.get_engine() + metadata = MetaData() + + Table( + parent_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("slug", String, unique=True), + ) + Table( + child_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey(f"{parent_table_name}.id")), + Column("tenant_id", String), + Column("record_key", String), + UniqueConstraint( + "tenant_id", + "record_key", + name=unique_constraint_name, + ), + ) + + metadata.create_all(engine) + + inspector = inspect(engine) + foreign_keys = inspector.get_foreign_keys(child_table_name) + unique_constraints = inspector.get_unique_constraints(child_table_name) + parent_unique_constraints = inspector.get_unique_constraints( + parent_table_name + ) + + metadata.drop_all(engine) + + expected_child_unique = { + "name": unique_constraint_name, + "column_names": ["tenant_id", "record_key"], + } + expected_parent_unique = {"name": None, "column_names": ["slug"]} + foreign_key = foreign_keys[0] if foreign_keys else {} + success = ( + foreign_key.get("constrained_columns") == ["parent_id"] + and foreign_key.get("referred_schema") is None + and foreign_key.get("referred_table") == parent_table_name + and foreign_key.get("referred_columns") == ["id"] + and expected_child_unique in unique_constraints + and expected_parent_unique in parent_unique_constraints + ) + + return Response.json( + { + "test": "sqlalchemy_reflect_constraints", + "success": success, + "foreign_keys": foreign_keys, + "unique_constraints": unique_constraints, + "parent_unique_constraints": parent_unique_constraints, + } + ) + except Exception as e: + try: + conn = self.get_connection() + cursor = conn.cursor() + await cursor.execute_async(f"DROP TABLE IF EXISTS {child_table_name}") + await cursor.execute_async(f"DROP TABLE IF EXISTS {parent_table_name}") + conn.close() + except Exception: + pass + return Response.json( + { + "test": "sqlalchemy_reflect_constraints", + "success": False, + "error": str(e), + }, + status=500, + ) + # MARK: - Empty Result Set Tests (GitHub issue #4) async def test_empty_result(self): diff --git a/examples/workers/uv.lock b/examples/workers/uv.lock index 0b743fb..6ea178e 100644 --- a/examples/workers/uv.lock +++ b/examples/workers/uv.lock @@ -614,7 +614,7 @@ wheels = [ [[package]] name = "sqlalchemy-cloudflare-d1" -version = "0.3.10" +version = "0.3.11" source = { editable = "../../" } dependencies = [ { name = "httpx" }, diff --git a/pyproject.toml b/pyproject.toml index 39aea16..2028aaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlalchemy-cloudflare-d1" -version = "0.3.10" +version = "0.3.11" description = "A SQLAlchemy dialect for Cloudflare's D1 Serverless SQLite Database" readme = "README.md" authors = [ diff --git a/src/sqlalchemy_cloudflare_d1/dialect.py b/src/sqlalchemy_cloudflare_d1/dialect.py index f88181d..4725cec 100644 --- a/src/sqlalchemy_cloudflare_d1/dialect.py +++ b/src/sqlalchemy_cloudflare_d1/dialect.py @@ -4,6 +4,7 @@ import base64 import enum as enum_module +import re import uuid as uuid_module from datetime import date, datetime, time from typing import Any, Callable, Dict, List, Optional @@ -627,6 +628,7 @@ def get_foreign_keys( # type: ignore[override] fks[fk_id] = { "name": None, "constrained_columns": [], + "referred_schema": schema, "referred_table": row[2], "referred_columns": [], "options": {"onupdate": row[5], "ondelete": row[6]}, @@ -637,6 +639,42 @@ def get_foreign_keys( # type: ignore[override] return list(fks.values()) + def get_unique_constraints( # type: ignore[override] + self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[Dict[str, Any]]: + """Get unique constraint information.""" + query = text( + f"PRAGMA index_list({self.identifier_preparer.quote_identifier(table_name)})" + ) + result = connection.execute(query) + + unique_constraints_by_sig: Dict[tuple[str, ...], Dict[str, Any]] = {} + for row in result: + # PRAGMA index_list returns: seq, name, unique, origin, partial + if not bool(row[2]): + continue + + origin = row[3] if len(row) > 3 else None + if origin in {"c", "pk"}: + # Explicit unique indexes are reflected by get_indexes(); primary + # key autoindexes are reflected by get_pk_constraint(). + continue + + index_name = row[1] + column_names = self._get_index_column_names(connection, index_name) + unique_constraints_by_sig[tuple(column_names)] = { + "name": None, + "column_names": column_names, + } + + table_sql = self._get_table_sql(connection, table_name) + for constraint_name, column_names in self._parse_unique_constraints(table_sql): + constraint = unique_constraints_by_sig.get(tuple(column_names)) + if constraint is not None: + constraint["name"] = constraint_name + + return list(unique_constraints_by_sig.values()) + def get_indexes( # type: ignore[override] self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[Dict[str, Any]]: @@ -647,22 +685,15 @@ def get_indexes( # type: ignore[override] result = connection.execute(query) indexes = [] + include_auto_indexes = kw.get("include_auto_indexes", False) for row in result: # PRAGMA index_list returns: seq, name, unique, origin, partial index_name = row[1] - if index_name.startswith("sqlite_autoindex_"): + if index_name.startswith("sqlite_autoindex_") and not include_auto_indexes: continue # Skip auto-generated indexes # Get column information for this index - col_query = text( - f"PRAGMA index_info({self.identifier_preparer.quote_identifier(index_name)})" - ) - col_result = connection.execute(col_query) - - column_names = [] - for col_row in col_result: - # PRAGMA index_info returns: seqno, cid, name - column_names.append(col_row[2]) + column_names = self._get_index_column_names(connection, index_name) indexes.append( { @@ -673,3 +704,63 @@ def get_indexes( # type: ignore[override] ) return indexes + + def _get_index_column_names(self, connection: Any, index_name: str) -> List[str]: + """Get column names for an index.""" + col_query = text( + f"PRAGMA index_info({self.identifier_preparer.quote_identifier(index_name)})" + ) + col_result = connection.execute(col_query) + + column_names = [] + for col_row in col_result: + # PRAGMA index_info returns: seqno, cid, name + column_names.append(col_row[2]) + + return column_names + + def _get_table_sql(self, connection: Any, table_name: str) -> Optional[str]: + """Get the CREATE TABLE SQL stored in sqlite_master.""" + query = text(""" + SELECT sql FROM sqlite_master + WHERE type='table' AND name=:table_name + """) + row = connection.execute(query, {"table_name": table_name}).fetchone() + return row[0] if row is not None else None + + def _parse_unique_constraints( + self, table_sql: Optional[str] + ) -> List[tuple[Optional[str], List[str]]]: + """Parse unique constraint names and column lists from CREATE TABLE SQL.""" + if table_sql is None: + return [] + + unique_constraints = [] + unique_pattern = re.compile( + r'(?:CONSTRAINT\s+(?:"([^"]+)"|`([^`]+)`|\[([^\]]+)\]|(\w+))\s+)?' + r"UNIQUE\s*\(([^)]+)\)", + re.IGNORECASE, + ) + + for match in unique_pattern.finditer(table_sql): + constraint_name = next( + (group for group in match.group(1, 2, 3, 4) if group), None + ) + column_names = self._parse_column_list(match.group(5)) + unique_constraints.append((constraint_name, column_names)) + + return unique_constraints + + def _parse_column_list(self, column_list_sql: str) -> List[str]: + """Parse a comma-separated identifier list.""" + column_names = [] + for raw_column in column_list_sql.split(","): + column = raw_column.strip() + if ( + (column.startswith('"') and column.endswith('"')) + or (column.startswith("`") and column.endswith("`")) + or (column.startswith("[") and column.endswith("]")) + ): + column = column[1:-1] + column_names.append(column) + return column_names diff --git a/tests/integration/test_restapi_integration.py b/tests/integration/test_restapi_integration.py index 819da68..d8ecf6f 100644 --- a/tests/integration/test_restapi_integration.py +++ b/tests/integration/test_restapi_integration.py @@ -457,6 +457,62 @@ def test_engine_create_table_with_composite_primary_key( finally: metadata.drop_all(d1_engine) + def test_engine_reflects_foreign_keys_and_unique_constraints( + self, d1_engine, test_table_name + ): + """Test SQLAlchemy reflection for D1 foreign keys and unique constraints.""" + from sqlalchemy import ForeignKey, UniqueConstraint, inspect + + metadata = MetaData() + parent_table_name = f"{test_table_name}_parent" + child_table_name = f"{test_table_name}_child" + unique_constraint_name = f"uq_{test_table_name}_tenant_record" + + Table( + parent_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("slug", String, unique=True), + ) + Table( + child_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("parent_id", Integer, ForeignKey(f"{parent_table_name}.id")), + Column("tenant_id", String), + Column("record_key", String), + UniqueConstraint( + "tenant_id", + "record_key", + name=unique_constraint_name, + ), + ) + + metadata.create_all(d1_engine) + + try: + inspector = inspect(d1_engine) + foreign_keys = inspector.get_foreign_keys(child_table_name) + unique_constraints = inspector.get_unique_constraints(child_table_name) + parent_unique_constraints = inspector.get_unique_constraints( + parent_table_name + ) + + assert foreign_keys + foreign_key = foreign_keys[0] + assert foreign_key["constrained_columns"] == ["parent_id"] + assert foreign_key["referred_schema"] is None + assert foreign_key["referred_table"] == parent_table_name + assert foreign_key["referred_columns"] == ["id"] + + assert { + "name": unique_constraint_name, + "column_names": ["tenant_id", "record_key"], + } in unique_constraints + assert {"name": None, "column_names": ["slug"]} in parent_unique_constraints + finally: + metadata.drop_all(d1_engine) + def test_engine_insert_and_select(self, d1_engine, test_table_name): """Test INSERT and SELECT using SQLAlchemy ORM-style.""" metadata = MetaData() diff --git a/tests/integration/test_worker_integration.py b/tests/integration/test_worker_integration.py index 390b175..08e4f6f 100644 --- a/tests/integration/test_worker_integration.py +++ b/tests/integration/test_worker_integration.py @@ -213,6 +213,32 @@ def test_sqlalchemy_reflect(self, dev_server): assert "username" in column_names assert "email" in column_names + def test_sqlalchemy_reflect_constraints(self, dev_server): + """Test SQLAlchemy reflection for foreign keys and unique constraints.""" + port = dev_server + response = requests.get( + f"http://localhost:{port}/sqlalchemy-reflect-constraints" + ) + + assert response.status_code == 200 + data = response.json() + + assert data["test"] == "sqlalchemy_reflect_constraints" + assert data["success"] is True + + assert data["foreign_keys"] + foreign_key = data["foreign_keys"][0] + assert foreign_key["constrained_columns"] == ["parent_id"] + assert foreign_key["referred_schema"] is None + assert foreign_key["referred_columns"] == ["id"] + + child_unique = data["unique_constraints"][0] + assert child_unique["column_names"] == ["tenant_id", "record_key"] + assert child_unique["name"] is not None + assert {"name": None, "column_names": ["slug"]} in data[ + "parent_unique_constraints" + ] + # MARK: - Empty Result Set Tests diff --git a/tests/unit/test_dialect.py b/tests/unit/test_dialect.py index 0b2125b..030c565 100644 --- a/tests/unit/test_dialect.py +++ b/tests/unit/test_dialect.py @@ -205,6 +205,98 @@ def test_create_table_composite_primary_key(): assert "PRIMARY KEY (" in sql.upper(), f"composite PK not table-level in: {sql}" +def test_get_foreign_keys_includes_referred_schema(): + """Test that foreign key reflection includes SQLAlchemy's schema key.""" + from sqlalchemy import create_engine, text + + dialect = CloudflareD1Dialect() + engine = create_engine("sqlite://") + + with engine.begin() as conn: + conn.execute(text("CREATE TABLE parent (id INTEGER PRIMARY KEY)")) + conn.execute( + text(""" + CREATE TABLE child ( + id INTEGER PRIMARY KEY, + parent_id INTEGER, + FOREIGN KEY (parent_id) REFERENCES parent(id) + ) + """) + ) + + foreign_keys = dialect.get_foreign_keys(conn, "child") + + assert foreign_keys == [ + { + "name": None, + "constrained_columns": ["parent_id"], + "referred_schema": None, + "referred_table": "parent", + "referred_columns": ["id"], + "options": {"onupdate": "NO ACTION", "ondelete": "NO ACTION"}, + } + ] + + +def test_get_unique_constraints_reflects_inline_and_named_constraints(): + """Test that unique constraint reflection returns SQLAlchemy's shape.""" + from sqlalchemy import create_engine, text + + dialect = CloudflareD1Dialect() + engine = create_engine("sqlite://") + + with engine.begin() as conn: + conn.execute( + text(""" + CREATE TABLE example ( + id INTEGER PRIMARY KEY, + slug TEXT UNIQUE, + tenant_id TEXT, + record_key TEXT, + CONSTRAINT uq_tenant_record UNIQUE (tenant_id, record_key) + ) + """) + ) + + unique_constraints = dialect.get_unique_constraints(conn, "example") + + assert unique_constraints == [ + {"name": "uq_tenant_record", "column_names": ["tenant_id", "record_key"]}, + {"name": None, "column_names": ["slug"]}, + ] + + +def test_get_unique_constraints_excludes_unique_indexes(): + """Test that unique indexes stay in index reflection, not constraints.""" + from sqlalchemy import create_engine, text + + dialect = CloudflareD1Dialect() + engine = create_engine("sqlite://") + + with engine.begin() as conn: + conn.execute( + text(""" + CREATE TABLE example ( + id INTEGER PRIMARY KEY, + name TEXT + ) + """) + ) + conn.execute(text("CREATE UNIQUE INDEX ix_example_name ON example (name)")) + + unique_constraints = dialect.get_unique_constraints(conn, "example") + indexes = dialect.get_indexes(conn, "example") + + assert unique_constraints == [] + assert indexes == [ + { + "name": "ix_example_name", + "column_names": ["name"], + "unique": True, + } + ] + + def test_async_dialect_import(): """Test that the async dialect can be imported.""" from sqlalchemy_cloudflare_d1 import CloudflareD1Dialect_async diff --git a/uv.lock b/uv.lock index 2af36dc..eee09c8 100644 --- a/uv.lock +++ b/uv.lock @@ -922,7 +922,7 @@ wheels = [ [[package]] name = "sqlalchemy-cloudflare-d1" -version = "0.3.10" +version = "0.3.11" source = { editable = "." } dependencies = [ { name = "httpx" },