From 7821ed2c5bd7132881d13e519205a892389c8f92 Mon Sep 17 00:00:00 2001 From: Grant Ramsay Date: Mon, 23 Feb 2026 19:37:04 +0000 Subject: [PATCH 1/5] Implement runtime db_column mapping across SQL generation --- sqliter/model/foreign_key.py | 24 ++++++ sqliter/query/query.py | 148 +++++++++++++++++++++++++++-------- sqliter/sqliter.py | 59 +++++++++++--- 3 files changed, 189 insertions(+), 42 deletions(-) diff --git a/sqliter/model/foreign_key.py b/sqliter/model/foreign_key.py index 72fa8ae0..4b7d7114 100644 --- a/sqliter/model/foreign_key.py +++ b/sqliter/model/foreign_key.py @@ -151,3 +151,27 @@ def get_foreign_key_info(field_info: FieldInfo) -> Optional[ForeignKeyInfo]: if isinstance(fk_info, ForeignKeyInfo): return fk_info return None + + +def get_model_field_db_column( + model_class: type[BaseDBModel], field_name: str +) -> str: + """Resolve a model field name to its actual database column name. + + Args: + model_class: The model class owning the field. + field_name: The model field name (for example ``author_id``). + + Returns: + The database column name. For non-FK fields this is the same as + ``field_name``; for FK fields with ``db_column`` metadata this returns + the configured column name. + """ + field_info = model_class.model_fields.get(field_name) + if field_info is None: + return field_name + + fk_info = get_foreign_key_info(field_info) + if fk_info is None or fk_info.db_column is None: + return field_name + return fk_info.db_column diff --git a/sqliter/query/query.py b/sqliter/query/query.py index 04e073b6..85a719cd 100644 --- a/sqliter/query/query.py +++ b/sqliter/query/query.py @@ -44,6 +44,7 @@ RecordFetchError, RecordUpdateError, ) +from sqliter.model.foreign_key import get_model_field_db_column from sqliter.query.aggregates import AggregateSpec if TYPE_CHECKING: # pragma: no cover @@ -205,6 +206,27 @@ def _validate_fields(self) -> None: ) raise ValueError(err_message) + @staticmethod + def _model_field_to_db_column( + model_class: type[BaseDBModel], field_name: str + ) -> str: + """Resolve a model field name to its database column name.""" + return get_model_field_db_column(model_class, field_name) + + def _column_sql( + self, + field_name: str, + *, + model_class: Optional[type[BaseDBModel]] = None, + table_alias: Optional[str] = None, + ) -> str: + """Build a quoted SQL column reference for a model field.""" + resolved_model = model_class or self.model_class + db_column = self._model_field_to_db_column(resolved_model, field_name) + if table_alias is None: + return f'"{db_column}"' + return f'{table_alias}."{db_column}"' + def filter(self, **conditions: FilterValue) -> Self: """Apply filter conditions to the query. @@ -294,7 +316,11 @@ def _handle_relationship_filter( raise InvalidFilterError(error_msg) # Apply filter with table alias - qualified_field = f'{join_info.alias}."{target_field}"' + qualified_field = self._column_sql( + target_field, + model_class=join_info.model_class, + table_alias=join_info.alias, + ) # Use the appropriate handler # Note: __isnull/__notnull operators don't reach here due to @@ -474,7 +500,7 @@ def having(self, **conditions: FilterValue) -> Self: for field, value in conditions.items(): field_name, operator = self._parse_field_operator(field) if field_name in allowed_group_fields: - field_sql = f't0."{field_name}"' + field_sql = self._column_sql(field_name, table_alias="t0") elif field_name in allowed_aggregate_aliases: field_sql = f'"{field_name}"' else: @@ -1248,7 +1274,7 @@ def _validate_and_build_join_info(self, path: str) -> None: # ORM FK descriptor fk_descriptor = fk_descriptors[segment] to_model = fk_descriptor.to_model - fk_column = f"{segment}_id" + fk_column = fk_descriptor.fk_info.db_column or f"{segment}_id" is_nullable = fk_descriptor.fk_info.null # Create alias for this join using global counter @@ -1613,7 +1639,11 @@ def _build_aggregate_expression( elif aggregate_spec.field in {None, "*"}: field_sql = "*" else: - field_sql = f't0."{aggregate_spec.field}"' + field_name = cast("str", aggregate_spec.field) + field_sql = self._column_sql( + field_name, + table_alias="t0", + ) if aggregate_spec.distinct and field_sql == "*": msg = ( @@ -1645,7 +1675,8 @@ def _build_projection_select_parts(self) -> tuple[list[str], list[str]]: projection_columns: list[str] = [] for field_name in self._group_by: - select_parts.append(f't0."{field_name}" AS "{field_name}"') + column_sql = self._column_sql(field_name, table_alias="t0") + select_parts.append(f'{column_sql} AS "{field_name}"') projection_columns.append(field_name) for alias, aggregate_spec in self._aggregates.items(): @@ -1669,7 +1700,8 @@ def _build_projection_order_clause(self) -> str: field_name = match.group(1) direction = match.group(2) if field_name in self.model_class.model_fields: - return f' ORDER BY t0."{field_name}" {direction}' + field_sql = self._column_sql(field_name, table_alias="t0") + return f" ORDER BY {field_sql} {direction}" if field_name in self._aggregates: return f' ORDER BY "{field_name}" {direction}' @@ -1704,7 +1736,8 @@ def _build_projection_sql(self) -> tuple[str, list[Any], list[str]]: if self._group_by: grouped_columns = ", ".join( - f't0."{field}"' for field in self._group_by + self._column_sql(field, table_alias="t0") + for field in self._group_by ) sql += f" GROUP BY {grouped_columns}" @@ -1802,7 +1835,8 @@ def _build_join_sql( # Main table columns (t0) for field in self.model_class.model_fields: alias = f"t0__{field}" - select_parts.append(f't0."{field}" AS "{alias}"') + column_sql = self._column_sql(field, table_alias="t0") + select_parts.append(f'{column_sql} AS "{alias}"') column_names.append(("t0", field, self.model_class)) # Add JOINed table columns @@ -1818,7 +1852,12 @@ def _build_join_sql( # Add columns from joined table for field in join.model_class.model_fields: alias = f"{join.alias}__{field}" - select_parts.append(f'{join.alias}."{field}" AS "{alias}"') + column_sql = self._column_sql( + field, + model_class=join.model_class, + table_alias=join.alias, + ) + select_parts.append(f'{column_sql} AS "{alias}"') column_names.append((join.alias, field, join.model_class)) select_clause = ", ".join(select_parts) @@ -1977,7 +2016,10 @@ def _execute_query( # noqa: C901, PLR0912, PLR0915 ) elif self._fields: # Build custom field selection with JOINs - field_list = ", ".join(f't0."{f}"' for f in self._fields) + field_list = ", ".join( + self._column_sql(field, table_alias="t0") + for field in self._fields + ) # table_name and fields validated - safe from SQL injection sql = ( f"SELECT {field_list} FROM " # noqa: S608 @@ -2008,7 +2050,8 @@ def _execute_query( # noqa: C901, PLR0912, PLR0915 if match: field_name = match.group(1) direction = match.group(2) - sql += f' ORDER BY t0."{field_name}" {direction}' + field_sql = self._column_sql(field_name, table_alias="t0") + sql += f" ORDER BY {field_sql} {direction}" elif self._order_by.lower().startswith("rowid"): # Fallback for non-quoted patterns such as "rowid DESC" sql += f" ORDER BY t0.{self._order_by}" @@ -2039,10 +2082,13 @@ def _execute_query( # noqa: C901, PLR0912, PLR0915 elif self._fields: if "pk" not in self._fields: self._fields.append("pk") - fields = ", ".join(f'"{field}"' for field in self._fields) + fields = ", ".join( + self._column_sql(field) for field in self._fields + ) else: fields = ", ".join( - f'"{field}"' for field in self.model_class.model_fields + self._column_sql(field) + for field in self.model_class.model_fields ) sql = f'SELECT {fields} FROM "{self.table_name}"' # noqa: S608 @@ -2054,7 +2100,14 @@ def _execute_query( # noqa: C901, PLR0912, PLR0915 sql += f" WHERE {where_clause}" if self._order_by: - sql += f" ORDER BY {self._order_by}" + match = re.match(r'"([^"]+)"\s+(ASC|DESC)', self._order_by) + if match: + field_name = match.group(1) + direction = match.group(2) + field_sql = self._column_sql(field_name) + sql += f" ORDER BY {field_sql} {direction}" + else: + sql += f" ORDER BY {self._order_by}" if self._limit is not None: sql += " LIMIT ?" @@ -2074,31 +2127,51 @@ def _execute_query( # noqa: C901, PLR0912, PLR0915 else: return (results, []) # Empty column_names for backward compat - def _qualify_base_field_name(self, field_name: str) -> str: - """Qualify a base-model field name with the main JOIN alias. + def _render_base_field_name( + self, field_name: str, *, qualify_base_fields: bool + ) -> str: + """Render a base-model field name as SQL for WHERE/HAVING clauses. Args: field_name: Raw field expression used in filters. + qualify_base_fields: Whether to qualify base-model fields with + ``t0`` alias. Returns: - The qualified field expression when it targets the base table. + SQL field expression for the base table when ``field_name`` + targets a model field. """ if ( field_name in self.model_class.model_fields and "." not in field_name and '"' not in field_name ): - return f't0."{field_name}"' + return self._column_sql( + field_name, + table_alias="t0" if qualify_base_fields else None, + ) return field_name - def _qualify_base_filter_clause(self, clause: str) -> str: - """Qualify the leading base-model field in a SQL filter clause. + def _qualify_base_field_name(self, field_name: str) -> str: + """Backwards-compatible wrapper for qualified base field rendering.""" + return self._render_base_field_name( + field_name, + qualify_base_fields=True, + ) + + def _render_base_filter_clause( + self, clause: str, *, qualify_base_fields: bool + ) -> str: + """Render the leading base-model field in a SQL filter clause. Args: clause: A single SQL clause fragment from the filter stack. + qualify_base_fields: Whether to qualify base-model fields with + ``t0`` alias. Returns: - Clause with the base-model field qualified for JOIN queries. + Clause with the base-model field rendered using mapped DB column + names, optionally qualified for JOIN queries. """ if _RE_ALIAS_PREFIX.match(clause): return clause @@ -2111,8 +2184,18 @@ def _qualify_base_filter_clause(self, clause: str) -> str: if field_name not in self.model_class.model_fields: return clause - qualified = f't0."{field_name}"' - return f"{qualified}{clause[match.end() :]}" + rendered = self._column_sql( + field_name, + table_alias="t0" if qualify_base_fields else None, + ) + return f"{rendered}{clause[match.end() :]}" + + def _qualify_base_filter_clause(self, clause: str) -> str: + """Backwards-compatible wrapper for qualifying base filter clauses.""" + return self._render_base_filter_clause( + clause, + qualify_base_fields=True, + ) def _parse_filter( self, *, qualify_base_fields: bool = False @@ -2132,18 +2215,16 @@ def _parse_filter( values = [] for field, value, operator in self.filters: if operator == "__eq": - field_expr = ( - self._qualify_base_field_name(field) - if qualify_base_fields - else field + field_expr = self._render_base_field_name( + field, + qualify_base_fields=qualify_base_fields, ) where_clauses.append(f"{field_expr} = ?") values.append(value) else: - clause = ( - self._qualify_base_filter_clause(field) - if qualify_base_fields - else field + clause = self._render_base_filter_clause( + field, + qualify_base_fields=qualify_base_fields, ) where_clauses.append(clause) if operator not in ["__isnull", "__notnull"]: @@ -2695,7 +2776,10 @@ def update(self, values: dict[str, Any]) -> int: for field_name, value in values.items(): # Serialize the value if needed serialized = self.model_class.serialize_field(value) - set_clauses.append(f'"{field_name}" = ?') + db_column = self._model_field_to_db_column( + self.model_class, field_name + ) + set_clauses.append(f'"{db_column}" = ?') set_values.append(serialized) set_clause = ", ".join(set_clauses) diff --git a/sqliter/sqliter.py b/sqliter/sqliter.py index 0ab26b28..32b9d146 100644 --- a/sqliter/sqliter.py +++ b/sqliter/sqliter.py @@ -31,7 +31,11 @@ TableDeletionError, ) from sqliter.helpers import infer_sqlite_type -from sqliter.model.foreign_key import ForeignKeyInfo, get_foreign_key_info +from sqliter.model.foreign_key import ( + ForeignKeyInfo, + get_foreign_key_info, + get_model_field_db_column, +) from sqliter.model.model import BaseDBModel from sqliter.query.query import QueryBuilder @@ -896,6 +900,33 @@ def _create_instance_from_data( instance.db_context = self return instance + @staticmethod + def _model_field_to_db_column( + model_class: type[BaseDBModel], field_name: str + ) -> str: + """Resolve a model field name to its database column name.""" + return get_model_field_db_column(model_class, field_name) + + def _map_data_to_db_columns( + self, model_class: type[BaseDBModel], data: dict[str, Any] + ) -> dict[str, Any]: + """Map model-field keyed data to database-column keyed data.""" + return { + self._model_field_to_db_column(model_class, field_name): value + for field_name, value in data.items() + } + + def _build_model_select_list(self, model_class: type[BaseDBModel]) -> str: + """Build a SELECT column list that maps DB columns to model fields.""" + select_parts: list[str] = [] + for field_name in model_class.model_fields: + db_column = self._model_field_to_db_column(model_class, field_name) + if db_column == field_name: + select_parts.append(f'"{db_column}"') + else: + select_parts.append(f'"{db_column}" AS "{field_name}"') + return ", ".join(select_parts) + def insert( self, model_instance: T, *, timestamp_override: bool = False ) -> T: @@ -934,11 +965,17 @@ def insert( if data.get("pk", None) == 0: data.pop("pk") - fields = ", ".join(data.keys()) + sql_data = self._map_data_to_db_columns(model_class, data) + fields = ", ".join(f'"{field}"' for field in sql_data) placeholders = ", ".join( - ["?" if value is not None else "NULL" for value in data.values()] + [ + "?" if value is not None else "NULL" + for value in sql_data.values() + ] + ) + values = tuple( + value for value in sql_data.values() if value is not None ) - values = tuple(value for value in data.values() if value is not None) insert_sql = f""" INSERT INTO {table_name} ({fields}) @@ -1009,11 +1046,12 @@ def _insert_single_record( if data.get("pk") == 0: data.pop("pk") - fields = ", ".join(data.keys()) + sql_data = self._map_data_to_db_columns(model_class, data) + fields = ", ".join(f'"{field}"' for field in sql_data) placeholders = ", ".join( - "?" if v is not None else "NULL" for v in data.values() + "?" if v is not None else "NULL" for v in sql_data.values() ) - values = tuple(v for v in data.values() if v is not None) + values = tuple(v for v in sql_data.values() if v is not None) insert_sql = ( f"INSERT INTO {table_name} ({fields}) " # noqa: S608 @@ -1134,7 +1172,7 @@ def get( if hit: return cast("Optional[T]", cached) - fields = ", ".join(model_class.model_fields) + fields = self._build_model_select_list(model_class) select_sql = f""" SELECT {fields} FROM {table_name} WHERE {primary_key} = ? @@ -1194,8 +1232,9 @@ def update(self, model_instance: BaseDBModel) -> None: primary_key_value = data.pop(primary_key) # Create the SQL using the processed data - fields = ", ".join(f"{field} = ?" for field in data) - values = tuple(data.values()) + sql_data = self._map_data_to_db_columns(model_class, data) + fields = ", ".join(f'"{field}" = ?' for field in sql_data) + values = tuple(sql_data.values()) update_sql = f""" UPDATE {table_name} From f93291f1b8f68f91663f633806f64901b719f8d0 Mon Sep 17 00:00:00 2001 From: Grant Ramsay Date: Mon, 23 Feb 2026 19:37:26 +0000 Subject: [PATCH 2/5] Add db_column runtime regression coverage and align debug SQL tests --- tests/test_debug_logging.py | 14 ++-- tests/test_foreign_keys.py | 83 +++++++++++++++++++ tests/test_foreign_keys_orm.py | 146 +++++++++++++++++++++++++++++++++ tests/test_query.py | 6 ++ 4 files changed, 242 insertions(+), 7 deletions(-) diff --git a/tests/test_debug_logging.py b/tests/test_debug_logging.py index 41cf0011..abb30429 100644 --- a/tests/test_debug_logging.py +++ b/tests/test_debug_logging.py @@ -41,7 +41,7 @@ def test_debug_sql_output_basic_query( assert ( 'Executing SQL: SELECT "pk", "created_at", "updated_at", "name", ' '"age", "is_active", "score", ' - '"nullable_field" FROM "complex_model" WHERE age = 30.5' + '"nullable_field" FROM "complex_model" WHERE "age" = 30.5' in caplog.text ) @@ -58,7 +58,7 @@ def test_debug_sql_output_string_values( assert ( 'Executing SQL: SELECT "pk", "created_at", "updated_at", "name", ' '"age", "is_active", "score", ' - '"nullable_field" FROM "complex_model" WHERE name = \'Alice\'' + '"nullable_field" FROM "complex_model" WHERE "name" = \'Alice\'' in caplog.text ) @@ -75,8 +75,8 @@ def test_debug_sql_output_multiple_conditions( assert ( 'Executing SQL: SELECT "pk", "created_at", "updated_at", "name", ' '"age", "is_active", "score", ' - '"nullable_field" FROM "complex_model" WHERE name = \'Alice\' AND ' - "age = 30.5" in caplog.text + '"nullable_field" FROM "complex_model" WHERE "name" = \'Alice\' ' + 'AND "age" = 30.5' in caplog.text ) def test_debug_sql_output_order_and_limit( @@ -120,7 +120,7 @@ def test_debug_sql_output_with_null_value( assert ( 'Executing SQL: SELECT "pk", "created_at", "updated_at", "name", ' '"age", "is_active", "score", ' - '"nullable_field" FROM "complex_model" WHERE age IS NULL' + '"nullable_field" FROM "complex_model" WHERE "age" IS NULL' in caplog.text ) @@ -166,7 +166,7 @@ def test_debug_sql_output_with_fields_and_filter( # Assert the SQL query selects 'name' and 'score' and applies the filter assert ( 'Executing SQL: SELECT "name", "score", "pk" FROM "complex_model" ' - "WHERE score > 85" in caplog.text + 'WHERE "score" > 85' in caplog.text ) def test_no_log_output_when_debug_false(self, caplog) -> None: @@ -235,7 +235,7 @@ def test_debug_sql_output_no_matching_records( assert ( 'Executing SQL: SELECT "pk", "created_at", "updated_at", "name", ' '"age", "is_active", "score", ' - '"nullable_field" FROM "complex_model" WHERE age = 100' + '"nullable_field" FROM "complex_model" WHERE "age" = 100' in caplog.text ) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 671971fb..b45a2bdb 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -4,6 +4,7 @@ from typing import Optional import pytest +from pydantic import Field from sqliter import SqliterDB from sqliter.exceptions import ( @@ -14,6 +15,7 @@ RecordInsertionError, ) from sqliter.model import BaseDBModel, ForeignKey, get_foreign_key_info +from sqliter.model.foreign_key import get_model_field_db_column class Author(BaseDBModel): @@ -627,6 +629,21 @@ def test_field_without_json_schema_extra_attribute(self, mocker) -> None: assert fk_info is None + def test_get_model_field_db_column_unknown_field(self) -> None: + """Unknown model fields should map to themselves.""" + assert get_model_field_db_column(Author, "unknown_field") == ( + "unknown_field" + ) + + def test_get_foreign_key_info_dict_without_fk_entry(self) -> None: + """Dict json_schema_extra without FK metadata should return None.""" + + class TestBook(BaseDBModel): + title: str = Field(..., json_schema_extra={"note": "plain"}) + + field_info = TestBook.model_fields["title"] + assert get_foreign_key_info(field_info) is None + class TestForeignKeyDatabaseErrors: """Test database error handling for FK operations.""" @@ -714,3 +731,69 @@ class TestBook(BaseDBModel): assert fk_info is not None assert fk_info.null is False + + +class TestForeignKeyDbColumnRuntime: + """Runtime behavior tests for explicit FK fields with db_column.""" + + def test_insert_get_filter_order_with_custom_db_column(self) -> None: + """CRUD/query paths should accept model field names with db_column.""" + + class CustomBook(BaseDBModel): + title: str + writer_id: int = ForeignKey( + Author, on_delete="CASCADE", db_column="auth_id" + ) + + db = SqliterDB(":memory:") + db.create_table(Author) + db.create_table(CustomBook) + + author = db.insert(Author(name="Alice", email="alice@example.com")) + book = db.insert(CustomBook(title="A1", writer_id=author.pk)) + + fetched = db.get(CustomBook, book.pk) + assert fetched is not None + assert fetched.writer_id == author.pk + + rows = ( + db.select(CustomBook) + .filter(writer_id=author.pk) + .order("writer_id") + .fetch_all() + ) + assert len(rows) == 1 + assert rows[0].title == "A1" + + def test_update_and_bulk_update_with_custom_db_column(self) -> None: + """db.update and QueryBuilder.update should honor db_column mapping.""" + + class CustomBook(BaseDBModel): + title: str + writer_id: int = ForeignKey( + Author, on_delete="CASCADE", db_column="auth_id" + ) + + db = SqliterDB(":memory:") + db.create_table(Author) + db.create_table(CustomBook) + + alice = db.insert(Author(name="Alice", email="alice@example.com")) + bob = db.insert(Author(name="Bob", email="bob@example.com")) + book = db.insert(CustomBook(title="A1", writer_id=alice.pk)) + + book.writer_id = bob.pk + db.update(book) + refreshed = db.get(CustomBook, book.pk) + assert refreshed is not None + assert refreshed.writer_id == bob.pk + + updated = ( + db.select(CustomBook) + .filter(title="A1") + .update({"writer_id": alice.pk}) + ) + assert updated == 1 + final = db.get(CustomBook, book.pk) + assert final is not None + assert final.writer_id == alice.pk diff --git a/tests/test_foreign_keys_orm.py b/tests/test_foreign_keys_orm.py index 4c40e722..b9430721 100644 --- a/tests/test_foreign_keys_orm.py +++ b/tests/test_foreign_keys_orm.py @@ -40,6 +40,24 @@ class Book(BaseDBModel): author: ForeignKey[Author] = ForeignKey(Author, on_delete="CASCADE") +class CustomColAuthor(BaseDBModel): + """Author model for custom db_column FK runtime tests.""" + + name: str + + +class CustomColBook(BaseDBModel): + """Book model whose FK uses a custom database column name.""" + + title: str + author: ForeignKey[CustomColAuthor] = ForeignKey( + CustomColAuthor, + on_delete="CASCADE", + related_name="custom_books", + db_column="author_ref", + ) + + class Publisher(BaseDBModel): """Test model for a publisher.""" @@ -254,6 +272,134 @@ class CustomBook(BaseDBModel): assert len(books) == 2 +class TestCustomDbColumnRuntime: + """End-to-end runtime tests for ORM FK db_column mappings.""" + + def test_insert_and_get_use_custom_fk_db_column( + self, db: SqliterDB + ) -> None: + """insert()/get() should persist and hydrate custom FK db columns.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + book = db.insert(CustomColBook(title="A1", author=alice)) + + fetched = db.get(CustomColBook, book.pk) + assert fetched is not None + assert fetched.author_id == alice.pk + assert fetched.author.name == "Alice" + + def test_select_filter_and_order_use_model_fk_field_names( + self, db: SqliterDB + ) -> None: + """Query API should accept model FK field names with custom columns.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + bob = db.insert(CustomColAuthor(name="Bob")) + db.insert(CustomColBook(title="A1", author=alice)) + db.insert(CustomColBook(title="A2", author=alice)) + db.insert(CustomColBook(title="B1", author=bob)) + + rows = ( + db.select(CustomColBook) + .filter(author_id=alice.pk) + .order("author_id") + .fetch_all() + ) + assert [row.title for row in rows] == ["A1", "A2"] + + def test_select_related_and_relationship_filter_with_custom_fk_column( + self, db: SqliterDB + ) -> None: + """Relationship traversal should join on custom FK db columns.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + bob = db.insert(CustomColAuthor(name="Bob")) + db.insert(CustomColBook(title="A1", author=alice)) + db.insert(CustomColBook(title="B1", author=bob)) + + rows = ( + db.select(CustomColBook) + .select_related("author") + .filter(author__name="Alice") + .fetch_all() + ) + + assert len(rows) == 1 + assert rows[0].title == "A1" + assert rows[0].author.name == "Alice" + + def test_prefetch_reverse_fk_with_custom_fk_column( + self, db: SqliterDB + ) -> None: + """prefetch_related should resolve reverse FK via custom db columns.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + bob = db.insert(CustomColAuthor(name="Bob")) + db.insert(CustomColBook(title="A1", author=alice)) + db.insert(CustomColBook(title="A2", author=alice)) + db.insert(CustomColBook(title="B1", author=bob)) + + authors = ( + db.select(CustomColAuthor) + .prefetch_related("custom_books") + .order("name") + .fetch_all() + ) + counts = { + author.name: len(author.custom_books.fetch_all()) + for author in authors + } + assert counts == {"Alice": 2, "Bob": 1} + + def test_model_update_uses_custom_fk_db_column(self, db: SqliterDB) -> None: + """db.update(model_instance) should update the mapped FK db column.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + bob = db.insert(CustomColAuthor(name="Bob")) + book = db.insert(CustomColBook(title="A1", author=alice)) + + book.author = bob + db.update(book) + + refreshed = db.get(CustomColBook, book.pk) + assert refreshed is not None + assert refreshed.author_id == bob.pk + assert refreshed.author.name == "Bob" + + def test_querybuilder_update_uses_custom_fk_db_column( + self, db: SqliterDB + ) -> None: + """QueryBuilder.update() should map FK fields to custom db columns.""" + db.create_table(CustomColAuthor) + db.create_table(CustomColBook) + + alice = db.insert(CustomColAuthor(name="Alice")) + bob = db.insert(CustomColAuthor(name="Bob")) + db.insert(CustomColBook(title="A1", author=alice)) + + updated = ( + db.select(CustomColBook) + .filter(title="A1") + .update({"author_id": bob.pk}) + ) + + assert updated == 1 + book = db.select(CustomColBook).filter(title="A1").fetch_one() + assert book is not None + assert book.author_id == bob.pk + assert book.author.name == "Bob" + + class TestModelRegistry: """Test suite for ModelRegistry.""" diff --git a/tests/test_query.py b/tests/test_query.py index ff1f9aed..d4002efe 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -680,6 +680,12 @@ def test_qualify_base_filter_clause_already_aliased(self, db_mock) -> None: assert qualified == clause + def test_qualify_base_field_name_wrapper(self, db_mock) -> None: + """Test wrapper method forwards to qualified base-field rendering.""" + query = db_mock.select(ExampleModel) + assert query._qualify_base_field_name("pk") == 't0."pk"' + assert query._qualify_base_field_name("unknown") == "unknown" + def test_fetch_result_with_list_of_tuples(self, mocker) -> None: """Test _fetch_result when _execute_query returns list of tuples.""" # ensure we get a dependable timestamp From f282f08dd496400598e9ce6c98f4b6a14fb1e54a Mon Sep 17 00:00:00 2001 From: Grant Ramsay Date: Mon, 23 Feb 2026 19:37:43 +0000 Subject: [PATCH 3/5] Document and demo custom FK db_column runtime behavior --- docs/api-reference/orm.md | 17 +++++++++- docs/guide/foreign-keys/orm.md | 28 +++++++++++++-- docs/tui-demo/orm.md | 62 ++++++++++++++++++++++++++++++++++ sqliter/tui/demos/orm.py | 59 ++++++++++++++++++++++++++++++++ 4 files changed, 163 insertions(+), 3 deletions(-) diff --git a/docs/api-reference/orm.md b/docs/api-reference/orm.md index 0142e3d0..344924f4 100644 --- a/docs/api-reference/orm.md +++ b/docs/api-reference/orm.md @@ -120,7 +120,7 @@ class ForeignKey(Generic[T]): | `null` | `bool` | `False` | Whether FK can be null | | `unique` | `bool` | `False` | Whether FK must be unique (one-to-one) | | `related_name` | `str` | `None` | `None` | Name for reverse relationship (auto-generated if `None`) | -| `db_column` | `str` | `None` | `None` | Custom column name for `_id` field | +| `db_column` | `str` | `None` | `None` | Custom database column for the generated `_id` field | **Example:** @@ -139,6 +139,21 @@ class Book(BaseDBModel): ) ``` +When `db_column` is set, runtime operations still use model field names while +SQL is generated against the mapped database column: + +```python +class Book(BaseDBModel): + title: str + author: ForeignKey[Author] = ForeignKey( + Author, + db_column="author_ref", + ) + +# Still uses model field names: +db.select(Book).filter(author_id=1).order("author_id").fetch_all() +``` + ### `ForeignKeyDescriptor` > [!CAUTION] diff --git a/docs/guide/foreign-keys/orm.md b/docs/guide/foreign-keys/orm.md index 9da69a8b..c928e51a 100644 --- a/docs/guide/foreign-keys/orm.md +++ b/docs/guide/foreign-keys/orm.md @@ -31,8 +31,8 @@ db.create_table(Book) > [!NOTE] > > When using ORM foreign keys, SQLiter automatically creates an `author_id` -> field in the database. You define `author` (without `_id`) in your model and -> access it for lazy loading. +> model field. By default the database column is also `author_id`, but you can +> override the physical column name with `db_column=...`. ## Database Context @@ -51,6 +51,30 @@ book.db_context = db # Set manually for lazy loading to work print(book.author.name) ``` +## Custom DB Column Names + +You can keep ORM-style access (`book.author` / `author_id`) while storing the +FK value in a different database column: + +```python +class Author(BaseDBModel): + name: str + +class Book(BaseDBModel): + title: str + author: ForeignKey[Author] = ForeignKey( + Author, + db_column="author_ref", + ) +``` + +With this configuration: + +- Model-level access still uses `author_id` and `author` +- SQLiter maps runtime CRUD/query operations to `author_ref` in SQL +- Filtering and ordering keep using model field names (for example, + `.filter(author_id=1).order("author_id")`) + ## Lazy Loading When you access a foreign key field, SQLiter automatically loads the related diff --git a/docs/tui-demo/orm.md b/docs/tui-demo/orm.md index f9e1bb19..7aa575a8 100644 --- a/docs/tui-demo/orm.md +++ b/docs/tui-demo/orm.md @@ -127,6 +127,68 @@ db.close() # --8<-- [end:insert-foreign-key] ``` +## Custom FK `db_column` + +Use a custom DB column name for an ORM FK while keeping the model API on +`author_id`. + +```python +# --8<-- [start:custom-fk-db-column] +from sqliter import SqliterDB +from sqliter.orm import BaseDBModel, ForeignKey + +class Author(BaseDBModel): + name: str + +class Book(BaseDBModel): + title: str + author: ForeignKey[Author] = ForeignKey( + Author, + db_column="author_ref", + related_name="books", + ) + +db = SqliterDB(memory=True) +db.create_table(Author) +db.create_table(Book) + +alice = db.insert(Author(name="Alice")) +bob = db.insert(Author(name="Bob")) +db.insert(Book(title="A1", author=alice)) +db.insert(Book(title="A2", author=alice)) + +rows = ( + db.select(Book) + .filter(author_id=alice.pk) + .order("author_id") + .select_related("author") + .fetch_all() +) + +print("Books filtered by model field author_id:") +for row in rows: + print(f" {row.title} -> {row.author.name}") + +db.select(Book).filter(title="A1").update({"author_id": bob.pk}) +updated = db.select(Book).filter(title="A1").fetch_one() + +if updated is not None: + print("\nAfter update(author_id=...):") + print(f" {updated.title} -> {updated.author.name}") + +print("\nModel API uses author_id while SQL stores it in author_ref.") + +db.close() +# --8<-- [end:custom-fk-db-column] +``` + +### What Happens + +- `db_column="author_ref"` customizes only the physical column name +- ORM and QueryBuilder calls still use model field names like `author_id` +- `insert`, `get`, `filter`, `order`, `select_related`, and `update` all map + through to the custom DB column + ### Storage vs Access - **Storage**: The `author` field stores only the primary key (integer) diff --git a/sqliter/tui/demos/orm.py b/sqliter/tui/demos/orm.py index 4111767d..caa1295c 100644 --- a/sqliter/tui/demos/orm.py +++ b/sqliter/tui/demos/orm.py @@ -82,6 +82,57 @@ class Book(BaseDBModel): return output.getvalue() +def _run_custom_fk_db_column() -> str: + """Use a custom FK db_column while keeping model-level _id access.""" + output = io.StringIO() + + class Author(BaseDBModel): + name: str + + class Book(BaseDBModel): + title: str + author: ForeignKey[Author] = ForeignKey( + Author, + db_column="author_ref", + related_name="books", + ) + + db = SqliterDB(memory=True) + db.create_table(Author) + db.create_table(Book) + + alice = db.insert(Author(name="Alice")) + bob = db.insert(Author(name="Bob")) + db.insert(Book(title="A1", author=alice)) + db.insert(Book(title="A2", author=alice)) + + rows = ( + db.select(Book) + .filter(author_id=alice.pk) + .order("author_id") + .select_related("author") + .fetch_all() + ) + + output.write("Books filtered by model field author_id:\n") + for row in rows: + output.write(f" {row.title} -> {row.author.name}\n") + + db.select(Book).filter(title="A1").update({"author_id": bob.pk}) + updated = db.select(Book).filter(title="A1").fetch_one() + + if updated is not None: + output.write("\nAfter update(author_id=...):\n") + output.write(f" {updated.title} -> {updated.author.name}\n") + + output.write( + "\nModel API uses author_id while SQL stores it in author_ref.\n" + ) + + db.close() + return output.getvalue() + + def _run_nullable_foreign_key() -> str: """Declare nullable FKs using Optional[T] in the type annotation. @@ -775,6 +826,14 @@ def get_category() -> DemoCategory: code=extract_demo_code(_run_orm_style_access), execute=_run_orm_style_access, ), + Demo( + id="orm_custom_fk_db_column", + title="Custom FK db_column", + description="Use author_id API with a custom DB column", + category="orm", + code=extract_demo_code(_run_custom_fk_db_column), + execute=_run_custom_fk_db_column, + ), Demo( id="orm_nullable_fk", title="Nullable Foreign Keys", From 817f578d48b570443064ed75c5c39c3a009a2651 Mon Sep 17 00:00:00 2001 From: Grant Ramsay Date: Mon, 23 Feb 2026 19:48:17 +0000 Subject: [PATCH 4/5] Polish db_column docs and add explicit FK runtime edge-case tests --- docs/guide/foreign-keys/explicit.md | 13 +++++ tests/test_foreign_keys.py | 82 +++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/docs/guide/foreign-keys/explicit.md b/docs/guide/foreign-keys/explicit.md index f751e359..17496c37 100644 --- a/docs/guide/foreign-keys/explicit.md +++ b/docs/guide/foreign-keys/explicit.md @@ -67,6 +67,19 @@ class Book(BaseDBModel): ) ``` +When `db_column` is set, your Python API still uses the model field name +(`author_id`). SQLiter maps runtime SQL generation to the configured database +column (`writer_id`) for insert/get/filter/order/update operations. + +```python +author = db.insert(Author(name="Jane Austen", email="jane@example.com")) +book = db.insert(Book(title="Emma", author_id=author.pk)) + +# Uses model field names in Python: +rows = db.select(Book).filter(author_id=author.pk).order("author_id").fetch_all() +db.select(Book).filter(pk=book.pk).update({"author_id": author.pk}) +``` + ## Type Checking The examples in this documentation show the simplest syntax that works at diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index b45a2bdb..6ad445a5 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -797,3 +797,85 @@ class CustomBook(BaseDBModel): final = db.get(CustomBook, book.pk) assert final is not None assert final.writer_id == alice.pk + + def test_field_selection_and_boundary_fetch_with_custom_db_column( + self, + ) -> None: + """Field selection APIs should map custom FK db columns correctly.""" + + class CustomBook(BaseDBModel): + title: str + writer_id: int = ForeignKey( + Author, on_delete="CASCADE", db_column="auth_id" + ) + + db = SqliterDB(":memory:") + db.create_table(Author) + db.create_table(CustomBook) + + alice = db.insert(Author(name="Alice", email="alice@example.com")) + bob = db.insert(Author(name="Bob", email="bob@example.com")) + db.insert(CustomBook(title="A1", writer_id=alice.pk)) + db.insert(CustomBook(title="A2", writer_id=alice.pk)) + db.insert(CustomBook(title="B1", writer_id=bob.pk)) + + selected = ( + db.select(CustomBook).fields(["title", "writer_id"]).fetch_all() + ) + assert len(selected) == 3 + assert {row.writer_id for row in selected} == {alice.pk, bob.pk} + + only_rows = ( + db.select(CustomBook) + .only("writer_id") + .filter(writer_id=alice.pk) + .fetch_all() + ) + assert len(only_rows) == 2 + assert all(row.writer_id == alice.pk for row in only_rows) + + excluded = db.select(CustomBook).exclude(["title"]).fetch_all() + assert len(excluded) == 3 + assert {row.writer_id for row in excluded} == {alice.pk, bob.pk} + + first = db.select(CustomBook).order("pk").fetch_first() + last = db.select(CustomBook).fetch_last() + assert first is not None + assert last is not None + assert first.title == "A1" + assert last.title == "B1" + + def test_in_and_not_in_filters_with_custom_db_column(self) -> None: + """__in and __not_in should use custom FK db columns correctly.""" + + class CustomBook(BaseDBModel): + title: str + writer_id: int = ForeignKey( + Author, on_delete="CASCADE", db_column="auth_id" + ) + + db = SqliterDB(":memory:") + db.create_table(Author) + db.create_table(CustomBook) + + alice = db.insert(Author(name="Alice", email="alice@example.com")) + bob = db.insert(Author(name="Bob", email="bob@example.com")) + db.insert(CustomBook(title="A1", writer_id=alice.pk)) + db.insert(CustomBook(title="A2", writer_id=alice.pk)) + db.insert(CustomBook(title="B1", writer_id=bob.pk)) + + in_rows = ( + db.select(CustomBook) + .filter(writer_id__in=[alice.pk]) + .order("title") + .fetch_all() + ) + assert [row.title for row in in_rows] == ["A1", "A2"] + + not_in_rows = ( + db.select(CustomBook) + .filter(writer_id__not_in=[alice.pk]) + .fetch_all() + ) + assert len(not_in_rows) == 1 + assert not_in_rows[0].title == "B1" From 58a1e974951cb13831f7fee9619913e294c6cceb Mon Sep 17 00:00:00 2001 From: Grant Ramsay Date: Tue, 24 Feb 2026 20:20:23 +0000 Subject: [PATCH 5/5] Fix punctuation in ORM foreign key docs --- docs/guide/foreign-keys/orm.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guide/foreign-keys/orm.md b/docs/guide/foreign-keys/orm.md index c928e51a..c2266c59 100644 --- a/docs/guide/foreign-keys/orm.md +++ b/docs/guide/foreign-keys/orm.md @@ -31,7 +31,7 @@ db.create_table(Book) > [!NOTE] > > When using ORM foreign keys, SQLiter automatically creates an `author_id` -> model field. By default the database column is also `author_id`, but you can +> model field. By default, the database column is also `author_id`, but you can > override the physical column name with `db_column=...`. ## Database Context