From 7c844de3336d4c53eba2f2591040af94559d7b35 Mon Sep 17 00:00:00 2001 From: Collier King Date: Tue, 5 May 2026 15:23:21 -0500 Subject: [PATCH] Add UUID and Enum type processors, fix all mypy errors (v0.3.10) Adds D1UUID and D1Enum type processors so SQLAlchemy uuid.UUID objects and Python enum classes round-trip correctly through D1's TEXT storage. Updates colspecs to map Uuid and Enum to their D1-aware subclasses. Also resolves all pre-existing mypy strict-mode errors across all five source files and adds full type annotations throughout the codebase. --- CHANGELOG.md | 11 + examples/workers/src/entry.py | 499 ++++++++++++++++++ examples/workers/uv.lock | 2 +- pyproject.toml | 2 +- src/sqlalchemy_cloudflare_d1/__init__.py | 4 +- src/sqlalchemy_cloudflare_d1/compiler.py | 88 +-- src/sqlalchemy_cloudflare_d1/connection.py | 56 +- src/sqlalchemy_cloudflare_d1/dialect.py | 101 +++- src/sqlalchemy_cloudflare_d1/dialect_async.py | 70 ++- tests/integration/test_restapi_integration.py | 301 +++++++++++ tests/integration/test_worker_integration.py | 127 +++++ uv.lock | 2 +- 12 files changed, 1149 insertions(+), 114 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f06e571..3b2e77d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.3.10] + +### Added + +- `UUID` and `Enum` column type support ([#24](https://github.com/CollierKing/sqlalchemy-cloudflare-d1/issues/24)) + - Added `D1UUID` type processor that converts `uuid.UUID` objects to hyphenated strings on bind and parses them back on result; supports UUID as primary key and nullable columns + - Added `D1Enum` type processor that converts Python `enum.Enum` objects to their string values on bind and reconstructs them on result; supports both string-value enums and Python enum classes + - Both types store as `TEXT` in D1 (SQLite-compatible) + - Works in both REST API and Worker modes + + ## [0.3.9] ### Added diff --git a/examples/workers/src/entry.py b/examples/workers/src/entry.py index 9e0f85d..4edf443 100644 --- a/examples/workers/src/entry.py +++ b/examples/workers/src/entry.py @@ -145,6 +145,24 @@ async def fetch(self, request, env): return await self.test_time_nullable() elif path == "time-orm": return await self.test_time_orm() + # UUID tests (GitHub issue #24) + elif path == "uuid-basic": + return await self.test_uuid_basic() + elif path == "uuid-nullable": + return await self.test_uuid_nullable() + elif path == "uuid-orm": + return await self.test_uuid_orm() + # Enum tests (GitHub issue #24) + elif path == "enum-basic": + return await self.test_enum_basic() + elif path == "enum-python-class": + return await self.test_enum_python_class() + elif path == "enum-nullable": + return await self.test_enum_nullable() + elif path == "enum-orm": + return await self.test_enum_orm() + elif path == "uuid-pk": + return await self.test_uuid_pk() # Parallel query tests (GitHub issue #20) elif path == "parallel-queries-engine": return await self.test_parallel_queries_engine() @@ -3642,6 +3660,487 @@ class Schedule(Base): status=500, ) + # MARK: - UUID Tests (GitHub issue #24) + + async def test_uuid_basic(self): + """Test UUID column insert and retrieve.""" + import uuid as uuid_module + from sqlalchemy import Column, Integer, MetaData, String, Table, Uuid, select + + table_name = f"test_uuid_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("item_uuid", Uuid(as_uuid=True)), + ) + + metadata.create_all(engine) + + test_uuid = uuid_module.uuid4() + + with engine.connect() as conn: + conn.execute( + test_table.insert().values(title="Test", item_uuid=test_uuid) + ) + conn.commit() + + result = conn.execute( + select(test_table.c.title, test_table.c.item_uuid) + ) + row = result.fetchone() + + metadata.drop_all(engine) + + success = ( + row is not None + and row[0] == "Test" + and isinstance(row[1], uuid_module.UUID) + and row[1] == test_uuid + ) + + return Response.json( + { + "test": "uuid_basic", + "success": success, + "uuid_type": type(row[1]).__name__ if row else None, + "uuid_match": row[1] == test_uuid if row else False, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "uuid_basic", "success": False, "error": str(e)}, + status=500, + ) + + async def test_uuid_nullable(self): + """Test nullable UUID columns handle NULL correctly.""" + import uuid as uuid_module + from sqlalchemy import Column, Integer, MetaData, String, Table, Uuid, select + + table_name = f"test_uuid_null_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("item_uuid", Uuid(as_uuid=True), nullable=True), + ) + + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + test_table.insert().values( + title="With UUID", item_uuid=uuid_module.uuid4() + ) + ) + conn.execute( + test_table.insert().values(title="No UUID", item_uuid=None) + ) + conn.commit() + + result = conn.execute( + select(test_table.c.title, test_table.c.item_uuid).order_by( + test_table.c.id + ) + ) + rows = result.fetchall() + + metadata.drop_all(engine) + + success = ( + len(rows) == 2 + and isinstance(rows[0][1], uuid_module.UUID) + and rows[1][1] is None + ) + + return Response.json( + { + "test": "uuid_nullable", + "success": success, + "with_uuid_is_uuid": isinstance(rows[0][1], uuid_module.UUID) + if rows + else False, + "no_uuid_is_none": rows[1][1] is None if len(rows) > 1 else False, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "uuid_nullable", "success": False, "error": str(e)}, + status=500, + ) + + async def test_uuid_orm(self): + """Test UUID via ORM session.""" + import uuid as uuid_module + from sqlalchemy import Integer, String, Uuid + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session + + engine = create_engine_from_binding(self.env.DB) + table_name = f"test_uuid_orm_{uuid_module.uuid4().hex[:8]}" + + try: + + class Base(DeclarativeBase): + pass + + class Item(Base): + __tablename__ = table_name + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String(127)) + item_uuid: Mapped[uuid_module.UUID] = mapped_column(Uuid(as_uuid=True)) + + Base.metadata.create_all(engine) + + test_uuid = uuid_module.uuid4() + + with Session(engine) as session: + item = Item(title="UUID ORM Test", item_uuid=test_uuid) + session.add(item) + session.commit() + session.refresh(item) + entry_title = item.title + uuid_is_uuid = isinstance(item.item_uuid, uuid_module.UUID) + uuid_match = item.item_uuid == test_uuid + + Base.metadata.drop_all(engine) + + return Response.json( + { + "test": "uuid_orm", + "success": uuid_is_uuid and uuid_match, + "entry_title": entry_title, + "uuid_is_uuid": uuid_is_uuid, + "uuid_match": uuid_match, + } + ) + except Exception as e: + try: + from sqlalchemy import MetaData, Table + + md = MetaData() + Table(table_name, md) + md.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "uuid_orm", "success": False, "error": str(e)}, + status=500, + ) + + async def test_uuid_pk(self): + """Test UUID used as primary key.""" + import uuid as uuid_module + from sqlalchemy import Column, MetaData, String, Table, Uuid, select + + table_name = f"test_uuid_pk_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Uuid(as_uuid=True), primary_key=True), + Column("title", String(127)), + ) + + metadata.create_all(engine) + + pk_uuid = uuid_module.uuid4() + + with engine.connect() as conn: + conn.execute(test_table.insert().values(id=pk_uuid, title="PK Test")) + conn.commit() + + result = conn.execute( + select(test_table).where(test_table.c.id == pk_uuid) + ) + row = result.fetchone() + + metadata.drop_all(engine) + + success = ( + row is not None + and isinstance(row[0], uuid_module.UUID) + and row[0] == pk_uuid + and row[1] == "PK Test" + ) + + return Response.json( + { + "test": "uuid_pk", + "success": success, + "pk_is_uuid": isinstance(row[0], uuid_module.UUID) + if row + else False, + "pk_match": row[0] == pk_uuid if row else False, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "uuid_pk", "success": False, "error": str(e)}, + status=500, + ) + + # MARK: - Enum Tests (GitHub issue #24) + + async def test_enum_basic(self): + """Test Enum column with string values insert and retrieve.""" + import uuid as uuid_module + from sqlalchemy import Column, Enum, Integer, MetaData, String, Table, select + + table_name = f"test_enum_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum("active", "inactive", "pending")), + ) + + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute(test_table.insert().values(title="Test", status="active")) + conn.commit() + + result = conn.execute(select(test_table.c.title, test_table.c.status)) + row = result.fetchone() + + metadata.drop_all(engine) + + success = row is not None and row[0] == "Test" and row[1] == "active" + + return Response.json( + { + "test": "enum_basic", + "success": success, + "status_value": row[1] if row else None, + "status_type": type(row[1]).__name__ if row else None, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "enum_basic", "success": False, "error": str(e)}, + status=500, + ) + + async def test_enum_python_class(self): + """Test Enum column with Python enum.Enum class.""" + import enum as enum_module + import uuid as uuid_module + from sqlalchemy import Column, Enum, Integer, MetaData, String, Table, select + + class Status(enum_module.Enum): + active = "active" + inactive = "inactive" + pending = "pending" + + table_name = f"test_enum_cls_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum(Status)), + ) + + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + test_table.insert().values(title="Test", status=Status.active) + ) + conn.commit() + + result = conn.execute(select(test_table.c.title, test_table.c.status)) + row = result.fetchone() + + metadata.drop_all(engine) + + success = row is not None and row[0] == "Test" and row[1] == Status.active + + return Response.json( + { + "test": "enum_python_class", + "success": success, + "status_is_enum": isinstance(row[1], Status) if row else False, + "status_value": row[1].value + if row and isinstance(row[1], Status) + else str(row[1]) + if row + else None, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "enum_python_class", "success": False, "error": str(e)}, + status=500, + ) + + async def test_enum_nullable(self): + """Test nullable Enum columns handle NULL correctly.""" + import uuid as uuid_module + from sqlalchemy import Column, Enum, Integer, MetaData, String, Table, select + + table_name = f"test_enum_null_{uuid_module.uuid4().hex[:8]}" + + try: + engine = self.get_engine() + metadata = MetaData() + + test_table = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum("active", "inactive"), nullable=True), + ) + + metadata.create_all(engine) + + with engine.connect() as conn: + conn.execute( + test_table.insert().values(title="With Status", status="active") + ) + conn.execute(test_table.insert().values(title="No Status", status=None)) + conn.commit() + + result = conn.execute( + select(test_table.c.title, test_table.c.status).order_by( + test_table.c.id + ) + ) + rows = result.fetchall() + + metadata.drop_all(engine) + + success = len(rows) == 2 and rows[0][1] == "active" and rows[1][1] is None + + return Response.json( + { + "test": "enum_nullable", + "success": success, + "with_status_value": rows[0][1] if rows else None, + "no_status_is_none": rows[1][1] is None if len(rows) > 1 else False, + } + ) + except Exception as e: + try: + metadata.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "enum_nullable", "success": False, "error": str(e)}, + status=500, + ) + + async def test_enum_orm(self): + """Test Enum via ORM session with Python enum class.""" + import enum as enum_module + import uuid as uuid_module + from sqlalchemy import Enum, Integer, String + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session + + class Priority(enum_module.Enum): + low = "low" + medium = "medium" + high = "high" + + engine = create_engine_from_binding(self.env.DB) + table_name = f"test_enum_orm_{uuid_module.uuid4().hex[:8]}" + + try: + + class Base(DeclarativeBase): + pass + + class Task(Base): + __tablename__ = table_name + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String(127)) + priority: Mapped[Priority] = mapped_column(Enum(Priority)) + + Base.metadata.create_all(engine) + + with Session(engine) as session: + task = Task(title="Enum ORM Task", priority=Priority.high) + session.add(task) + session.commit() + session.refresh(task) + entry_title = task.title + priority_is_enum = isinstance(task.priority, Priority) + priority_match = task.priority == Priority.high + + Base.metadata.drop_all(engine) + + return Response.json( + { + "test": "enum_orm", + "success": priority_is_enum and priority_match, + "entry_title": entry_title, + "priority_is_enum": priority_is_enum, + "priority_match": priority_match, + } + ) + except Exception as e: + try: + from sqlalchemy import MetaData, Table + + md = MetaData() + Table(table_name, md) + md.drop_all(engine) + except Exception: + pass + return Response.json( + {"test": "enum_orm", "success": False, "error": str(e)}, + status=500, + ) + # MARK: - Parallel Query Tests (GitHub issue #20) async def test_parallel_queries_engine(self): diff --git a/examples/workers/uv.lock b/examples/workers/uv.lock index d70f25f..0b743fb 100644 --- a/examples/workers/uv.lock +++ b/examples/workers/uv.lock @@ -614,7 +614,7 @@ wheels = [ [[package]] name = "sqlalchemy-cloudflare-d1" -version = "0.3.8" +version = "0.3.10" source = { editable = "../../" } dependencies = [ { name = "httpx" }, diff --git a/pyproject.toml b/pyproject.toml index f359105..39aea16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlalchemy-cloudflare-d1" -version = "0.3.9" +version = "0.3.10" description = "A SQLAlchemy dialect for Cloudflare's D1 Serverless SQLite Database" readme = "README.md" authors = [ diff --git a/src/sqlalchemy_cloudflare_d1/__init__.py b/src/sqlalchemy_cloudflare_d1/__init__.py index 84b8776..78329c6 100644 --- a/src/sqlalchemy_cloudflare_d1/__init__.py +++ b/src/sqlalchemy_cloudflare_d1/__init__.py @@ -15,6 +15,8 @@ engine = create_async_engine("cloudflare_d1+async://account_id:api_token@database_id") """ +from typing import Any + from .dialect import CloudflareD1Dialect from .connection import ( # Sync classes @@ -47,7 +49,7 @@ ) -def __getattr__(name: str): +def __getattr__(name: str) -> Any: """Lazy import for async dialect to avoid requiring greenlet at import time. The async SQLAlchemy dialect (CloudflareD1Dialect_async) requires greenlet diff --git a/src/sqlalchemy_cloudflare_d1/compiler.py b/src/sqlalchemy_cloudflare_d1/compiler.py index 8e92d84..68e22c6 100644 --- a/src/sqlalchemy_cloudflare_d1/compiler.py +++ b/src/sqlalchemy_cloudflare_d1/compiler.py @@ -5,6 +5,8 @@ including support for INSERT ... ON CONFLICT DO UPDATE (upsert). """ +from typing import Any, NoReturn + from sqlalchemy.dialects.sqlite.base import ( SQLiteCompiler, SQLiteDDLCompiler, @@ -26,7 +28,7 @@ class CloudflareD1Compiler(SQLiteCompiler): including ON CONFLICT DO UPDATE support for upserts. """ - def limit_clause(self, select, **kw): + def limit_clause(self, select: Any, **kw: Any) -> str: """Handle LIMIT clause for SQLite.""" text = "" if select._limit_clause is not None: @@ -38,29 +40,29 @@ def limit_clause(self, select, **kw): text += " OFFSET " + self.process(select._offset_clause, **kw) return text - def visit_true(self, element, **kw): + def visit_true(self, element: Any, **kw: Any) -> str: # type: ignore[override] """Handle boolean TRUE.""" return "1" - def visit_false(self, element, **kw): + def visit_false(self, element: Any, **kw: Any) -> str: # type: ignore[override] """Handle boolean FALSE.""" return "0" - def visit_mod_binary(self, binary, operator, **kw): + def visit_mod_binary(self, binary: Any, operator: Any, **kw: Any) -> str: """Handle modulo operator.""" return ( self.process(binary.left, **kw) + " % " + self.process(binary.right, **kw) ) - def visit_now_func(self, fn, **kw): + def visit_now_func(self, fn: Any, **kw: Any) -> str: """Handle CURRENT_TIMESTAMP function.""" return "CURRENT_TIMESTAMP" - def visit_char_length_func(self, fn, **kw): + def visit_char_length_func(self, fn: Any, **kw: Any) -> str: """Handle CHAR_LENGTH function.""" return "length" + self.function_argspec(fn, **kw) - def visit_cast(self, cast, **kw): + def visit_cast(self, cast: Any, **kw: Any) -> str: """Handle CAST operations.""" type_ = cast.typeclause.type @@ -80,7 +82,7 @@ def visit_cast(self, cast, **kw): return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), sqlite_type) - def visit_extract(self, extract, **kw): + def visit_extract(self, extract: Any, **kw: Any) -> str: """Handle EXTRACT function.""" field = extract.field expr = self.process(extract.expr, **kw) @@ -105,14 +107,18 @@ def visit_extract(self, extract, **kw): else: return f"strftime('%{field}', {expr})" - def visit_regexp_match_op_binary(self, binary, operator, **kw): + def visit_regexp_match_op_binary( + self, binary: Any, operator: Any, **kw: Any + ) -> str: """Handle REGEXP operator.""" return "%s REGEXP %s" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), ) - def visit_regexp_replace_op_binary(self, binary, operator, **kw): + def visit_regexp_replace_op_binary( + self, binary: Any, operator: Any, **kw: Any + ) -> NoReturn: """Handle REGEXP_REPLACE (not natively supported in SQLite).""" # SQLite doesn't have native REGEXP_REPLACE, would need custom function raise NotImplementedError("REGEXP_REPLACE not supported in SQLite/D1") @@ -126,7 +132,9 @@ class CloudflareD1DDLCompiler(SQLiteDDLCompiler): D1 only supports it on INTEGER PRIMARY KEY, not TEXT PRIMARY KEY. """ - def get_column_specification(self, column, first_pk=False, **kwargs): + def get_column_specification( + self, column: Any, first_pk: bool = False, **kwargs: Any + ) -> str: """Get column specification for CREATE TABLE. Overrides SQLite's implementation to never add AUTOINCREMENT, @@ -161,8 +169,8 @@ def get_column_specification(self, column, first_pk=False, **kwargs): return colspec def create_table_constraints( - self, table, _include_foreign_key_constraints=None, **kw - ): + self, table: Any, _include_foreign_key_constraints: Any = None, **kw: Any + ) -> str: """Create table constraints, skipping PK if it was added inline. For single-column primary keys, we add PRIMARY KEY inline in @@ -203,7 +211,7 @@ def create_table_constraints( if p is not None ) - def visit_drop_table(self, drop, **kw): + def visit_drop_table(self, drop: Any, **kw: Any) -> str: """Handle DROP TABLE statements.""" text = "\nDROP TABLE " if drop.if_exists: @@ -211,7 +219,7 @@ def visit_drop_table(self, drop, **kw): text += self.preparer.format_table(drop.element) return text - def visit_create_index(self, create, **kw): + def visit_create_index(self, create: Any, **kw: Any) -> str: # type: ignore[override] """Handle CREATE INDEX statements.""" index = create.element preparer = self.preparer @@ -233,7 +241,7 @@ def visit_create_index(self, create, **kw): return text - def visit_drop_index(self, drop, **kw): + def visit_drop_index(self, drop: Any, **kw: Any) -> str: """Handle DROP INDEX statements.""" text = "\nDROP INDEX " if drop.if_exists: @@ -248,41 +256,41 @@ class CloudflareD1TypeCompiler(SQLiteTypeCompiler): Inherits from SQLiteTypeCompiler to get proper SQLite type compilation. """ - def visit_TEXT(self, type_, **kw): + def visit_TEXT(self, type_: Any, **kw: Any) -> str: """Handle TEXT type.""" return "TEXT" - def visit_STRING(self, type_, **kw): + def visit_STRING(self, type_: Any, **kw: Any) -> str: """Handle STRING/VARCHAR type.""" if type_.length: return f"VARCHAR({type_.length})" return "TEXT" - def visit_VARCHAR(self, type_, **kw): + def visit_VARCHAR(self, type_: Any, **kw: Any) -> str: """Handle VARCHAR type.""" if type_.length: return f"VARCHAR({type_.length})" return "TEXT" - def visit_CHAR(self, type_, **kw): + def visit_CHAR(self, type_: Any, **kw: Any) -> str: """Handle CHAR type.""" if type_.length: return f"CHAR({type_.length})" return "TEXT" - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: Any, **kw: Any) -> str: """Handle INTEGER type.""" return "INTEGER" - def visit_BIGINT(self, type_, **kw): + def visit_BIGINT(self, type_: Any, **kw: Any) -> str: """Handle BIGINT type.""" return "INTEGER" # SQLite treats all integers the same - def visit_SMALLINT(self, type_, **kw): + def visit_SMALLINT(self, type_: Any, **kw: Any) -> str: """Handle SMALLINT type.""" return "INTEGER" # SQLite treats all integers the same - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: Any, **kw: Any) -> str: """Handle NUMERIC type.""" if type_.precision is not None and type_.scale is not None: return f"NUMERIC({type_.precision}, {type_.scale})" @@ -290,42 +298,54 @@ def visit_NUMERIC(self, type_, **kw): return f"NUMERIC({type_.precision})" return "NUMERIC" - def visit_DECIMAL(self, type_, **kw): + def visit_DECIMAL(self, type_: Any, **kw: Any) -> str: """Handle DECIMAL type.""" return self.visit_NUMERIC(type_, **kw) - def visit_REAL(self, type_, **kw): + def visit_REAL(self, type_: Any, **kw: Any) -> str: """Handle REAL type.""" return "REAL" - def visit_FLOAT(self, type_, **kw): + def visit_FLOAT(self, type_: Any, **kw: Any) -> str: """Handle FLOAT type.""" return "REAL" # SQLite uses REAL for floating point - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: Any, **kw: Any) -> str: """Handle BOOLEAN type.""" return "INTEGER" # SQLite stores boolean as INTEGER - def visit_DATE(self, type_, **kw): + def visit_DATE(self, type_: Any, **kw: Any) -> str: """Handle DATE type.""" return "TEXT" # SQLite stores dates as TEXT - def visit_TIME(self, type_, **kw): + def visit_TIME(self, type_: Any, **kw: Any) -> str: """Handle TIME type.""" return "TEXT" # SQLite stores times as TEXT - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: Any, **kw: Any) -> str: """Handle DATETIME type.""" return "TEXT" # SQLite stores datetimes as TEXT - def visit_TIMESTAMP(self, type_, **kw): + def visit_TIMESTAMP(self, type_: Any, **kw: Any) -> str: """Handle TIMESTAMP type.""" return "TEXT" # SQLite stores timestamps as TEXT - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: Any, **kw: Any) -> str: """Handle BLOB type.""" return "BLOB" - def visit_CLOB(self, type_, **kw): + def visit_CLOB(self, type_: Any, **kw: Any) -> str: """Handle CLOB type.""" return "TEXT" + + def visit_UUID(self, type_: Any, **kw: Any) -> str: + """Handle UUID type — D1 stores as TEXT.""" + return "TEXT" + + def visit_uuid(self, type_: Any, **kw: Any) -> str: + """Handle Uuid type — D1 stores as TEXT.""" + return "TEXT" + + def visit_enum(self, type_: Any, **kw: Any) -> str: + """Handle Enum type — D1 stores as TEXT.""" + return "TEXT" diff --git a/src/sqlalchemy_cloudflare_d1/connection.py b/src/sqlalchemy_cloudflare_d1/connection.py index bd67b48..005dce3 100644 --- a/src/sqlalchemy_cloudflare_d1/connection.py +++ b/src/sqlalchemy_cloudflare_d1/connection.py @@ -7,7 +7,7 @@ """ import os -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union try: import httpx @@ -243,33 +243,33 @@ def __getitem__(self, key: Union[int, str]) -> Any: else: raise TypeError("Key must be int or str") - def __iter__(self): + def __iter__(self) -> Iterator[Any]: """Iterate over values.""" return iter(self._values) - def __len__(self): + def __len__(self) -> int: """Get number of columns.""" return len(self._values) - def __bool__(self): + def __bool__(self) -> bool: """Check if row has data.""" return bool(self._values) - def __repr__(self): + def __repr__(self) -> str: """String representation of the row.""" return f"Row({dict(zip(self._keys, self._values))})" - def keys(self): + def keys(self) -> List[str]: """Get column names.""" return self._keys - def values(self): + def values(self) -> List[Any]: """Get values.""" return self._values - def items(self): + def items(self) -> Iterator[Tuple[str, Any]]: """Get (key, value) pairs.""" - return zip(self._keys, self._values) + return zip(self._keys, self._values) # type: ignore[return-value] # Add attribute access for compatibility def __getattr__(self, name: str) -> Any: @@ -404,11 +404,11 @@ def lastrowid(self) -> Optional[int]: """Get the ID of the last inserted row.""" return self._last_result_meta.get("last_row_id") - def __iter__(self): + def __iter__(self) -> "BaseCursorMixin": """Make cursor iterable.""" return self - def __next__(self): + def __next__(self) -> tuple: """Get next row for iteration.""" row = self.fetchone() if row is None: @@ -464,7 +464,9 @@ def executemany( class Connection: """DBAPI-compatible connection for Cloudflare D1 REST API.""" - def __init__(self, account_id: str, database_id: str, api_token: str, **kwargs): + def __init__( + self, account_id: str, database_id: str, api_token: str, **kwargs: Any + ) -> None: """Initialize D1 connection via REST API.""" if not HTTPX_AVAILABLE: raise ImportError( @@ -519,7 +521,7 @@ def rollback(self) -> None: # D1 doesn't support explicit transactions via REST API pass - def execute(self, operation: str, parameters: Optional[Sequence] = None): + def execute(self, operation: str, parameters: Optional[Sequence] = None) -> Cursor: """Execute operation directly on connection (convenience method).""" cursor = self.cursor() cursor.execute(operation, parameters) @@ -778,7 +780,7 @@ class CloudflareD1DBAPI: NotSupportedError = NotSupportedError @staticmethod - def connect(**kwargs) -> Connection: + def connect(**kwargs: Any) -> Connection: """Create a new database connection.""" return Connection(**kwargs) @@ -801,7 +803,7 @@ def Binary(data: bytes) -> bytes: paramstyle = CloudflareD1DBAPI.paramstyle -def connect(**kwargs) -> Connection: +def connect(**kwargs: Any) -> Connection: """Create a new database connection.""" return CloudflareD1DBAPI.connect(**kwargs) @@ -821,7 +823,9 @@ class AsyncConnection: rows = await cursor.fetchall() """ - def __init__(self, account_id: str, database_id: str, api_token: str, **kwargs): + def __init__( + self, account_id: str, database_id: str, api_token: str, **kwargs: Any + ) -> None: """Initialize async D1 connection via REST API.""" if not HTTPX_AVAILABLE: raise ImportError( @@ -858,7 +862,7 @@ async def __aenter__(self) -> "AsyncConnection": """Async context manager entry.""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Async context manager exit.""" await self.close() @@ -975,7 +979,7 @@ async def __aenter__(self) -> "AsyncCursor": """Async context manager entry.""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Async context manager exit.""" await self.close() @@ -1052,7 +1056,7 @@ async def close(self) -> None: # type: ignore[override] BaseCursorMixin.close(self) -async def connect_async(**kwargs) -> AsyncConnection: +async def connect_async(**kwargs: Any) -> AsyncConnection: """Create a new async database connection.""" return AsyncConnection(**kwargs) @@ -1088,7 +1092,7 @@ def __init__(self, d1_binding: Any): """Store the D1 binding for later use in connect().""" self._d1_binding = d1_binding - def connect(self, **kwargs) -> "SyncWorkerConnection": + def connect(self, **kwargs: Any) -> "SyncWorkerConnection": """Create a new database connection using the stored D1 binding.""" return SyncWorkerConnection(self._d1_binding) @@ -1153,15 +1157,15 @@ def _execute_query( try: # Import pyodide to use run_sync for async operations # This is available in Cloudflare Python Workers - from pyodide.ffi import run_sync + from pyodide.ffi import run_sync # type: ignore[import-not-found] - async def _run(): - from js import JSON + async def _run() -> Dict[str, Any]: + from js import JSON # type: ignore[import-not-found] # Get a proper JS null value (not undefined) JS_NULL = JSON.parse("null") - def convert_param(val): + def convert_param(val: Any) -> Any: """Convert parameter for D1 binding, handling None -> null.""" if val is None: return JS_NULL @@ -1226,7 +1230,7 @@ def convert_param(val): parsed["columns"] = fallback_columns return parsed - return run_sync(_run()) + return run_sync(_run()) # type: ignore[no-any-return] except ImportError: raise NotSupportedError( @@ -1289,7 +1293,7 @@ def executemany( # MARK: - Engine Factory -def create_engine_from_binding(d1_binding: Any, **kwargs) -> Any: +def create_engine_from_binding(d1_binding: Any, **kwargs: Any) -> Any: """Create a SQLAlchemy engine from a D1 Worker binding. This allows using SQLAlchemy Core and ORM patterns inside Cloudflare diff --git a/src/sqlalchemy_cloudflare_d1/dialect.py b/src/sqlalchemy_cloudflare_d1/dialect.py index e9da755..f88181d 100644 --- a/src/sqlalchemy_cloudflare_d1/dialect.py +++ b/src/sqlalchemy_cloudflare_d1/dialect.py @@ -3,6 +3,8 @@ """ import base64 +import enum as enum_module +import uuid as uuid_module from datetime import date, datetime, time from typing import Any, Callable, Dict, List, Optional @@ -11,6 +13,7 @@ from sqlalchemy.sql.sqltypes import ( Boolean, DateTime, + Enum, INTEGER, LargeBinary, NUMERIC, @@ -18,6 +21,7 @@ TEXT, Date, Time, + Uuid, ) from sqlalchemy import text @@ -88,7 +92,7 @@ def process(value: Any) -> Optional[str]: return None if isinstance(value, bytes): return base64.b64encode(value).decode("ascii") - return value + return value # type: ignore[no-any-return] return process @@ -109,7 +113,7 @@ def process(value: Any) -> Optional[bytes]: except Exception: # If not valid base64, return as encoded bytes return value.encode("utf-8") - return value + return value # type: ignore[no-any-return] return process @@ -155,8 +159,8 @@ def process(value: Any) -> Optional[date]: try: return date.fromisoformat(value) except ValueError: - return value - return value + return value # type: ignore[return-value] + return value # type: ignore[no-any-return] return process @@ -202,8 +206,8 @@ def process(value: Any) -> Optional[time]: try: return time.fromisoformat(value) except ValueError: - return value - return value + return value # type: ignore[return-value] + return value # type: ignore[no-any-return] return process @@ -248,6 +252,48 @@ def process(value: Any) -> Optional[datetime]: if isinstance(value, str): try: return datetime.fromisoformat(value) + except ValueError: + return value # type: ignore[return-value] + return value # type: ignore[no-any-return] + + return process + + +# MARK: - UUID Type Processor + + +class D1UUID(Uuid): + """Custom UUID type for Cloudflare D1. + + D1 stores UUIDs as TEXT. This processor converts uuid.UUID objects to + hyphenated string format on bind and parses them back to uuid.UUID on result. + """ + + def bind_processor(self, dialect: Dialect) -> Callable[[Any], Optional[str]]: + """Convert uuid.UUID to hyphenated string for D1.""" + + def process(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, uuid_module.UUID): + return str(value) + return str(value) + + return process + + def result_processor(self, dialect: Dialect, coltype: Any) -> Callable[[Any], Any]: + """Convert string from D1 back to uuid.UUID (or str if as_uuid=False).""" + as_uuid = self.as_uuid + + def process(value: Any) -> Any: + if value is None: + return None + if isinstance(value, uuid_module.UUID): + return value if as_uuid else str(value) + if isinstance(value, str): + try: + parsed = uuid_module.UUID(value) + return parsed if as_uuid else value except ValueError: return value return value @@ -255,6 +301,37 @@ def process(value: Any) -> Optional[datetime]: return process +# MARK: - Enum Type Processor + + +class D1Enum(Enum): + """Custom Enum type for Cloudflare D1. + + D1 stores enum values as TEXT. This processor ensures Python enum.Enum + objects are converted to their string values on bind. Result processing + delegates to the parent Enum type for enum reconstruction. + """ + + def bind_processor(self, dialect: Dialect) -> Callable[[Any], Optional[str]]: + """Convert Python enum (or any value) to string for D1.""" + + def process(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, enum_module.Enum): + return str(value.value) + return str(value) + + return process + + def result_processor( + self, dialect: Dialect, coltype: Any + ) -> Optional[Callable[[Any], Any]]: + """Delegate enum reconstruction to parent Enum type.""" + parent = super().result_processor(dialect, coltype) + return parent # type: ignore[no-any-return] + + # MARK: - Dialect @@ -292,8 +369,10 @@ class CloudflareD1Dialect(default.DefaultDialect): Boolean: D1Boolean, Date: D1Date, DateTime: D1DateTime, + Enum: D1Enum, LargeBinary: D1LargeBinary, Time: D1Time, + Uuid: D1UUID, } # Reserved words (SQLite keywords) @@ -443,7 +522,7 @@ def create_connect_args(self, url: Any) -> tuple: return (), opts - def get_isolation_level(self, connection: Any) -> Optional[str]: + def get_isolation_level(self, connection: Any) -> Optional[str]: # type: ignore[override] """D1 doesn't support isolation levels.""" return None @@ -476,7 +555,7 @@ def has_table( ) return bool(result.fetchone()) - def get_columns( + def get_columns( # type: ignore[override] self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[Dict[str, Any]]: """Get column information for a table.""" @@ -518,7 +597,7 @@ def _get_column_type(self, type_string: str) -> Any: else: return TEXT() # Default to TEXT for unknown types - def get_pk_constraint( + def get_pk_constraint( # type: ignore[override] self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any ) -> Dict[str, Any]: """Get primary key constraint information.""" @@ -530,7 +609,7 @@ def get_pk_constraint( "name": None, # SQLite doesn't name PK constraints } - def get_foreign_keys( + def get_foreign_keys( # type: ignore[override] self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[Dict[str, Any]]: """Get foreign key constraints.""" @@ -558,7 +637,7 @@ def get_foreign_keys( return list(fks.values()) - def get_indexes( + def get_indexes( # type: ignore[override] self, connection: Any, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[Dict[str, Any]]: """Get index information.""" diff --git a/src/sqlalchemy_cloudflare_d1/dialect_async.py b/src/sqlalchemy_cloudflare_d1/dialect_async.py index 7b5a64d..7ead83b 100644 --- a/src/sqlalchemy_cloudflare_d1/dialect_async.py +++ b/src/sqlalchemy_cloudflare_d1/dialect_async.py @@ -17,7 +17,7 @@ """ from collections import deque -from typing import Any, Optional, Sequence +from typing import Any, Iterator, List, NoReturn, Optional, Sequence, Type from sqlalchemy.engine import AdaptedConnection from sqlalchemy.pool import AsyncAdaptedQueuePool @@ -62,13 +62,15 @@ def __init__(self, adapt_connection: "AsyncAdapt_d1_connection"): self.rowcount = -1 self.lastrowid = None self.description = None - self._rows = deque() + self._rows: deque[Any] = deque() - def close(self): + def close(self) -> None: """Close the cursor - just clear local rows.""" self._rows.clear() - def execute(self, operation: str, parameters: Optional[Sequence] = None): + def execute( + self, operation: str, parameters: Optional[Sequence] = None + ) -> "AsyncAdapt_d1_cursor": """Execute a database operation. Uses await_ to run async operations. Eagerly fetches all results @@ -91,12 +93,12 @@ def execute(self, operation: str, parameters: Optional[Sequence] = None): if is_select: # For SELECT statements, set description (may be empty list for no-column results) # D1 returns [] for empty results since it can't know column names - self.description = _cursor.description if _cursor.description else [] + self.description = _cursor.description if _cursor.description else [] # type: ignore[assignment] self.lastrowid = None self.rowcount = -1 # Eagerly fetch all results into local deque rows = self.await_(_cursor.fetchall()) - self._rows = deque(rows if rows else []) + self._rows = deque(rows if rows else []) # type: ignore[assignment] else: # For non-SELECT statements (INSERT, UPDATE, DELETE) self.description = None @@ -111,7 +113,9 @@ def execute(self, operation: str, parameters: Optional[Sequence] = None): except Exception as error: self._adapt_connection._handle_exception(error) - def executemany(self, operation: str, seq_of_parameters: Sequence[Sequence]): + def executemany( + self, operation: str, seq_of_parameters: Sequence[Sequence] + ) -> "AsyncAdapt_d1_cursor": """Execute operation multiple times.""" try: _cursor = self.await_(self._connection.cursor()) @@ -124,38 +128,38 @@ def executemany(self, operation: str, seq_of_parameters: Sequence[Sequence]): except Exception as error: self._adapt_connection._handle_exception(error) - def setinputsizes(self, *inputsizes): + def setinputsizes(self, *inputsizes: Any) -> None: """No-op for D1.""" pass - def setoutputsize(self, size, column=None): + def setoutputsize(self, size: Any, column: Any = None) -> None: """No-op for D1.""" pass - def __iter__(self): + def __iter__(self) -> Iterator[Any]: """Iterate over results from local deque.""" while self._rows: yield self._rows.popleft() - def fetchone(self): + def fetchone(self) -> Optional[Any]: """Fetch next row from local deque.""" if self._rows: return self._rows.popleft() return None - def fetchmany(self, size=None): + def fetchmany(self, size: Optional[int] = None) -> List[Any]: """Fetch multiple rows from local deque.""" if size is None: size = self.arraysize return [self._rows.popleft() for _ in range(min(size, len(self._rows)))] - def fetchall(self): + def fetchall(self) -> List[Any]: """Fetch all remaining rows from local deque.""" retval = list(self._rows) self._rows.clear() return retval - async def _async_soft_close(self): + async def _async_soft_close(self) -> None: """Async soft close for SQLAlchemy async result compatibility. This is called by SQLAlchemy after execute() completes but BEFORE @@ -183,31 +187,31 @@ def __init__(self, dbapi: "AsyncAdapt_d1_dbapi", connection: AsyncConnection): self.dbapi = dbapi self._connection = connection - def cursor(self): + def cursor(self) -> AsyncAdapt_d1_cursor: """Create a cursor.""" return AsyncAdapt_d1_cursor(self) - def execute(self, *args, **kw): + def execute(self, *args: Any, **kw: Any) -> AsyncAdapt_d1_cursor: """Execute directly on connection.""" cursor = self.cursor() cursor.execute(*args, **kw) return cursor - def rollback(self): + def rollback(self) -> None: """Rollback transaction (no-op for D1).""" try: self.await_(self._connection.rollback()) except Exception as error: self._handle_exception(error) - def commit(self): + def commit(self) -> None: """Commit transaction (no-op for D1).""" try: self.await_(self._connection.commit()) except Exception as error: self._handle_exception(error) - def close(self): + def close(self) -> None: """Close the connection.""" try: self.await_(self._connection.close()) @@ -215,11 +219,11 @@ def close(self): self._handle_exception(error) @property - def closed(self): + def closed(self) -> bool: """Check if connection is closed.""" - return self._connection.closed + return self._connection.closed # type: ignore[no-any-return] - def _handle_exception(self, error): + def _handle_exception(self, error: Exception) -> NoReturn: """Handle and re-raise exceptions appropriately.""" if isinstance( error, (Error, InterfaceError, OperationalError, ProgrammingError) @@ -255,19 +259,7 @@ class AsyncAdapt_d1_dbapi: NotSupportedError, ) - # Make them class attributes - Error = Error - Warning = Warning - InterfaceError = InterfaceError - DatabaseError = DatabaseError - DataError = DataError - OperationalError = OperationalError - IntegrityError = IntegrityError - InternalError = InternalError - ProgrammingError = ProgrammingError - NotSupportedError = NotSupportedError - - def connect(self, **kwargs) -> AsyncAdapt_d1_connection: + def connect(self, **kwargs: Any) -> AsyncAdapt_d1_connection: """Create an async-adapted connection. Creates the AsyncConnection and wraps it. The await_only call @@ -297,7 +289,7 @@ def Binary(data: bytes) -> bytes: _dbapi_singleton = None -def _get_dbapi(): +def _get_dbapi() -> "AsyncAdapt_d1_dbapi": """Get or create the DBAPI singleton.""" global _dbapi_singleton if _dbapi_singleton is None: @@ -327,15 +319,15 @@ def import_dbapi(cls) -> Any: return _get_dbapi() @classmethod - def get_pool_class(cls, url): + def get_pool_class(cls, url: Any) -> Type[AsyncAdaptedQueuePool]: """Return the async pool class.""" return AsyncAdaptedQueuePool - def get_driver_connection(self, connection): + def get_driver_connection(self, connection: Any) -> Any: """Get the underlying driver connection.""" return connection._connection - def is_disconnect(self, e, connection, cursor): + def is_disconnect(self, e: Exception, connection: Any, cursor: Any) -> bool: """Check if exception indicates a disconnected state.""" if isinstance(e, OperationalError): msg = str(e).lower() diff --git a/tests/integration/test_restapi_integration.py b/tests/integration/test_restapi_integration.py index 761af46..89b27d1 100644 --- a/tests/integration/test_restapi_integration.py +++ b/tests/integration/test_restapi_integration.py @@ -3015,5 +3015,306 @@ def test_time_filter_query(self, d1_engine, test_table_name): metadata.drop_all(d1_engine) +# MARK: - UUID Column Tests + + +class TestUUIDColumn: + """Test UUID column handling. + + D1 stores UUIDs as TEXT. The D1UUID type processor converts uuid.UUID + objects to hyphenated strings on bind and parses them back on result. + """ + + def test_uuid_insert_and_retrieve(self, d1_engine, test_table_name): + """Test that UUID columns can store and retrieve uuid.UUID values.""" + import uuid as uuid_module + from sqlalchemy import Uuid + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("item_uuid", Uuid(as_uuid=True)), + ) + + metadata.create_all(d1_engine) + + try: + test_uuid = uuid_module.uuid4() + + with d1_engine.connect() as conn: + conn.execute( + test_table.insert().values(title="Test", item_uuid=test_uuid) + ) + conn.commit() + + result = conn.execute( + select(test_table.c.title, test_table.c.item_uuid) + ) + row = result.fetchone() + + assert row is not None + assert row[0] == "Test" + assert isinstance(row[1], uuid_module.UUID) + assert row[1] == test_uuid + finally: + metadata.drop_all(d1_engine) + + def test_uuid_nullable(self, d1_engine, test_table_name): + """Test nullable UUID columns handle NULL correctly.""" + import uuid as uuid_module + from sqlalchemy import Uuid + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("item_uuid", Uuid(as_uuid=True), nullable=True), + ) + + metadata.create_all(d1_engine) + + try: + with d1_engine.connect() as conn: + conn.execute( + test_table.insert().values(title="No UUID", item_uuid=None) + ) + conn.execute( + test_table.insert().values( + title="With UUID", item_uuid=uuid_module.uuid4() + ) + ) + conn.commit() + + result = conn.execute( + select(test_table.c.title, test_table.c.item_uuid) + ) + rows = result.fetchall() + + assert len(rows) == 2 + assert rows[0][1] is None + assert isinstance(rows[1][1], uuid_module.UUID) + finally: + metadata.drop_all(d1_engine) + + def test_uuid_as_primary_key(self, d1_engine, test_table_name): + """Test UUID used as primary key.""" + import uuid as uuid_module + from sqlalchemy import Uuid + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Uuid(as_uuid=True), primary_key=True), + Column("title", String(127)), + ) + + metadata.create_all(d1_engine) + + try: + pk_uuid = uuid_module.uuid4() + + with d1_engine.connect() as conn: + conn.execute(test_table.insert().values(id=pk_uuid, title="PK Test")) + conn.commit() + + result = conn.execute( + select(test_table).where(test_table.c.id == pk_uuid) + ) + row = result.fetchone() + + assert row is not None + assert isinstance(row[0], uuid_module.UUID) + assert row[0] == pk_uuid + assert row[1] == "PK Test" + finally: + metadata.drop_all(d1_engine) + + def test_uuid_orm_session(self, d1_engine): + """Test UUID via ORM session.""" + import uuid as uuid_module + from sqlalchemy import Uuid + from sqlalchemy.orm import Mapped, Session, declarative_base, mapped_column + + Base = declarative_base() + + class Item(Base): + __tablename__ = f"items_{uuid_module.uuid4().hex[:8]}" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String(127)) + item_uuid: Mapped[uuid_module.UUID] = mapped_column(Uuid(as_uuid=True)) + + Base.metadata.create_all(d1_engine) + + try: + test_uuid = uuid_module.uuid4() + + with Session(d1_engine) as session: + item = Item(title="ORM Test", item_uuid=test_uuid) + session.add(item) + session.commit() + + retrieved = session.query(Item).first() + + assert retrieved is not None + assert retrieved.title == "ORM Test" + assert isinstance(retrieved.item_uuid, uuid_module.UUID) + assert retrieved.item_uuid == test_uuid + finally: + Base.metadata.drop_all(d1_engine) + + +# MARK: - Enum Column Tests + + +class TestEnumColumn: + """Test Enum column handling. + + D1 stores enum values as TEXT. The D1Enum type processor converts Python + enum.Enum objects to their string values on bind and reconstructs them on result. + """ + + def test_enum_string_values_insert_and_retrieve(self, d1_engine, test_table_name): + """Test Enum with string values (no Python enum class).""" + from sqlalchemy import Enum + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum("active", "inactive", "pending")), + ) + + metadata.create_all(d1_engine) + + try: + with d1_engine.connect() as conn: + conn.execute(test_table.insert().values(title="Test", status="active")) + conn.commit() + + result = conn.execute(select(test_table.c.title, test_table.c.status)) + row = result.fetchone() + + assert row is not None + assert row[0] == "Test" + assert row[1] == "active" + finally: + metadata.drop_all(d1_engine) + + def test_enum_python_enum_class(self, d1_engine, test_table_name): + """Test Enum with a Python enum.Enum class.""" + import enum as enum_module + from sqlalchemy import Enum + + class Status(enum_module.Enum): + active = "active" + inactive = "inactive" + pending = "pending" + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum(Status)), + ) + + metadata.create_all(d1_engine) + + try: + with d1_engine.connect() as conn: + conn.execute( + test_table.insert().values(title="Test", status=Status.active) + ) + conn.commit() + + result = conn.execute(select(test_table.c.title, test_table.c.status)) + row = result.fetchone() + + assert row is not None + assert row[0] == "Test" + assert row[1] == Status.active + finally: + metadata.drop_all(d1_engine) + + def test_enum_nullable(self, d1_engine, test_table_name): + """Test nullable Enum columns handle NULL correctly.""" + from sqlalchemy import Enum + + metadata = MetaData() + test_table = Table( + test_table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("title", String(127)), + Column("status", Enum("active", "inactive"), nullable=True), + ) + + metadata.create_all(d1_engine) + + try: + with d1_engine.connect() as conn: + conn.execute(test_table.insert().values(title="No Status", status=None)) + conn.execute( + test_table.insert().values(title="Active", status="active") + ) + conn.commit() + + result = conn.execute(select(test_table.c.title, test_table.c.status)) + rows = result.fetchall() + + assert len(rows) == 2 + assert rows[0][1] is None + assert rows[1][1] == "active" + finally: + metadata.drop_all(d1_engine) + + def test_enum_orm_session(self, d1_engine): + """Test Enum via ORM session with Python enum class.""" + import enum as enum_module + import uuid as uuid_module + from sqlalchemy import Enum + from sqlalchemy.orm import Mapped, Session, declarative_base, mapped_column + + class Priority(enum_module.Enum): + low = "low" + medium = "medium" + high = "high" + + Base = declarative_base() + + class Task(Base): + __tablename__ = f"tasks_{uuid_module.uuid4().hex[:8]}" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + title: Mapped[str] = mapped_column(String(127)) + priority: Mapped[Priority] = mapped_column(Enum(Priority)) + + Base.metadata.create_all(d1_engine) + + try: + with Session(d1_engine) as session: + task = Task(title="ORM Task", priority=Priority.high) + session.add(task) + session.commit() + + retrieved = session.query(Task).first() + + assert retrieved is not None + assert retrieved.title == "ORM Task" + assert retrieved.priority == Priority.high + finally: + Base.metadata.drop_all(d1_engine) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) diff --git a/tests/integration/test_worker_integration.py b/tests/integration/test_worker_integration.py index da66268..ebaa417 100644 --- a/tests/integration/test_worker_integration.py +++ b/tests/integration/test_worker_integration.py @@ -961,6 +961,133 @@ def test_time_orm_session(self, dev_server): assert data["entry_title"] == "Time Test Entry" +# MARK: - UUID Tests (GitHub issue #24) + + +class TestWorkerUUIDColumn: + """Test UUID column handling via Worker endpoints.""" + + def test_uuid_insert_and_retrieve(self, dev_server): + """Test UUID column insert and retrieve.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/uuid-basic") + + assert response.status_code == 200, f"uuid_basic failed: {response.json()}" + data = response.json() + + assert data["test"] == "uuid_basic" + assert data["success"] is True, f"uuid_basic failed: error={data.get('error')}" + assert data["uuid_type"] == "UUID" + assert data["uuid_match"] is True + + def test_uuid_nullable(self, dev_server): + """Test nullable UUID columns handle NULL correctly.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/uuid-nullable") + + assert response.status_code == 200, f"uuid_nullable failed: {response.json()}" + data = response.json() + + assert data["test"] == "uuid_nullable" + assert data["success"] is True, ( + f"uuid_nullable failed: error={data.get('error')}" + ) + assert data["with_uuid_is_uuid"] is True + assert data["no_uuid_is_none"] is True + + def test_uuid_as_primary_key(self, dev_server): + """Test UUID used as primary key.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/uuid-pk") + + assert response.status_code == 200, f"uuid_pk failed: {response.json()}" + data = response.json() + + assert data["test"] == "uuid_pk" + assert data["success"] is True, f"uuid_pk failed: error={data.get('error')}" + assert data["pk_is_uuid"] is True + assert data["pk_match"] is True + + def test_uuid_orm_session(self, dev_server): + """Test UUID via ORM session.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/uuid-orm") + + assert response.status_code == 200, f"uuid_orm failed: {response.json()}" + data = response.json() + + assert data["test"] == "uuid_orm" + assert data["success"] is True, f"uuid_orm failed: error={data.get('error')}" + assert data["entry_title"] == "UUID ORM Test" + assert data["uuid_is_uuid"] is True + assert data["uuid_match"] is True + + +# MARK: - Enum Tests (GitHub issue #24) + + +class TestWorkerEnumColumn: + """Test Enum column handling via Worker endpoints.""" + + def test_enum_string_values(self, dev_server): + """Test Enum with string values insert and retrieve.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/enum-basic") + + assert response.status_code == 200, f"enum_basic failed: {response.json()}" + data = response.json() + + assert data["test"] == "enum_basic" + assert data["success"] is True, f"enum_basic failed: error={data.get('error')}" + assert data["status_value"] == "active" + + def test_enum_python_class(self, dev_server): + """Test Enum with Python enum.Enum class.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/enum-python-class") + + assert response.status_code == 200, ( + f"enum_python_class failed: {response.json()}" + ) + data = response.json() + + assert data["test"] == "enum_python_class" + assert data["success"] is True, ( + f"enum_python_class failed: error={data.get('error')}" + ) + assert data["status_is_enum"] is True + assert data["status_value"] == "active" + + def test_enum_nullable(self, dev_server): + """Test nullable Enum columns handle NULL correctly.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/enum-nullable") + + assert response.status_code == 200, f"enum_nullable failed: {response.json()}" + data = response.json() + + assert data["test"] == "enum_nullable" + assert data["success"] is True, ( + f"enum_nullable failed: error={data.get('error')}" + ) + assert data["with_status_value"] == "active" + assert data["no_status_is_none"] is True + + def test_enum_orm_session(self, dev_server): + """Test Enum via ORM session.""" + port = dev_server + response = requests.get(f"http://localhost:{port}/enum-orm") + + assert response.status_code == 200, f"enum_orm failed: {response.json()}" + data = response.json() + + assert data["test"] == "enum_orm" + assert data["success"] is True, f"enum_orm failed: error={data.get('error')}" + assert data["entry_title"] == "Enum ORM Task" + assert data["priority_is_enum"] is True + assert data["priority_match"] is True + + # MARK: - Parallel Query Tests (GitHub issue #20) diff --git a/uv.lock b/uv.lock index 929c813..2af36dc 100644 --- a/uv.lock +++ b/uv.lock @@ -922,7 +922,7 @@ wheels = [ [[package]] name = "sqlalchemy-cloudflare-d1" -version = "0.3.8" +version = "0.3.10" source = { editable = "." } dependencies = [ { name = "httpx" },