diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c052d04..38ea4e8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,6 +10,21 @@ jobs: name: pytest (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest + services: + postgres: + image: postgres:16 + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: orm_loader_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 5s + --health-retries 10 + strategy: fail-fast: false matrix: @@ -30,9 +45,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e ".[dev]" + pip install -e ".[dev,postgres]" - name: Run pytest env: PYTHONPATH: src - run: pytest -m "not postgres" \ No newline at end of file + TEST_POSTGRES_URL: postgresql+psycopg://test:test@localhost:5432/orm_loader_test + run: pytest diff --git a/.gitignore b/.gitignore index b6dc481..8475748 100644 --- a/.gitignore +++ b/.gitignore @@ -211,3 +211,5 @@ OMOP_CDM*.csv *.db .vscode/ .DS_Store +_temp/ +notebooks/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e912da..3580cb4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -108,4 +108,10 @@ - literally just removing stale sqlalchemy-utils dependency # 0.3.27 -- adding minimum versions for dependabot alerts (dev deps only) \ No newline at end of file +- adding minimum versions for dependabot alerts (dev deps only) + +# 0.4.0 +- update to handle psycopg (as opposed to psycopg2) cleanly +- overall api cleanup with the goal of being more explicit about selection of specific db backends +- general typing cleanup +- removed example notebooks until they can be cleaned up with working use-cases according to updated api \ No newline at end of file diff --git a/README.md b/README.md index 63c4644..c0b1c98 100644 --- a/README.md +++ b/README.md @@ -4,28 +4,27 @@ https://github.com/AustralianCancerDataNetwork/orm-loader/actions/workflows/tests.yml ) -A lightweight, reusable foundation for building and validating SQLAlchemy-based clinical (and non-clinical) data models. +A lightweight foundation for building and validating SQLAlchemy-based data models. -This library provides general-purpose ORM infrastructure that sits below any specific data model (OMOP, PCORnet, custom CDMs, etc.), focusing on: +`orm-loader` sits below any particular schema or CDM. It gives you a small set of reusable pieces for defining tables, loading files through staging tables, and checking models against external specifications. It stays out of domain logic on purpose. -* declarative base configuration -* bulk ingestion patterns -* file-based validation & loading -* table introspection -* model-agnostic validation scaffolding -* safe, database-portable operational helpers +The library focuses on: -It intentionally contains no domain logic and no assumptions about a specific schema. +* ORM table mixins and introspection +* staged file loading +* loader and validation infrastructure +* operational helpers that work across supported backends +At the moment, the built-in backends are SQLite and PostgreSQL. -### What this library provides: -This library provides a small set of composable building blocks for defining, loading, inspecting, and validating SQLAlchemy-based data models. -All components are model-agnostic and can be selectively combined in downstream libraries. +### What this library provides -1. A minimal, opinionated ORM table base +The package is deliberately small. Most downstream projects only need a couple of these pieces. -ORMTableBase provides structural introspection utilities for SQLAlchemy-mapped tables, without imposing any domain semantics. +1. A minimal ORM table base + +`ORMTableBase` provides structural utilities for mapped tables without pulling domain rules into the base layer. It supports: * mapper access and inspection @@ -41,17 +40,19 @@ class MyTable(ORMTableBase, Base): __tablename__ = "my_table" ``` -This base is intended to be inherited by all ORM tables, either directly or via higher-level mixins. +You can inherit from it directly or pick it up through one of the higher-level mixins. 2. CSV-based ingestion mixins -CSVLoadableTableInterface adds opt-in CSV loading support for ORM tables using pandas, with a focus on correctness and scalability. +`CSVLoadableTableInterface` adds staged file loading to ORM tables. It can use pandas or PyArrow loaders, and on PostgreSQL it can use a fast `COPY` path when the input is clean enough. Features include: +* staging table creation and cleanup * chunked loading for large files -* optional per-table normalisation logic -* optional deduplication against existing database rows -* safe bulk inserts using SQLAlchemy sessions +* optional casting and deduplication before insert +* backend-specific merge behaviour +* PostgreSQL fast-path loading with ORM fallback +* backend-aware index handling during merge ```python class MyTable(CSVLoadableTableInterface, ORMTableBase, Base): @@ -59,15 +60,11 @@ class MyTable(CSVLoadableTableInterface, ORMTableBase, Base): ``` -Downstream models may override: -* normalise_dataframe(...) -* dedupe_dataframe(...) -* csv_columns() -to implement table-specific ingestion policies. +The main extension points here are loader choice, column mapping, and the normal SQLAlchemy model definitions themselves. Most downstream projects do not need to override much beyond `csv_columns()` and the model schema. 3. Structured serialisation and hashing -SerialisableTableInterface adds lightweight, explicit serialisation helpers for ORM rows. +`SerialisableTableInterface` adds lightweight serialisation helpers for ORM rows. It supports: * conversion to dictionaries @@ -92,7 +89,7 @@ This is useful for: 4. Model registry and validation scaffolding -The library includes model-agnostic validation infrastructure, designed to compare ORM models against external specifications. +The library includes validation infrastructure for comparing ORM models against external specifications. This includes: * a model registry @@ -118,7 +115,8 @@ Validation output is available as: * exit codes suitable for pipelines 5. Database bootstrap helpers -The library provides lightweight helpers for schema creation and bootstrapping, without imposing a migration strategy. + +The library provides lightweight helpers for schema creation and bootstrapping. It does not try to replace migrations. ```python from orm_loader.metadata import Base @@ -127,24 +125,20 @@ from orm_loader.bootstrap import bootstrap bootstrap(engine, create=True) ``` -6. Safe bulk-loading utilities +6. Bulk-loading helpers -A reusable context manager simplifies trusted bulk ingestion workflows: -* temporarily disables foreign key checks where supported -* suppresses autoflush for performance -* ensures reliable rollback on failure +There are a few lower-level helpers for trusted bulk workflows, including backend-aware foreign key management and SQLite connection setup for heavy local loads. ## Summary -This library intentionally focuses on infrastructure, not semantics. +This library is meant to be the boring layer underneath downstream models: -It provides: * reusable ORM mixins -* safe ingestion patterns +* staged ingestion patterns * validation scaffolding -* database-portable utilities +* operational helpers -while leaving domain rules, business logic, and schema semantics to downstream libraries. +Domain rules, business logic, and schema semantics stay in the downstream project. This makes it suitable as a shared foundation for: * clinical data models diff --git a/TODO.txt b/TODO.txt new file mode 100644 index 0000000..fe7d4f0 --- /dev/null +++ b/TODO.txt @@ -0,0 +1,2 @@ +[] consider opt-in malformed text repair (as opposed to existing normalisation) - e.g. load_csv(..., text_repair: str | None = None) +- consider ftfy.fix_encoding() \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index c8ff6bb..015be3c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ A lightweight, reusable foundation for building and validating SQLAlchemy-based data models. -`orm-loader` provides **infrastructure, not semantics**. +`orm-loader` provides infrastructure for SQLAlchemy-based data models. It is the shared plumbing layer, not the place where model-specific rules live. It focuses on: @@ -11,17 +11,16 @@ It focuses on: - safe bulk ingestion patterns - file-based loading via staging tables - model-agnostic validation scaffolding -- database-portable operational helpers +- operational helpers for supported backends -No domain logic is included. -No schema assumptions are enforced. +It currently ships with backend implementations for SQLite and PostgreSQL. --- ## Core Concepts - **Tables are structural** — semantics live downstream -- **Mixins define capabilities**, not behaviour contracts +- **Mixins define capabilities** - **Protocols decouple infrastructure from implementations** - **Ingestion is explicit and staged** @@ -37,13 +36,7 @@ No schema assumptions are enforced. # Design Philosophy -`orm-loader` is intentionally conservative. - -It provides: - -- *mechanisms*, not policies -- *capabilities*, not workflows -- *structure*, not semantics +`orm-loader` is intentionally conservative. It gives downstream libraries the machinery to load, inspect, and validate data without deciding what the data means. The library is designed to sit **below**: @@ -65,6 +58,7 @@ and **above**: - No schema enforcement - No migrations - No concurrency guarantees +- No support yet for arbitrary database dialects --- @@ -81,4 +75,3 @@ This allows downstream libraries to: - replace base classes - mock implementations - incrementally adopt features - diff --git a/docs/loaders/context.md b/docs/loaders/context.md index 2e4528d..29418fd 100644 --- a/docs/loaders/context.md +++ b/docs/loaders/context.md @@ -25,6 +25,7 @@ on globals or implicit configuration. | `chunksize` | Optional chunk size | | `normalise` | Whether to cast values to ORM types | | `dedupe` | Whether to deduplicate incoming data | +| `quote_mode` | CSV quoting mode for PostgreSQL fast-path loading | ::: orm_loader.loaders.data_classes.LoaderContext diff --git a/docs/loaders/helpers.md b/docs/loaders/helpers.md index f299a95..8842dd3 100644 --- a/docs/loaders/helpers.md +++ b/docs/loaders/helpers.md @@ -1,8 +1,6 @@ # Loader Helper Utilities -This page documents low-level helper functions used by loaders. - -These utilities are stateless and intentionally conservative. +This page covers the low-level functions that support the loader implementations. --- @@ -37,17 +35,17 @@ Used by `ParquetLoader` for internal deduplication. --- -## Conservative CSV parsing +## Batch-oriented CSV parsing ### `conservative_load_parquet(...)` -Reads CSV files using PyArrow with: +Despite the name, this helper reads delimited text with PyArrow and yields batches: - strict column inclusion - malformed row skipping - chunked batch iteration -This is used when loading CSVs via the Parquet pipeline. +This is used by the PyArrow-based loader path. --- @@ -55,18 +53,18 @@ This is used when loading CSVs via the Parquet pipeline. ### `quick_load_pg(...)` -Loads CSV files into PostgreSQL staging tables using `COPY`. +Loads CSV files into a PostgreSQL staging table using `COPY`. ### Characteristics -- Extremely fast -- Bypasses ORM -- Sensitive to data quality issues +- Fast +- Bypasses ORM row construction +- Works best on clean input ### Failure handling - Errors trigger rollback -- Loader falls back to ORM-based loading -- No partial silent loads +- `CSVLoadableTableInterface` falls back to ORM-based loading +- Failures are noisy on purpose This helper is only used when explicitly supported by the database. diff --git a/docs/loaders/index.md b/docs/loaders/index.md index b38f267..3d41b0f 100644 --- a/docs/loaders/index.md +++ b/docs/loaders/index.md @@ -1,14 +1,13 @@ # Loaders -The `orm_loader.loaders` module provides **conservative, schema-aware file -ingestion infrastructure** for loading external data into ORM-backed -staging tables. +The `orm_loader.loaders` module provides conservative, schema-aware file +loading into ORM-backed staging tables. This subsystem is designed to handle: - untrusted or messy source files - large datasets requiring chunked processing -- incremental and repeatable loads +- repeatable staged loads - dialect-specific optimisations (e.g. PostgreSQL COPY) - explicit, inspectable failure modes @@ -23,7 +22,7 @@ they do not embed domain rules or business semantics. [`LoaderContext`](context.md) -A `LoaderContext` object carries all state required to load a single file: +A `LoaderContext` object carries the state required to load one file: - target ORM table - database session @@ -44,8 +43,7 @@ All loaders implement a common interface: - `orm_file_load(ctx)` — orchestrates file ingestion - `dedupe(data, ctx)` — defines deduplication semantics -Concrete implementations differ only in **how data is read and processed**, -not in how it is staged. +Concrete implementations mainly differ in how they read and transform incoming data. --- @@ -54,11 +52,10 @@ not in how it is staged. Loaders always write to **staging tables**, never directly to production tables. -This allows: +This gives you: - safe rollback - repeatable merges -- database-level deduplication - bulk loading optimisations Final merge semantics are handled by the table mixins, not by loaders. @@ -69,8 +66,8 @@ Final merge semantics are handled by the table mixins, not by loaders. | Loader | Use case | |------|----------| -| `PandasLoader` | Flexible, debuggable CSV ingestion | -| `ParquetLoader` | High-volume, columnar ingestion | +| `PandasLoader` | Flexible CSV and TSV ingestion | +| `ParquetLoader` | Columnar or batch-oriented ingestion | Both loaders share the same lifecycle and guarantees. @@ -81,11 +78,11 @@ Both loaders share the same lifecycle and guarantees. 1. Detect file format and encoding 2. Read data in chunks or batches 3. Optionally normalise to ORM column types -4. Optionally deduplicate (internal and/or database-level) +4. Optionally deduplicate within the incoming data 5. Insert into staging table 6. Return row count -No implicit commits or merges occur at this layer. +Final merge behaviour belongs to the table mixins and backend layer, not to the loader itself. --- diff --git a/docs/loaders/loaders.md b/docs/loaders/loaders.md index 897bdb2..1fa8728 100644 --- a/docs/loaders/loaders.md +++ b/docs/loaders/loaders.md @@ -3,8 +3,7 @@ This page documents the concrete loader implementations provided by `orm_loader`. -All loaders implement the same interface and differ only in -how data is read and processed. +All loaders implement the same interface. The difference is in how they read data and how much work they do before rows reach the staging table. --- @@ -24,7 +23,7 @@ All loaders: - load into staging tables only - respect `LoaderContext` flags - return row counts -- avoid implicit commits +- leave final merge behaviour to the table layer --- @@ -34,7 +33,7 @@ All loaders: ### Characteristics -- Supports CSV and TSV inputs +- Works well with CSV and TSV inputs - Easy to debug and inspect - Supports chunked loading - Flexible transformation pipeline @@ -67,7 +66,6 @@ All loaders: - More complex pipeline - Less flexible row-wise transformations -- DB-level deduplication not yet implemented ### Best suited for @@ -79,16 +77,7 @@ All loaders: ## Deduplication behaviour -Deduplication occurs in two phases: - -1. **Internal deduplication** - Removes duplicate primary key rows within the incoming data. - -2. **Database-level deduplication (optional)** - Removes rows that already exist in the database. - -Database-level deduplication is currently implemented for pandas-based -loads. +Deduplication here means deduplicating within the incoming data before it is inserted into staging. The merge step is what decides what happens when incoming rows overlap with existing target rows. --- @@ -100,4 +89,4 @@ When enabled, loaders: - drop rows violating required constraints - log casting failures with examples -No schema changes are performed. +No schema changes are performed at the loader layer. diff --git a/docs/tables/loadable_table.md b/docs/tables/loadable_table.md index 90302e1..51ebfe4 100644 --- a/docs/tables/loadable_table.md +++ b/docs/tables/loadable_table.md @@ -3,10 +3,11 @@ Infrastructure for staged, file-based ingestion into ORM tables. Supports: -- CSV-based ingestion -- optional fast-path database COPY -- dialect-aware merge strategies -- Parquet loading hooks +- staged file loading into backend-specific staging tables +- PostgreSQL fast-path `COPY` with ORM fallback +- backend-aware merge strategies +- pandas and PyArrow-based loader paths +- index handling during merge --- diff --git a/docs/tables/mat_view.md b/docs/tables/mat_view.md index 5bbd429..2a1a807 100644 --- a/docs/tables/mat_view.md +++ b/docs/tables/mat_view.md @@ -1,6 +1,6 @@ # Materialised Views -This module provides a SQLAlchemy-native pattern for defining, creating, refreshing, and orchestrating materialized views using normal `Select` constructs, with explicit dependency management and deterministic refresh order. +This module provides a SQLAlchemy-native way to define, create, refresh, and order materialized views from ordinary `Select` constructs. It is designed for: @@ -9,7 +9,7 @@ It is designed for: * large fact tables with repeated joins or aggregates * schema-level orchestration (migrations, setup, Airflow, admin tasks) -The implementation is PostgreSQL-oriented (due to materialized view support), but remains cleanly isolated from ORM persistence logic. +The implementation is PostgreSQL-oriented. The mixin resolves a backend from the supplied bind, and the built-in PostgreSQL backend is currently the only one that supports materialized views. ## Overview @@ -21,7 +21,7 @@ The materialized view system consists of four main parts: * backing `Select` * optional dependencies 3. Dependency resolution: A topological sort over declared dependencies to determine refresh order. -4. Refresh orchestration: Helpers to refresh one or many materialized views safely and predictably. +4. Refresh orchestration: Helpers to refresh one or many materialized views in a predictable order. ### Defining the Materialised View diff --git a/notebooks/01_setup_registry.ipynb b/notebooks/01_setup_registry.ipynb deleted file mode 100644 index 1e8a679..0000000 --- a/notebooks/01_setup_registry.ipynb +++ /dev/null @@ -1,207 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "d4a7dfa5", - "metadata": {}, - "outputs": [], - "source": [ - "from orm_loader.registry import (\n", - " ModelRegistry,\n", - " ModelDescriptor,\n", - " TableSpec,\n", - " FieldSpec,\n", - " Validator,\n", - " ValidationIssue,\n", - " SeverityLevel,\n", - " ValidationRunner,\n", - " always_on_validators,\n", - ")\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9fec9cb5", - "metadata": {}, - "outputs": [], - "source": [ - "m = ModelRegistry(model_version = '5.4', model_name = 'CDM')" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "5685c951", - "metadata": {}, - "outputs": [], - "source": [ - "field_spec = Path('OMOP_CDMv5.4_Field_Level.csv')\n", - "table_spec = Path('OMOP_CDMv5.4_Table_Level.csv')\n", - "\n", - "m.load_table_specs(table_csv=table_spec, field_csv=field_spec)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "6efeec23", - "metadata": {}, - "outputs": [], - "source": [ - "m.discover_models('omop_alchemy.cdm.model')" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "9cb956ec", - "metadata": {}, - "outputs": [], - "source": [ - "runner = ValidationRunner(\n", - " validators=always_on_validators(),\n", - " fail_fast=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "ad537e7e", - "metadata": {}, - "outputs": [], - "source": [ - "report = runner.run(m)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e3d63142", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "MODEL v5.4: 0 error(s), 27 warning(s), 8 info\n" - ] - } - ], - "source": [ - "print(report.summary())\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43589cc8", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "📦 cdm_source\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cdm_source_name) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 cohort\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cohort_definition_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: subject_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 cohort_definition\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: cohort_definition_id) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 concept_ancestor\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: ancestor_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: descendant_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 concept_relationship\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: relationship_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 concept_synonym\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: concept_synonym_name) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 death\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: person_id) Hint: ORM primary key not marked as primary key in specification\n", - "\n", - "📦 drug_strength\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: drug_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: ingredient_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 episode\n", - " ⚠️ FOREIGN_KEY_NOT_IN_SPEC (field: episode_parent_id) Hint: ORM defines FK but specification does not\n", - "\n", - "📦 episode_event\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: episode_event_field_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: episode_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: event_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 fact_relationship\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: domain_concept_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: domain_concept_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: fact_id_1) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: fact_id_2) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: relationship_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n", - "\n", - "📦 relationship\n", - " ⚠️ FOREIGN_KEY_NOT_IN_SPEC (field: reverse_relationship_id) Hint: ORM defines FK but specification does not\n", - "\n", - "📦 source_to_concept_map\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_code) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_concept_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ PRIMARY_KEY_NOT_DECLARED_IN_SPEC (field: source_vocabulary_id) Hint: ORM primary key not marked as primary key in specification\n", - " ⚠️ COMPOSITE_PRIMARY_KEY Hint: Composite primary key detected\n" - ] - } - ], - "source": [ - "if not report.is_valid():\n", - " print(report.render_text_report())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ef909ef", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.11.12)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/02_test_file_load.ipynb b/notebooks/02_test_file_load.ipynb deleted file mode 100644 index eb2a64d..0000000 --- a/notebooks/02_test_file_load.ipynb +++ /dev/null @@ -1,220 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "c5d4e71b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "897b6570", - "metadata": {}, - "outputs": [], - "source": [ - "import sqlalchemy as sa\n", - "import sqlalchemy.orm as so\n", - "from sqlalchemy.orm import DeclarativeBase, Session\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import tempfile\n", - "import logging\n", - "from orm_loader.tables.base import CSVLoadableTableInterface \n", - "\n", - "logging.basicConfig(level=logging.INFO)\n", - "\n", - "class Base(DeclarativeBase):\n", - " pass\n", - "\n", - "engine = sa.create_engine(\"sqlite:///test.db\", echo=False, future=True)\n", - "Base.metadata.bind = engine\n", - "\n", - "\n", - "class TestTable(Base, CSVLoadableTableInterface):\n", - " __tablename__ = \"test_table\"\n", - "\n", - " id: so.Mapped[int] = so.mapped_column(primary_key=True)\n", - " name: so.Mapped[str] = so.mapped_column(nullable=False)\n", - "\n", - "Base.metadata.create_all(engine)\n", - "\n", - "tmp = Path(tempfile.mkdtemp())\n", - "\n", - "csv_initial = tmp / \"test_table.csv\"\n", - "csv_replace = tmp / \"test_table_replace.csv\"\n", - "csv_empty = tmp / \"test_table_empty.csv\"\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 1, \"name\": \"alpha\"},\n", - " {\"id\": 2, \"name\": \"beta\"},\n", - " {\"id\": 3, \"name\": \"gamma\"},\n", - " ]\n", - ").to_csv(csv_initial, index=False, sep=\"\\t\")\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 2, \"name\": \"beta_updated\"},\n", - " {\"id\": 3, \"name\": \"gamma_updated\"},\n", - " ]\n", - ").to_csv(csv_replace, index=False, sep=\"\\t\")\n", - "\n", - "csv_empty.touch()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "a62502c4", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_test_table does not exist; recreating\n" - ] - }, - { - "data": { - "text/plain": [ - "[<__main__.TestTable at 0x120949d30>, <__main__.TestTable at 0x1166facf0>]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with Session(engine) as session:\n", - " inserted = TestTable.load_csv(\n", - " session,\n", - " csv_initial,\n", - " dedupe=False,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - "rows\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ba6337f0", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "rows" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6956332", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " replaced = TestTable.replace_from_csv(\n", - " session,\n", - " csv_replace,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - "rows\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29c775f5", - "metadata": {}, - "outputs": [], - "source": [ - "with engine.connect() as conn:\n", - " tables = conn.execute(\n", - " sa.text(\n", - " \"SELECT name FROM sqlite_master WHERE type='table'\"\n", - " )\n", - " ).fetchall()\n", - "\n", - "tables" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6e8fc89", - "metadata": {}, - "outputs": [], - "source": [ - "with Session(engine) as session:\n", - " loaded = TestTable.replace_from_csv(\n", - " session,\n", - " csv_empty,\n", - " )\n", - " session.commit()\n", - "\n", - " rows = session.execute(\n", - " sa.select(TestTable).order_by(TestTable.id)\n", - " ).scalars().all()\n", - "\n", - " print(\"After empty file replace:\", [(r.id, r.name) for r in rows])\n", - " print(\"Rows loaded from empty file:\", loaded)\n", - "\n", - " # hard assertions (will raise if broken)\n", - " assert loaded == 0, \"Empty CSV should load 0 rows\"\n", - " assert [(r.id, r.name) for r in rows] == [\n", - " (1, \"alpha\"),\n", - " (2, \"beta_updated\"),\n", - " (3, \"gamma_updated\"),\n", - " ], \"Empty CSV must not modify existing rows\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30eea280", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.12.10)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/03_improve_load_perf.ipynb b/notebooks/03_improve_load_perf.ipynb deleted file mode 100644 index 079f1c8..0000000 --- a/notebooks/03_improve_load_perf.ipynb +++ /dev/null @@ -1,318 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "a251fa62", - "metadata": {}, - "outputs": [], - "source": [ - "import sqlalchemy as sa\n", - "import sqlalchemy.orm as so\n", - "from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker\n", - "from sqlalchemy.exc import IntegrityError\n", - "from pathlib import Path\n", - "import pandas as pd\n", - "import numpy as np\n", - "import tempfile, logging, os\n", - "from orm_loader.tables.base import CSVLoadableTableInterface \n", - "from orm_loader.loaders import LoaderContext\n", - "from orm_loader.loaders.loader_interface import ParquetLoader, LoaderInterface, PandasLoader\n", - "\n", - "from orm_loader.helpers import configure_logging, bootstrap, explain_sqlite_fk_error, bulk_load_context, configure_logging\n", - "\n", - "from omop_alchemy import get_engine_name, load_environment, TEST_PATH, ROOT_PATH\n", - "from omop_alchemy.cdm.model.vocabulary import (\n", - " Domain,\n", - " Vocabulary,\n", - " Concept_Class,\n", - " Relationship,\n", - " Concept,\n", - " Concept_Ancestor,\n", - " Concept_Relationship,\n", - " Concept_Synonym,\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9173aad2", - "metadata": {}, - "outputs": [], - "source": [ - "logging.basicConfig(level=logging.INFO)\n", - "\n", - "class Base(DeclarativeBase):\n", - " pass\n", - "\n", - "engine_string = \"postgresql+psycopg2://airflow:airflow@0.0.0.0:5433/mosaiq\"\n", - "engine = sa.create_engine(engine_string, echo=False, future=True)\n", - "Base.metadata.bind = engine\n", - "\n", - "class TestTable(Base, CSVLoadableTableInterface):\n", - " __tablename__ = \"test_table\"\n", - "\n", - " id: so.Mapped[int] = so.mapped_column(primary_key=True)\n", - " name: so.Mapped[str] = so.mapped_column(nullable=False)\n", - "\n", - "Base.metadata.create_all(engine)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8c2f3d9a", - "metadata": {}, - "outputs": [], - "source": [ - "tmp = Path(tempfile.mkdtemp())\n", - "\n", - "csv_initial = tmp / \"test_table.csv\"\n", - "csv_replace = tmp / \"test_table_replace.csv\"\n", - "csv_empty = tmp / \"test_table_empty.csv\"\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 1, \"name\": \"alpha\"},\n", - " {\"id\": 2, \"name\": \"beta\"},\n", - " {\"id\": 3, \"name\": \"gamma\"},\n", - " ]\n", - ").to_csv(csv_initial, index=False, sep=\"\\t\")\n", - "\n", - "pd.DataFrame(\n", - " [\n", - " {\"id\": 2, \"name\": \"beta_updated\"},\n", - " {\"id\": 3, \"name\": \"gamma_updated\"},\n", - " ]\n", - ").to_csv(csv_replace, index=False, sep=\"\\t\")\n", - "\n", - "csv_empty.touch()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "53b8a52a", - "metadata": {}, - "outputs": [], - "source": [ - "session = Session(engine)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "5a683a1b", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_test_table does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_test_table via COPY (encoding=utf-8, delimiter=\t)\n" - ] - }, - { - "data": { - "text/plain": [ - "3" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "TestTable.load_csv(path=csv_initial, session=session)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "83505eb9", - "metadata": {}, - "outputs": [], - "source": [ - "session.commit()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "eba074e6", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:10:53,957 | INFO | sql_loader.omop_alchemy.config | Environment variables loaded from .env file\n", - "INFO:sql_loader.omop_alchemy.config:Environment variables loaded from .env file\n" - ] - } - ], - "source": [ - "ATHENA_INITIAL_LOAD = [\n", - " Domain,\n", - " Vocabulary,\n", - " Concept_Class,\n", - " Relationship,\n", - " Concept\n", - "]\n", - "\n", - "ATHENA_SUBSEQUENT_LOAD = [\n", - " Concept_Ancestor,\n", - " Concept_Relationship,\n", - " Concept_Synonym,\n", - "]\n", - "\n", - "configure_logging()\n", - "load_environment()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "930f6572", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:10:54,687 | INFO | sql_loader.omop_alchemy.config | Database engine configured for schema 'cdm'\n", - "INFO:sql_loader.omop_alchemy.config:Database engine configured for schema 'cdm'\n" - ] - } - ], - "source": [ - "engine_string = get_engine_name('cdm')" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "de3d47e5", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:orm_loader.helpers.bootstrap:Bootstrapping schema (create=True)\n" - ] - } - ], - "source": [ - "engine = sa.create_engine(engine_string, future=True, echo=False)\n", - "bootstrap(engine, create=True)\n", - "\n", - "Session = sessionmaker(bind=engine, future=True)\n", - "session = Session()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7da32a8f", - "metadata": {}, - "outputs": [], - "source": [ - "source_path = Path(os.environ['SOURCE_PATH'])\n", - "\n", - "\n", - "p = ParquetLoader()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6ac0d9a5", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23ae5a8a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2026-01-22 17:12:05,251 | INFO | sql_loader.orm_loader.helpers.bulk | Disabled foreign key checks for bulk load\n", - "INFO:sql_loader.orm_loader.helpers.bulk:Disabled foreign key checks for bulk load\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_domain does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_domain via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_vocabulary does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_vocabulary via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_concept_class does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_concept_class via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_relationship does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_relationship via COPY (encoding=utf-8, delimiter=\t)\n", - "WARNING:orm_loader.tables.base.loadable_table:Staging table _staging_concept does not exist; recreating\n", - "INFO:orm_loader.loaders.loading_helpers:Bulk loading _staging_concept via COPY (encoding=utf-8, delimiter=\t)\n" - ] - } - ], - "source": [ - "with bulk_load_context(session):\n", - " for model in ATHENA_INITIAL_LOAD:\n", - " _ = model.load_csv(\n", - " session,\n", - " source_path / f\"{model.__tablename__.upper()}.csv\",\n", - " dedupe=False,\n", - " merge_strategy=\"upsert\",\n", - " loader=p,\n", - " )\n", - " session.commit()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e694d48c", - "metadata": {}, - "outputs": [], - "source": [ - "with bulk_load_context(session):\n", - " for model in ATHENA_SUBSEQUENT_LOAD:\n", - " _ = model.load_csv(\n", - " session,\n", - " source_path / f\"{model.__tablename__.upper()}.csv\",\n", - " dedupe=False,\n", - " chunksize=60_000_000, # parquet loader chunk is bytes not rows\n", - " merge_strategy=\"replace\",\n", - " loader=p,\n", - " )\n", - " session.commit()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "orm-loader (3.12.10)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/pyproject.toml b/pyproject.toml index 398dc4c..407130e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "orm-loader" -version = "0.3.27" +version = "0.4.0" description = "Generic base classes to handle ORM functionality for multiple downstream datamodels" readme = "README.md" authors = [ @@ -14,6 +14,18 @@ dependencies = [ "sqlalchemy>=2.0.45", ] + +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Medical Science Apps.", + "Topic :: Database :: Database Engines/Servers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.12", +] + [project.urls] Homepage = "https://AustralianCancerDataNetwork.github.io/orm-loader" Documentation = "https://AustralianCancerDataNetwork.github.io/orm-loader" @@ -25,16 +37,20 @@ requires = ["uv_build>=0.9.2,<0.10.0"] build-backend = "uv_build" [project.optional-dependencies] +postgres = [ + "psycopg[binary]>=3.2", +] dev = [ - "mypy>=1.19.1", "pytest>=9.0.3", + "mypy>=1.19.1", "ruff>=0.14.11", "mkdocs-material>=9.7.1", "mkdocstrings-python>=2.0.1", "requests>=2.33.0", "mkdocs>=1.6.1", "mkdocs-mermaid2-plugin", - "Pygments>=2.20.0" + "Pygments>=2.20.0", + "python-dotenv" ] [tool.setuptools] @@ -54,3 +70,9 @@ python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] addopts = "-ra" +markers = [ + "postgres: requires a running Postgres instance (set TEST_POSTGRES_URL)", +] + +[tool.pyright] +reportMissingTypeStubs = false \ No newline at end of file diff --git a/src/orm_loader/backends/__init__.py b/src/orm_loader/backends/__init__.py new file mode 100644 index 0000000..3fa6888 --- /dev/null +++ b/src/orm_loader/backends/__init__.py @@ -0,0 +1,13 @@ +from .postgres import PostgresBackend +from .resolve import resolve_backend +from .sqlite import SQLiteBackend +from .base import BackendCapabilities, DatabaseBackend, Dialect + +__all__ = [ + "BackendCapabilities", + "DatabaseBackend", + "Dialect", + "PostgresBackend", + "SQLiteBackend", + "resolve_backend", +] diff --git a/src/orm_loader/backends/base.py b/src/orm_loader/backends/base.py new file mode 100644 index 0000000..b5fc9a1 --- /dev/null +++ b/src/orm_loader/backends/base.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager, contextmanager, nullcontext +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Type, Any, Iterator + +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy.engine import Connection, Engine + +if TYPE_CHECKING: + from ..loaders.data_classes import LoaderContext + from ..tables.typing import CSVTableProtocol + + +@dataclass(frozen=True) +class BackendCapabilities: + """ + Capability flags exposed by a database backend. + + These defaults are intentionally conservative. Concrete backends should + opt into capabilities explicitly. + """ + + supports_fast_load: bool = False + supports_unlogged_staging: bool = False + supports_fk_toggle: bool = False + supports_materialized_views: bool = False + + +class Dialect(str, Enum): + """Supported SQLAlchemy dialect names.""" + + SQLITE = "sqlite" + POSTGRESQL = "postgresql" + + +class DatabaseBackend(ABC): + """ + Abstract base class for database-specific loader behavior. + + This class defines the stable contract for future backend implementations + without changing existing loader orchestration yet. + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable backend name.""" + + @property + @abstractmethod + def dialect(self) -> Dialect: + """SQLAlchemy dialect handled by this backend.""" + + @property + @abstractmethod + def capabilities(self) -> BackendCapabilities: + """Capability flags supported by this backend.""" + + def supports_dialect(self, dialect: Dialect) -> bool: + """Return ``True`` when the backend handles the given dialect.""" + return self.dialect == dialect + + @property + def default_index_strategy(self) -> str: + """Default index strategy used when callers request ``auto``.""" + return "drop_rebuild" + + def resolve_index_strategy(self, index_strategy: str) -> str: + """ + Resolve a caller-facing index strategy to a concrete backend choice. + """ + valid = {"auto", "drop_rebuild", "keep"} + if index_strategy not in valid: + raise ValueError( + f"Unknown index_strategy '{index_strategy}'. Expected one of: {sorted(valid)}" + ) + if index_strategy == "auto": + return self.default_index_strategy + return index_strategy + + def _require_capability(self, capability_name: str, feature_name: str) -> None: + """ + Raise a clear error when a backend capability is not supported. + """ + if not hasattr(self.capabilities, capability_name): + raise AttributeError( + f"Unknown backend capability {capability_name!r} on {type(self.capabilities).__name__}" + ) + if not getattr(self.capabilities, capability_name): + raise NotImplementedError( + f"Backend '{self.name}' does not support {feature_name}" + ) + + @contextmanager + def _as_connection( + self, + bind: Engine | Connection, + ) -> Iterator[Connection]: + if isinstance(bind, Engine): + with bind.begin() as conn: + yield conn + else: + yield bind + + def _insertable_column_names( + self, + table_cls: Type["CSVTableProtocol"], + ) -> list[str]: + """ + Return column names safe to include in generic insert statements. + + Computed columns are excluded because backend loaders and merge helpers + should not attempt to write to them directly. + """ + return [c.name for c in table_cls.__table__.columns if c.computed is None] + + @abstractmethod + def create_staging_table( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + """Create a staging table for the supplied ORM table class.""" + + @abstractmethod + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + """Drop a staging table if it exists.""" + + def load_staging_fast( + self, + loader_context: "LoaderContext", + staging_name: str, + ) -> int | None: + """ + Attempt a backend-native fast-path load. + + Return the inserted row count when handled, or ``None`` when the + backend has no fast-path loader for the given context. + """ + return None + + @staticmethod + @abstractmethod + def _normalize_fk_check_state(previous_state: str | int) -> str | int: + """Validate and normalise a previously-returned FK state before interpolating into SQL. + + Each backend accepts a different type (SQLite: int, Postgres: str) and must + implement this to guard restore_fk_check() against invalid or injected values. + """ + + @abstractmethod + def disable_fk_check(self, session: so.Session) -> str | int: + """Disable FK checks and return the previous backend-specific state.""" + + @abstractmethod + def enable_fk_check(self, session: so.Session) -> str | int: + """Explicitly enable FK checks and return the previous backend-specific state.""" + + @abstractmethod + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + """Restore FK checks to a previously returned backend-specific state.""" + + @abstractmethod + def merge_replace( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + """Merge staging rows by replacing matching target rows first.""" + + @abstractmethod + def merge_upsert( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + """Merge staging rows using backend-specific upsert semantics.""" + + @abstractmethod + def merge_insert( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + """Insert all staging rows into the target table.""" + + def merge_context( + self, + table_cls: Type["CSVTableProtocol"], + session: so.Session, + ) -> AbstractContextManager[None]: + """Return a context manager for merge-time backend operations.""" + return nullcontext() + + @contextmanager + def bulk_load_context( + self, + session: so.Session, + *, + disable_fk: bool = True, + no_autoflush: bool = True, + ): + """ + Generic bulk-load context that defers FK semantics to the backend. + """ + previous_fk_state: str | int | None = None + try: + if disable_fk: + self._require_capability("supports_fk_toggle", "foreign key toggling") + raw_state = self.disable_fk_check(session) + previous_fk_state = self._normalize_fk_check_state(raw_state) + + if no_autoflush: + with session.no_autoflush: + yield + else: + yield + + except Exception: + session.rollback() + raise + + finally: + if previous_fk_state is not None: + self.restore_fk_check(session, previous_fk_state) + + @abstractmethod + def create_materialized_view( + self, + bind: "Engine | Connection", + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + """Create a materialized view for the supplied selectable.""" + + @abstractmethod + def refresh_materialized_view( + self, + bind: "Engine | Connection", + name: str, + ) -> None: + """Refresh a materialized view.""" diff --git a/src/orm_loader/backends/postgres.py b/src/orm_loader/backends/postgres.py new file mode 100644 index 0000000..7b57c8b --- /dev/null +++ b/src/orm_loader/backends/postgres.py @@ -0,0 +1,237 @@ +from __future__ import annotations + +from contextlib import contextmanager, AbstractContextManager +from typing import TYPE_CHECKING, Any +import sqlalchemy as sa +import sqlalchemy.orm as so +import sqlalchemy.event as sae + +from .base import BackendCapabilities, DatabaseBackend, Dialect +from ..loaders.loading_helpers import quick_load_pg + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + from ..loaders.data_classes import LoaderContext + from ..tables.typing import CSVTableProtocol + +_VALID_PG_REPLICATION_ROLES = frozenset({"origin", "local", "replica"}) + + +class PostgresBackend(DatabaseBackend): + @property + def name(self) -> str: + return "postgres" + + @property + def dialect(self) -> Dialect: + return Dialect.POSTGRESQL + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=True, + supports_unlogged_staging=True, + supports_fk_toggle=True, + supports_materialized_views=True, + ) + + def create_staging_table( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + table = table_cls.__table__ + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}";')) + session.execute( + sa.text( + f''' + CREATE UNLOGGED TABLE "{staging_name}" + (LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS); + ''' + ) + ) + + computed_cols = [c.name for c in table.columns if c.computed is not None] + for col in computed_cols: + session.execute(sa.text(f'ALTER TABLE "{staging_name}" DROP COLUMN "{col}";')) + + session.commit() + + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}"')) + + def load_staging_fast( + self, + loader_context: "LoaderContext", + staging_name: str, + ) -> int | None: + return quick_load_pg( + path=loader_context.path, + session=loader_context.session, + tablename=staging_name, + quote_mode=loader_context.quote_mode, + ) + + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str: + if isinstance(previous_state, int): + raise ValueError( + f"Invalid PostgreSQL session_replication_role {previous_state!r}: " + "Postgres uses string roles ('origin', 'local', 'replica'), not integers. " + "The value passed here should always come from this backend's own " + "disable_fk_check(), which returns a string." + ) + normalised = previous_state.strip().lower() + if normalised not in _VALID_PG_REPLICATION_ROLES: + raise ValueError( + f"Invalid PostgreSQL session_replication_role {previous_state!r}. " + f"Expected one of: {sorted(_VALID_PG_REPLICATION_ROLES)}" + ) + return normalised + + def disable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() + session.execute(sa.text("SET session_replication_role = 'replica'")) + if not isinstance(previous_state, str): + raise RuntimeError("Expected PostgreSQL FK state to be a string") + return previous_state + + def enable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(sa.text("SHOW session_replication_role")).scalar() + session.execute(sa.text("SET session_replication_role = 'origin'")) + if not isinstance(previous_state, str): + raise RuntimeError("Expected PostgreSQL FK state to be a string") + return previous_state + + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + safe_state = self._normalize_fk_check_state(previous_state) + session.execute(sa.text(f"SET session_replication_role = '{safe_state}'")) + + def merge_replace( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + pk_join = " AND ".join( + f't."{c}" = s."{c}"' for c in pk_cols + ) + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" t + USING "{staging_name}" s + WHERE {pk_join}; + """ + ) + ) + + def merge_upsert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + conflict_cols = ", ".join(f'"{c}"' for c in pk_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}" + ON CONFLICT ({conflict_cols}) DO NOTHING; + """ + ) + ) + + def merge_insert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_context( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + ) -> AbstractContextManager[None]: + return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) + + + + def create_materialized_view( + self, + bind: Engine | Connection, + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + from ..mappers.materialised_view_mixin import CreateMaterializedView + + with self._as_connection(bind) as conn: + conn.execute(CreateMaterializedView(name, selectable)) + + def refresh_materialized_view( + self, + bind: Engine | Connection, + name: str, + ) -> None: + with self._as_connection(bind) as conn: + safe_name = name + dialect = getattr(conn, "dialect", None) + if dialect is not None: + safe_name = dialect.identifier_preparer.quote(name) + conn.execute( + sa.text(f"REFRESH MATERIALIZED VIEW {safe_name};") + ) + + @contextmanager + def engine_with_replica_role(self, engine: "Engine"): + def _set_replica_role( + dbapi_conn: sa.engine.interfaces.DBAPIConnection, + _, + ) -> None: + cur = dbapi_conn.cursor() + cur.execute("SET session_replication_role = 'replica'") + cur.close() + + sae.listen(engine, "connect", _set_replica_role) + + try: + yield engine + finally: + sae.remove(engine, "connect", _set_replica_role) + with engine.connect() as conn: + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + conn.execute(sa.text("SET session_replication_role = DEFAULT")) + role = conn.execute( + sa.text("SHOW session_replication_role") + ).scalar() + if role != "origin": + raise RuntimeError("Failed to restore session_replication_role") diff --git a/src/orm_loader/backends/resolve.py b/src/orm_loader/backends/resolve.py new file mode 100644 index 0000000..e3919c6 --- /dev/null +++ b/src/orm_loader/backends/resolve.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +import sqlalchemy.orm as so + +from .base import DatabaseBackend, Dialect +from .postgres import PostgresBackend +from .sqlite import SQLiteBackend + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + +_BACKEND_TYPES: tuple[type[DatabaseBackend], ...] = ( + PostgresBackend, + SQLiteBackend, +) + + +def _dialect(bindable: "so.Session | Engine | Connection") -> Dialect: + if isinstance(bindable, so.Session): + bind = bindable.get_bind() + dialect_name = bind.dialect.name + elif hasattr(bindable, "dialect"): + dialect_name = bindable.dialect.name + else: + raise TypeError(f"Unsupported bindable type: {type(bindable)!r}") + + try: + return Dialect(dialect_name) + except ValueError as exc: + raise NotImplementedError( + f"Unsupported SQLAlchemy dialect '{dialect_name}'" + ) from exc + + +def resolve_backend(bindable: "so.Session | Engine | Connection") -> DatabaseBackend: + """ + Resolve a concrete backend from a SQLAlchemy session, engine, or connection. + """ + dialect = _dialect(bindable) + for backend_type in _BACKEND_TYPES: + backend = backend_type() + if backend.supports_dialect(dialect): + return backend + raise NotImplementedError(f"No backend registered for dialect '{dialect.value}'") diff --git a/src/orm_loader/backends/sqlite.py b/src/orm_loader/backends/sqlite.py new file mode 100644 index 0000000..753abd4 --- /dev/null +++ b/src/orm_loader/backends/sqlite.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import logging +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING, Any +from contextlib import AbstractContextManager + +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy import event, text +from sqlalchemy.exc import IntegrityError + +from .base import BackendCapabilities, DatabaseBackend, Dialect + +if TYPE_CHECKING: + from sqlalchemy.engine import Connection, Engine + + from ..tables.typing import CSVTableProtocol + + +logger = logging.getLogger(__name__) +VALID_SQLITE_JOURNAL_MODES = frozenset( + {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"} +) + + +class SQLiteBackend(DatabaseBackend): + def __init__( + self, + *, + busy_timeout_ms: int = 60000, + journal_mode: str = "WAL", + defer_foreign_keys: bool = True, + ) -> None: + self.busy_timeout_ms = busy_timeout_ms + self.journal_mode = self._validate_journal_mode(journal_mode) + self.defer_foreign_keys = defer_foreign_keys + + @staticmethod + def _validate_journal_mode(journal_mode: str) -> str: + normalised = journal_mode.strip().upper() + if normalised not in VALID_SQLITE_JOURNAL_MODES: + raise ValueError( + "Unsupported SQLite journal_mode " + f"{journal_mode!r}. Expected one of: {sorted(VALID_SQLITE_JOURNAL_MODES)}" + ) + return normalised + + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str: + if isinstance(previous_state, int): + if previous_state == 1: + return "ON" + if previous_state == 0: + return "OFF" + elif isinstance(previous_state, str): + normalised = previous_state.strip().upper() + if normalised in {"1", "ON"}: + return "ON" + if normalised in {"0", "OFF"}: + return "OFF" + raise ValueError( + f"Invalid SQLite foreign_keys state {previous_state!r}. " + "Expected 0, 1, 'OFF', or 'ON'." + ) + + @property + def name(self) -> str: + return "sqlite" + + @property + def dialect(self) -> Dialect: + return Dialect.SQLITE + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=False, + supports_unlogged_staging=False, + supports_fk_toggle=True, + supports_materialized_views=False, + ) + + @property + def default_index_strategy(self) -> str: + return "keep" + + def create_staging_table( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}";')) + + metadata = sa.MetaData() + staging_columns = [ + sa.Column(col.name, col.type, nullable=True) + for col in table_cls.__table__.columns + ] + staging_table = sa.Table(staging_name, metadata, *staging_columns) + metadata.create_all(bind=session.connection(), tables=[staging_table]) + session.commit() + + def drop_staging_table( + self, + session: so.Session, + staging_name: str, + ) -> None: + session.execute(sa.text(f'DROP TABLE IF EXISTS "{staging_name}"')) + + def disable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = OFF")) + if not isinstance(previous_state, int): + raise RuntimeError("Expected SQLite FK state to be an int") + return previous_state + + def enable_fk_check(self, session: so.Session) -> str | int: + previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() + session.execute(text("PRAGMA foreign_keys = ON")) + if not isinstance(previous_state, int): + raise RuntimeError("Expected SQLite FK state to be an int") + return previous_state + + def restore_fk_check( + self, + session: so.Session, + previous_state: str | int, + ) -> None: + safe_state = self._normalize_fk_check_state(previous_state) + session.execute(text(f"PRAGMA foreign_keys = {safe_state}")) + + def merge_replace( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + if len(pk_cols) == 1: + pk = pk_cols[0] + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" + WHERE "{pk}" IN ( + SELECT "{pk}" FROM "{staging_name}" + ); + """ + ) + ) + return + + pk_match = " AND ".join( + f'"{target_name}"."{c}" = "{staging_name}"."{c}"' for c in pk_cols + ) + session.execute( + sa.text( + f""" + DELETE FROM "{target_name}" + WHERE EXISTS ( + SELECT 1 FROM "{staging_name}" + WHERE {pk_match} + ); + """ + ) + ) + + def merge_upsert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT OR IGNORE INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_insert( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + insertable_cols = self._insertable_column_names(table_cls) + cols_str = ", ".join(f'"{c}"' for c in insertable_cols) + session.execute( + sa.text( + f""" + INSERT INTO "{target_name}" ({cols_str}) + SELECT {cols_str} FROM "{staging_name}"; + """ + ) + ) + + def merge_context( + self, + table_cls: type["CSVTableProtocol"], + session: so.Session, + ) -> AbstractContextManager[None]: + return self.bulk_load_context(session, disable_fk=True, no_autoflush=False) + + def create_materialized_view( + self, + bind: "Engine | Connection", + name: str, + selectable: sa.sql.Select[Any], + ) -> None: + self._require_capability("supports_materialized_views", "materialized views") + + def refresh_materialized_view( + self, + bind: "Engine | Connection", + name: str, + ) -> None: + self._require_capability("supports_materialized_views", "materialized views") + + def configure_dbapi_connection(self, dbapi_connection: sa.engine.interfaces.DBAPIConnection) -> None: + if dbapi_connection.__class__.__module__.startswith("sqlite3"): + cursor = dbapi_connection.cursor() + cursor.execute(f"PRAGMA busy_timeout = {self.busy_timeout_ms}") + cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") + cursor.execute("PRAGMA foreign_keys = ON;") + if self.defer_foreign_keys: + cursor.execute("PRAGMA defer_foreign_keys = ON;") + cursor.close() + + def install_engine_hooks(self, engine: "Engine") -> None: + @event.listens_for(engine, "connect") + def _enable_sqlite_foreign_keys( # type: ignore + dbapi_connection: sa.engine.interfaces.DBAPIConnection, + _connection_record: Any + ) -> None: + self.configure_dbapi_connection(dbapi_connection) + + def explain_fk_error( + self, + session: so.Session, + exc: IntegrityError, + *, + raise_error: bool = True, + ) -> None: + bind: Engine | Connection = session.get_bind() + if bind.dialect.name != "sqlite": + raise exc + + with self._as_connection(bind) as conn: + rows = conn.execute(text("PRAGMA foreign_key_check")).fetchall() + + if rows: + for row in rows: + logger.error( + "FK violation: table=%s rowid=%s references=%s fk_index=%s", + row[0], row[1], row[2], row[3] + ) + + if raise_error: + raise exc + + def restore_journal_mode(self, db_path: Path) -> None: + timeout_s = max(self.busy_timeout_ms / 1000, 5) + try: + with sqlite3.connect(db_path.resolve(), timeout=timeout_s) as conn: + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") + conn.execute("PRAGMA journal_mode = DELETE") + conn.commit() + except sqlite3.OperationalError as exc: + raise RuntimeError( + "Failed to restore SQLite journal mode. " + "Close or dispose active SQLite connections before calling this helper." + ) from exc diff --git a/src/orm_loader/helpers/__init__.py b/src/orm_loader/helpers/__init__.py index 32623f5..f2c49a1 100644 --- a/src/orm_loader/helpers/__init__.py +++ b/src/orm_loader/helpers/__init__.py @@ -1,7 +1,11 @@ from .errors import IngestError, ValidationError from .logging import get_logger, configure_logging from .bootstrap import bootstrap, create_db -from .sqlite import enable_sqlite_foreign_keys, explain_sqlite_fk_error +from .sqlite import ( + attach_sqlite_bulk_load_pragmas, + explain_sqlite_fk_error, + restore_sqlite_journal_mode, +) from .bulk import bulk_load_context, engine_with_replica_role from .metadata import Base from .discovery import get_model_by_tablename @@ -14,11 +18,12 @@ "configure_logging", "bootstrap", "create_db", - "enable_sqlite_foreign_keys", + "attach_sqlite_bulk_load_pragmas", "explain_sqlite_fk_error", + "restore_sqlite_journal_mode", "bulk_load_context", "engine_with_replica_role", "Base", "get_model_by_tablename", "normalise_null", -] \ No newline at end of file +] diff --git a/src/orm_loader/helpers/bootstrap.py b/src/orm_loader/helpers/bootstrap.py index 473d6e5..08f7760 100644 --- a/src/orm_loader/helpers/bootstrap.py +++ b/src/orm_loader/helpers/bootstrap.py @@ -1,13 +1,13 @@ from .metadata import Base import logging - +import sqlalchemy as sa logger = logging.getLogger(__name__) -def create_db(engine): +def create_db(engine: sa.engine.Engine) -> None: logger.debug("Creating database schema") Base.metadata.create_all(engine) -def bootstrap(engine, *, create: bool = True): +def bootstrap(engine: sa.engine.Engine, *, create: bool = True) -> None: logger.info("Bootstrapping schema (create=%s)", create) if create: create_db(engine) diff --git a/src/orm_loader/helpers/bulk.py b/src/orm_loader/helpers/bulk.py index 4c3a40a..4be22b4 100644 --- a/src/orm_loader/helpers/bulk.py +++ b/src/orm_loader/helpers/bulk.py @@ -1,61 +1,33 @@ from contextlib import contextmanager -from sqlalchemy import text, Engine +from sqlalchemy import Engine from sqlalchemy.orm import Session -import sqlalchemy as sa +from typing import Iterator +from ..backends.resolve import resolve_backend from .logging import get_logger logger = get_logger(__name__) def disable_fk_check(session: Session) -> str | int: - """Disables FK checks and returns the previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - previous_state = None - - if dialect == "postgresql": - previous_state = session.execute(text("SHOW session_replication_role")).scalar() - session.execute(text("SET session_replication_role = 'replica'")) - elif dialect == "sqlite": - previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() - session.execute(text("PRAGMA foreign_keys = OFF")) - else: - raise NotImplementedError(f"FK disable not implemented for {dialect}") - + """Disable foreign-key checks for the current session and return the previous state.""" + previous_state = resolve_backend(session).disable_fk_check(session) logger.info("Disabled foreign key checks for bulk load.") - assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + if not isinstance(previous_state, (str, int)): + logger.error(f"Unexpected FK state type: {type(previous_state)}. Expected str or int.") + raise TypeError(f"Expected previous FK state to be str or int, got {type(previous_state)}") return previous_state def enable_fk_check(session: Session) -> str | int: - """Explicitly enables FK checks and returns the previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - previous_state = None - - if dialect == "postgresql": - previous_state = session.execute(text("SHOW session_replication_role")).scalar() - session.execute(text("SET session_replication_role = 'origin'")) - elif dialect == "sqlite": - previous_state = session.execute(text("PRAGMA foreign_keys")).scalar() - session.execute(text("PRAGMA foreign_keys = ON")) - else: - raise NotImplementedError(f"FK enable not implemented for {dialect}") - + """Enable foreign-key checks for the current session and return the previous state.""" + previous_state = resolve_backend(session).enable_fk_check(session) logger.info("Explicitly re-enabled foreign key checks.") - assert isinstance(previous_state, (str, int)), "Expected previous FK state to be str or int" + if not isinstance(previous_state, (str, int)): + logger.error(f"Unexpected FK state type: {type(previous_state)}. Expected str or int.") + raise TypeError(f"Expected previous FK state to be str or int, got {type(previous_state)}") return previous_state def restore_fk_check(session: Session, previous_state: str | int): - """Restores FK checks to a specifically provided previous state.""" - engine = session.get_bind() - dialect = engine.dialect.name - - if dialect == "postgresql": - session.execute(text(f"SET session_replication_role = '{previous_state}'")) - elif dialect == "sqlite": - session.execute(text(f"PRAGMA foreign_keys = {previous_state}")) - else: - raise NotImplementedError(f"FK restore not implemented for {dialect}") - + """Restore foreign-key checks to a previously captured backend-specific state.""" + resolve_backend(session).restore_fk_check(session, previous_state) logger.info(f"Restored foreign key checks to state: {previous_state}") @contextmanager @@ -64,61 +36,36 @@ def bulk_load_context( *, disable_fk: bool = True, no_autoflush: bool = True, -): - previous_fk_state = None - try: - if disable_fk: - previous_fk_state = disable_fk_check(session) - - if no_autoflush: - with session.no_autoflush: - yield - else: - yield - - except Exception: - session.rollback() - raise +) -> Iterator[None]: + """ + Wrap a trusted bulk operation in backend-aware session settings. - finally: - if previous_fk_state is not None: - restore_fk_check(session, previous_fk_state) + This is a thin helper over ``DatabaseBackend.bulk_load_context()``. + It exists so older call sites can keep using the helper import path. + """ + backend = resolve_backend(session) + with backend.bulk_load_context( + session, + disable_fk=disable_fk, + no_autoflush=no_autoflush, + ): + yield @contextmanager -def engine_with_replica_role(engine: Engine): +def engine_with_replica_role(engine: Engine) -> Iterator[Engine]: """ - Context manager that: - - forces session_replication_role=replica on all connections - - restores DEFAULT on exit - - this is different to bulk_load_context manager from orm_loader.helpers - because this is engine scoped where that one is session scoped + Force ``session_replication_role=replica`` on PostgreSQL engine connections. - postgres only + This is engine-scoped rather than session-scoped. It is only available + on backends that explicitly implement the behaviour. """ - @sa.event.listens_for(engine, "connect") # type: ignore - def _set_replica_role(dbapi_conn, _): - cur = dbapi_conn.cursor() - cur.execute("SET session_replication_role = replica") - cur.close() - - try: - yield engine - finally: - # Explicitly restore on a fresh connection - with engine.connect() as conn: - conn = conn.execution_options(isolation_level="AUTOCOMMIT") - conn.execute(text("SET session_replication_role = DEFAULT")) - - role = conn.execute( - text("SHOW session_replication_role") - ).scalar() - - if role != "origin": - raise RuntimeError( - "Failed to restore session_replication_role" - ) - - logger.info("session_replication_role restored to DEFAULT") + backend = resolve_backend(engine) + method = getattr(backend, "engine_with_replica_role", None) + if method is None: + raise NotImplementedError( + f"Backend '{backend.name}' does not support replica-role engine contexts" + ) + with method(engine) as wrapped: + yield wrapped diff --git a/src/orm_loader/helpers/discovery.py b/src/orm_loader/helpers/discovery.py index eb3e1a1..0e0333c 100644 --- a/src/orm_loader/helpers/discovery.py +++ b/src/orm_loader/helpers/discovery.py @@ -1,11 +1,19 @@ -from typing import Type +from typing import TypeVar from .metadata import Base -def get_model_by_tablename(tablename: str, base: Type[Base] | None = None) -> Type | None: +ModelT = TypeVar("ModelT", bound=Base) + +def get_model_by_tablename( + tablename: str, + base: type[ModelT] = Base, +) -> type[ModelT] | None: tablename = tablename.lower().strip() - if base is None: - base = Base - for cls in base.__subclasses__(): + for mapper in base.registry.mappers: + cls = mapper.class_ + if not isinstance(cls, type): + continue + if not issubclass(cls, base): + continue if getattr(cls, "__tablename__", None) == tablename: return cls return None diff --git a/src/orm_loader/helpers/logging.py b/src/orm_loader/helpers/logging.py index fc92ae8..bce30f9 100644 --- a/src/orm_loader/helpers/logging.py +++ b/src/orm_loader/helpers/logging.py @@ -1,7 +1,8 @@ from __future__ import annotations + import logging -from typing import Optional import re +from typing import Any, Optional SENSITIVE_KEYS = { "password", @@ -15,22 +16,22 @@ } LOGGING_NAMESPACE = "sql_loader" + def _coerce_log_level(level: int | str) -> int: if isinstance(level, int): return level - if isinstance(level, str): - s = level.strip().upper() - if s.isdigit(): - return int(s) - - mapping = logging.getLevelNamesMapping() - if s in mapping: - return mapping[s] + if not isinstance(level, str): + raise TypeError(f"log level must be an int or str, got {type(level).__name__}") + s = level.strip().upper() + if s.isdigit(): + return int(s) - raise ValueError(f"Invalid log level: {level!r}") + mapping = logging.getLevelNamesMapping() + if s in mapping: + return mapping[s] - raise TypeError(f"Invalid log level type: {type(level)}") + raise ValueError(f"Invalid log level: {level!r}") def get_logger(name: Optional[str] = None) -> logging.Logger: @@ -46,16 +47,17 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: class RedactingFormatter(logging.Formatter): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._pattern = re.compile( r"(?i)\\b(" + "|".join(SENSITIVE_KEYS) + r")\\b\\s*[:=]\\s*[^\\s,;]+" ) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: msg = super().format(record) return self._pattern.sub(r"\\1=", msg) - + + def configure_logging( *, level: int | str = logging.INFO, @@ -87,4 +89,4 @@ def configure_logging( logger.propagate = propagate -logging.getLogger(LOGGING_NAMESPACE).addHandler(logging.NullHandler()) \ No newline at end of file +logging.getLogger(LOGGING_NAMESPACE).addHandler(logging.NullHandler()) diff --git a/src/orm_loader/helpers/sqlite.py b/src/orm_loader/helpers/sqlite.py index 19e4fe0..ca8c134 100644 --- a/src/orm_loader/helpers/sqlite.py +++ b/src/orm_loader/helpers/sqlite.py @@ -1,32 +1,45 @@ -from sqlalchemy import event, text +from pathlib import Path + from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError -import logging -logger = logging.getLogger(__name__) +from ..backends.sqlite import SQLiteBackend + + +def attach_sqlite_bulk_load_pragmas( + engine: Engine, + *, + busy_timeout_ms: int = 60000, + journal_mode: str = "WAL", + defer_foreign_keys: bool = True, +) -> None: + """ + Install SQLite connect hooks aimed at heavy local write workloads. + + The hook currently sets ``busy_timeout``, journal mode, and foreign-key + enforcement, and can also enable deferred foreign-key checking for the + connection. + + Note that this is the replacement for old ``enable_sqlite_foreign_keys()`` + workaround - this should be no longer needed. + """ + SQLiteBackend( + busy_timeout_ms=busy_timeout_ms, + journal_mode=journal_mode, + defer_foreign_keys=defer_foreign_keys, + ).install_engine_hooks(engine) -@event.listens_for(Engine, "connect") -def enable_sqlite_foreign_keys(dbapi_connection, connection_record): - if dbapi_connection.__class__.__module__.startswith("sqlite3"): - logger.debug("Enabling SQLite foreign key enforcement") - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA defer_foreign_keys = ON;") - cursor.close() def explain_sqlite_fk_error(session, exc: IntegrityError, raise_error: bool = True): - engine = session.get_bind() - if engine.dialect.name != "sqlite": - raise exc - - with engine.connect() as conn: - rows = conn.execute(text("PRAGMA foreign_key_check")).fetchall() - - if rows: - for r in rows: - logger.error( - "FK violation: table=%s rowid=%s references=%s fk_index=%s", - r[0], r[1], r[2], r[3] - ) - - if raise_error: - raise exc + """Log SQLite foreign-key check details before re-raising an error.""" + SQLiteBackend().explain_fk_error(session, exc, raise_error=raise_error) + + +def restore_sqlite_journal_mode(db_path: Path) -> None: + """ + Checkpoint WAL contents and switch the database back to ``DELETE`` mode. + + Call this after disposing active SQLite connections. Reconnecting + through an engine that still installs WAL hooks will enable WAL again. + """ + SQLiteBackend().restore_journal_mode(db_path) diff --git a/src/orm_loader/loaders/data_classes.py b/src/orm_loader/loaders/data_classes.py index f7dfe8b..d7031fd 100644 --- a/src/orm_loader/loaders/data_classes.py +++ b/src/orm_loader/loaders/data_classes.py @@ -70,7 +70,7 @@ class LoaderContext: chunksize: int | None = None normalise: bool = True dedupe: bool = True - quote_mode: str = "csv" + quote_mode: str = "auto" class LoaderInterface: @@ -170,61 +170,6 @@ def dedupe(cls, data: pd.DataFrame | pa.Table, ctx: LoaderContext) -> Any: """ raise NotImplementedError - # @classmethod - # def _dedupe_db(cls, df: pd.DataFrame, ctx: LoaderContext) -> pd.DataFrame: - # """ - # Perform database-level deduplication against existing rows. - - # Parameters - # ---------- - # df - # Incoming DataFrame. - # ctx - # Loader context. - - # Returns - # ------- - # pandas.DataFrame - # DataFrame with rows already present in the database removed. - # """ - # pk_names = ctx.tableclass.pk_names() - # pk_tuples = list(df[pk_names].itertuples(index=False, name=None)) - # if not pk_tuples: - # return df - # tableclass = ( - # ctx.staging_table - # if ctx.staging_table is not None - # else ctx.tableclass.__table__ - # ) - # pk_cols = [getattr(tableclass.c, pk) for pk in pk_names] - - # vars_per_row = len(pk_cols) - # chunk_size = max(1, 10_000 // vars_per_row) - # existing_rows: list[tuple] = [] - - # for i in range(0, len(pk_tuples), chunk_size): - # chunk = pk_tuples[i : i + chunk_size] - - # rows = ( - # ctx.session.query(*pk_cols) - # .filter(sa.tuple_(*pk_cols).in_(chunk)) - # .all() - # ) - # existing_rows.extend(rows) - - # if not existing_rows: - # return df - - # existing = pd.DataFrame(existing_rows, columns=pk_names) - - # logger.warning(f"Dropping {len(existing)} rows from {ctx.tableclass.__tablename__} that already exist in the database") - # df = ( - # df.merge(existing, on=pk_names, how="left", indicator=True) - # .loc[lambda x: x["_merge"] == "left_only"] - # .drop(columns="_merge") - # ) - # return df - @dataclass class ColumnCastingStats: @@ -286,4 +231,3 @@ def to_dict(self) -> dict[str, dict[str, Any]]: for col, stats in self.columns.items() } - diff --git a/src/orm_loader/loaders/loading_helpers.py b/src/orm_loader/loaders/loading_helpers.py index 93dd09e..cbc5ef7 100644 --- a/src/orm_loader/loaders/loading_helpers.py +++ b/src/orm_loader/loaders/loading_helpers.py @@ -1,6 +1,8 @@ from __future__ import annotations from pathlib import Path import chardet +import csv as _csv +import re import sqlalchemy as sa import sqlalchemy.orm as so import logging @@ -9,7 +11,10 @@ import pyarrow.csv as pv import io +_SAFE_ENCODING = re.compile(r'^[A-Za-z][A-Za-z0-9_-]*$') + logger = logging.getLogger(__name__) +COPY_BLOCK_SIZE = 8192 """ Loader Helper Functions @@ -85,6 +90,66 @@ def infer_delim(file): return '\t' return ',' + +def infer_quote_mode( + path: Path, + delimiter: str, + encoding: str = "utf-8", + sample_rows: int = 200, +) -> str: + """Return 'csv' or 'literal' by comparing column-count consistency under both + quoting interpretations across a sample of rows. + + - 'csv' → standard RFC-4180 quoting; surrounding double-quotes are stripped + and embedded delimiters/newlines inside quotes are preserved. + - 'literal' → double-quote has no special meaning; every byte is stored as-is. + + Defaults to 'csv' when both modes produce identical output (no quoting in play) + or when the evidence is tied. Callers can always override by passing an + explicit value instead of relying on auto-detection. + """ + with open(path, encoding=encoding, errors="replace", newline="") as f: + lines = [f.readline() for _ in range(sample_rows + 1)] + + raw = "".join(ln for ln in lines if ln) + if not raw: + return "csv" + + try: + rows_csv = list(_csv.reader(io.StringIO(raw), delimiter=delimiter)) + except _csv.Error: + return "literal" + + try: + rows_lit = list( + _csv.reader(io.StringIO(raw), delimiter=delimiter, quoting=_csv.QUOTE_NONE) + ) + except _csv.Error: + return "csv" + + if not rows_csv: + return "csv" + + ncols = len(rows_csv[0]) + if ncols <= 1: + return "csv" + + # No difference between modes → no quoting is active, csv is the safe default + if rows_csv == rows_lit: + return "csv" + + data_csv = rows_csv[1:] + data_lit = rows_lit[1:] if len(rows_lit) > 1 else [] + + if not data_csv: + return "csv" + + csv_ok = sum(1 for r in data_csv if len(r) == ncols) + lit_ok = sum(1 for r in data_lit if len(r) == ncols) + + # Prefer csv on a tie; only choose literal when it is strictly more consistent + return "literal" if lit_ok > csv_ok else "csv" + def arrow_drop_duplicates( table: pa.Table, pk_names: list[str], @@ -168,15 +233,20 @@ def quick_load_pg( path: Path, session: so.Session, tablename: str, - quote_mode: str = "csv", + quote_mode: str = "auto", ) -> int: - raw_conn = session.connection().connection + raw_conn = session.connection().connection if not hasattr(raw_conn, "cursor"): raise RuntimeError("Expected DB-API connection for COPY") - + encoding = infer_encoding(path)['encoding'] or 'utf-8' + if not _SAFE_ENCODING.match(encoding): + raise ValueError(f"Unsafe encoding value from chardet: {encoding!r}") delimiter = infer_delim(path) + if quote_mode == "auto": + quote_mode = infer_quote_mode(path, delimiter=delimiter, encoding=encoding) + logger.info(f"Auto-detected quote_mode={quote_mode!r} for {path.name}") if quote_mode == "csv": copy_options = f""" FORMAT csv, @@ -202,17 +272,17 @@ def quick_load_pg( try: with open(path, "rb") as f: stream = NormalisedCSVStream(f, encoding=encoding, delimiter=delimiter) - - cur.copy_expert( - sql=f''' + with cur.copy( + f''' COPY "{tablename}" FROM STDIN WITH ( {copy_options} ) - ''', - file=stream, - ) + ''' + ) as copy: + while data := stream.read(COPY_BLOCK_SIZE): + copy.write(data) session.flush() total = session.execute(sa.text(f'SELECT COUNT(*) FROM "{tablename}"')).scalar_one() return total diff --git a/src/orm_loader/mappers/materialised_view_mixin.py b/src/orm_loader/mappers/materialised_view_mixin.py index a01096f..34e037c 100644 --- a/src/orm_loader/mappers/materialised_view_mixin.py +++ b/src/orm_loader/mappers/materialised_view_mixin.py @@ -1,7 +1,9 @@ from sqlalchemy.ext import compiler from sqlalchemy.schema import DDLElement import sqlalchemy as sa +from typing import Any from collections import defaultdict, deque +from ..backends.resolve import resolve_backend class CreateMaterializedView(DDLElement): """ @@ -23,12 +25,16 @@ class CreateMaterializedView(DDLElement): materialized view. """ - def __init__(self, name, selectable): + def __init__(self, name: str, selectable: sa.sql.Select[Any]): self.name = name self.selectable = selectable @compiler.compiles(CreateMaterializedView) -def _create_view(element, compiler, **kw): +def _create_view( # type: ignore + element: CreateMaterializedView, + compiler: sa.sql.compiler.SQLCompiler, + **kwargs: Any +) -> str: """ `_create_view` @@ -150,11 +156,11 @@ class DailyObservationCountsMV(Base, MaterializedViewMixin): """ __mv_name__: str - __mv_select__: sa.sql.Select + __mv_select__: sa.sql.Select[Any] __mv_dependencies__: set[str] = set() @classmethod - def create_mv(cls, bind): + def create_mv(cls, bind: "sa.engine.Connection | sa.engine.Engine") -> None: """ Create the materialized view if it does not already exist. @@ -166,8 +172,8 @@ def create_mv(cls, bind): Notes ----- The underlying SQL is emitted via a custom DDL element and executed - directly against the database. This operation is not transactional - on all backends. + through the resolved backend. With the built-in backends, this means + PostgreSQL. Unsupported backends raise ``NotImplementedError``. Examples @@ -193,11 +199,11 @@ def create_mv(cls, bind): WHERE observation.observation_date >= CURRENT_DATE - INTERVAL '30 days'; ``` """ - ddl = CreateMaterializedView(cls.__mv_name__, cls.__mv_select__) - bind.execute(ddl) + backend = resolve_backend(bind) + backend.create_materialized_view(bind, cls.__mv_name__, cls.__mv_select__) @classmethod - def refresh_mv(cls, bind): + def refresh_mv(cls, bind: "sa.engine.Connection | sa.engine.Engine") -> None: """ Refresh the contents of the materialized view. @@ -208,9 +214,9 @@ def refresh_mv(cls, bind): Notes ----- - This method issues a REFRESH MATERIALIZED VIEW statement and assumes - backend support (e.g. PostgreSQL). Concurrent refresh semantics are - not handled here. + This method issues a backend-specific refresh statement. With the + built-in backends, materialized views are PostgreSQL-only. + Concurrent refresh semantics are not handled here. Examples -------- @@ -219,7 +225,8 @@ def refresh_mv(cls, bind): RecentObservationMV.refresh_mv(conn) ``` """ - bind.execute(sa.text(f"REFRESH MATERIALIZED VIEW {cls.__mv_name__};")) + backend = resolve_backend(bind) + backend.refresh_materialized_view(bind, cls.__mv_name__) def resolve_mv_refresh_order(mv_classes: list[type[MaterializedViewMixin]]) -> list[type]: @@ -271,7 +278,7 @@ def resolve_mv_refresh_order(mv_classes: list[type[MaterializedViewMixin]]) -> l return [name_to_mv[name] for name in ordered] -def refresh_all_mvs(bind, mv_classes): +def refresh_all_mvs(bind: "sa.engine.Connection | sa.engine.Engine", mv_classes: list[type[MaterializedViewMixin]]) -> None: """ `refresh_all_mvs` @@ -289,7 +296,7 @@ def refresh_all_mvs(bind, mv_classes): refresh_all_mvs(engine, ALL_MVS) ``` """ - ordered = resolve_mv_refresh_order(mv_classes) + ordered: list[type[MaterializedViewMixin]] = resolve_mv_refresh_order(mv_classes) for mv in ordered: - mv.refresh_mv(bind) \ No newline at end of file + mv.refresh_mv(bind) diff --git a/src/orm_loader/tables/loadable_table.py b/src/orm_loader/tables/loadable_table.py index f8a91c5..1ce97fe 100644 --- a/src/orm_loader/tables/loadable_table.py +++ b/src/orm_loader/tables/loadable_table.py @@ -1,20 +1,29 @@ +# pyright: reportPrivateUsage=false import sqlalchemy as sa import sqlalchemy.orm as so import logging +from sqlalchemy.exc import InvalidRequestError, UnboundExecutionError -from typing import Type, ClassVar, Optional +from typing import Type, ClassVar, Optional, Any, Iterator from pathlib import Path from contextlib import contextmanager from .orm_table import ORMTableBase from .typing import CSVTableProtocol +from ..backends.resolve import resolve_backend from ..loaders.loader_interface import LoaderInterface, LoaderContext, PandasLoader, ParquetLoader -from ..loaders.loading_helpers import quick_load_pg -from ..helpers.bulk import restore_fk_check, disable_fk_check logger = logging.getLogger(__name__) +def _require_bind(session: so.Session) -> sa.Engine | sa.Connection: + """Return a bound connectable or raise a stable runtime error.""" + try: + return session.get_bind() + except (InvalidRequestError, UnboundExecutionError) as exc: + raise RuntimeError("Session is not bound to an engine") from exc + + """ CSV Loadable Table Mixins ================================== @@ -66,7 +75,7 @@ def staging_tablename(cls: Type[CSVTableProtocol]) -> str: str The staging table name. """ - if cls._staging_tablename: + if cls._staging_tablename: # type: ignore return cls._staging_tablename return f"_staging_{cls.__tablename__}" @@ -93,75 +102,30 @@ def create_staging_table( NotImplementedError If the database dialect is unsupported. """ - table = cls.__table__ - session.execute(sa.text(f"""DROP TABLE IF EXISTS "{cls.staging_tablename()}";""")) - - if session.bind is None: - raise RuntimeError("Session is not bound to an engine") - - dialect = session.bind.dialect.name - - if dialect == "postgresql": - logger.info("Disabling indices on staging table for performance") - session.execute(sa.text(f''' - CREATE UNLOGGED TABLE "{cls.staging_tablename()}" - (LIKE "{table.name}" INCLUDING DEFAULTS INCLUDING CONSTRAINTS); - ''')) - - # Need to drop the columns we are not going to load into, otherwise the COPY will fail - computed_cols = [c.name for c in table.columns if c.computed is not None] - for col in computed_cols: - session.execute(sa.text(f'ALTER TABLE "{cls.staging_tablename()}" DROP COLUMN "{col}";')) - - elif dialect == "sqlite": - - metadata = sa.MetaData() - - staging_columns = [] - for col in table.columns: - staging_columns.append( - sa.Column( - col.name, - col.type, - nullable=True, - ) - ) - - staging_table = sa.Table( - cls.staging_tablename(), - metadata, - *staging_columns, - ) - - conn = session.connection() - metadata.create_all(bind=conn, tables=[staging_table]) - # this borks on date cols because it loses the date - # specification and reverts to NUM - # - changing to metadata.create_all approach for sqlite - # but not postgresql for now to keep unlogged table feature - # session.execute(sa.text(f''' - # CREATE TABLE "{cls.staging_tablename()}" AS - # SELECT * FROM "{table.name}" WHERE 0; - # ''')) - else: - raise NotImplementedError( - f"Staging table creation not implemented for dialect '{dialect}'" - ) - # query the sense of having internal commit here, but for now - # it is required for the ORM-based fallback loader to function - # cleanly for external pipeline purposes - - session.commit() + _require_bind(session) + backend = resolve_backend(session) + backend.create_staging_table(cls, session, cls.staging_tablename()) @classmethod @contextmanager - def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session): + def manage_indices( + cls: Type['CSVTableProtocol'], + session: so.Session, + index_strategy: str = "auto", + ) -> Iterator[None]: """ - Temporarily drops non-primary key indices before a bulk operation - and recreates them afterwards to prevent write amplification. + Manage non-primary-key indexes around a staged merge. + + ``index_strategy`` may be ``"auto"``, ``"drop_rebuild"``, or + ``"keep"``. The backend decides what ``"auto"`` means. At the + moment SQLite keeps indexes by default, while PostgreSQL drops + and rebuilds them. """ - indices = list(cls.__table__.indexes) - inspector = sa.inspect(session.bind) + backend = resolve_backend(session) + resolved_index_strategy = backend.resolve_index_strategy(index_strategy) + + indices = list(cls.__table__.indexes) if resolved_index_strategy == "drop_rebuild" else [] + inspector = sa.inspect(_require_bind(session)) assert inspector is not None, "Failed to create inspector for index management" if indices: @@ -174,20 +138,16 @@ def manage_indices(cls: Type['CSVTableProtocol'], session: so.Session): session.execute(sa.schema.DropIndex(idx)) session.commit() - # session.commit() above restores the original state of the session. We need that one after we are done - previous_fk_state = disable_fk_check(session) - try: - yield - session.commit() + with backend.merge_context(cls, session): + yield + session.commit() except Exception as e: session.rollback() logger.error(f"Table `{cls.__tablename__}`: Merge operation failed - {e}") raise finally: - restore_fk_check(session, previous_fk_state) - if indices: logger.info(f"Table `{cls.__tablename__}`: Verifying/Rebuilding indices.") inspector.clear_cache() # Required to ensure we get the current state of the database after potential changes @@ -225,10 +185,7 @@ def get_staging_table( sqlalchemy.Table The reflected staging table. """ - if session.bind is None: - raise RuntimeError("Session is not bound to an engine") - - engine = session.get_bind() + engine = _require_bind(session) inspector = sa.inspect(engine) staging_name = cls.staging_tablename() @@ -266,28 +223,25 @@ def load_staging( int Number of rows loaded into the staging table. """ - if loader_context.session.bind is None: - raise RuntimeError("Session is not bound to an engine") + _require_bind(loader_context.session) - dialect = loader_context.session.bind.dialect.name + backend = resolve_backend(loader_context.session) total = 0 try: cls.create_staging_table(loader_context.session) - if dialect == "postgresql": - try: - total = quick_load_pg( - path=loader_context.path, - session=loader_context.session, - tablename=cls.staging_tablename(), - quote_mode=loader_context.quote_mode, - ) + try: + total = backend.load_staging_fast( + loader_context=loader_context, + staging_name=cls.staging_tablename(), + ) + if total is not None: return total - except Exception as e: - loader_context.session.rollback() - logger.warning(f"COPY failed for {cls.staging_tablename()}: {e}") - logger.info('Falling back to ORM-based load functionality') + except Exception as e: + loader_context.session.rollback() + logger.warning(f"Fast-path load failed for {cls.staging_tablename()}: {e}") + logger.info('Falling back to ORM-based load functionality') total = cls.orm_staging_load( loader=loader, @@ -346,7 +300,8 @@ def load_csv( dedupe: bool = False, chunksize: int | None = None, merge_strategy: str = "replace", - quote_mode: str = "csv", + quote_mode: str = "auto", + index_strategy: str = "auto", ) -> int: """ @@ -374,11 +329,16 @@ def load_csv( Optional chunk size for incremental loading. merge_strategy Merge strategy to apply (e.g. ``replace`` or ``upsert``). + quote_mode + Quoting mode used by the PostgreSQL fast-path loader. + index_strategy + Index handling strategy during merge. Use ``"auto"`` to let + the backend choose a sensible default. Returns ------- int - Number of rows loaded. + Number of rows loaded into staging before merge. """ logger.debug(f"Table `{cls.__tablename__}`: Loading CSV from {path}") @@ -403,12 +363,12 @@ def load_csv( loader = cls._select_loader(path) # Load to staging (Indices are already excluded via updated create_staging_table) - logger.info(f"Table `{cls.__tablename__}`: Loading data into unlogged staging table") + logger.info(f"Table `{cls.__tablename__}`: Loading data into staging table") total = cls.load_staging(loader=loader, loader_context=loader_context) # Merge staging to target (Wrapped in our index dropper!) logger.info(f"Table `{cls.__tablename__}`: Merging staging data into target table") - with cls.manage_indices(session): + with cls.manage_indices(session, index_strategy=index_strategy): cls.merge_from_staging(session, merge_strategy=merge_strategy) cls.drop_staging_table(session) @@ -423,8 +383,7 @@ def _merge_replace( session: so.Session, target: str, staging: str, - pk_cols: list[str], - dialect: str + pk_cols: list[str] ): """ Merge staging data by replacing existing rows. @@ -432,37 +391,8 @@ def _merge_replace( Existing target rows matching the staging primary keys are deleted prior to insertion. """ - if dialect == "postgresql": - pk_join = " AND ".join( - f't."{c}" = s."{c}"' for c in pk_cols - ) - - session.execute(sa.text(f""" - DELETE FROM "{target}" t - USING "{staging}" s - WHERE {pk_join}; - """)) - - elif dialect == "sqlite": - if len(pk_cols) == 1: - pk = pk_cols[0] - session.execute(sa.text(f""" - DELETE FROM "{target}" - WHERE "{pk}" IN ( - SELECT "{pk}" FROM "{staging}" - ); - """)) - else: - pk_match = " AND ".join( - f'"{target}"."{c}" = "{staging}"."{c}"' for c in pk_cols - ) - session.execute(sa.text(f""" - DELETE FROM "{target}" - WHERE EXISTS ( - SELECT 1 FROM "{staging}" - WHERE {pk_match} - ); - """)) + backend = resolve_backend(session) + backend.merge_replace(cls, session, target, staging, pk_cols) @classmethod def _merge_insert( @@ -470,18 +400,12 @@ def _merge_insert( session: so.Session, target: str, staging: str - ): + ): """ Insert all rows from the staging table into the target table. """ - # Get all columns that are NOT computed - insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] - cols_str = ", ".join(f'"{c}"' for c in insertable_cols) - - session.execute(sa.text(f""" - INSERT INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}"; - """)) + backend = resolve_backend(session) + backend.merge_insert(cls, session, target, staging) @classmethod @@ -490,33 +414,13 @@ def _merge_upsert( session: so.Session, target: str, staging: str, - pk_cols: list[str], - dialect: str + pk_cols: list[str] ): """ Merge staging data using an upsert strategy. """ - - # Get all columns that are NOT computed - insertable_cols = [c.name for c in cls.__table__.columns if c.computed is None] - cols_str = ", ".join(f'"{c}"' for c in insertable_cols) - - if dialect == "postgresql": - # INSERT … ON CONFLICT DO NOTHING - session.execute(sa.text(f""" - INSERT INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}" - ON CONFLICT ({", ".join(f'"{c}"' for c in pk_cols)}) DO NOTHING; - """)) - - elif dialect == "sqlite": - session.execute(sa.text(f""" - INSERT OR IGNORE INTO "{target}" ({cols_str}) - SELECT {cols_str} FROM "{staging}"; - """)) - - else: - raise NotImplementedError + backend = resolve_backend(session) + backend.merge_upsert(cls, session, target, staging, pk_cols) @classmethod def merge_from_staging( @@ -538,17 +442,13 @@ def merge_from_staging( staging = cls.staging_tablename() pk_cols = cls.pk_names() - if not session.bind: - raise RuntimeError("Session is not bound to an engine") - - dialect = session.bind.dialect.name + _require_bind(session) if merge_strategy == "replace": cls._merge_replace( session=session, target=target, staging=staging, pk_cols=pk_cols, - dialect=dialect, ) cls._merge_insert( session=session, @@ -561,7 +461,6 @@ def merge_from_staging( target=target, staging=staging, pk_cols=pk_cols, - dialect=dialect, ) else: raise ValueError(f"Unknown merge strategy '{merge_strategy}'") @@ -571,12 +470,11 @@ def drop_staging_table(cls: Type[CSVTableProtocol], session: so.Session): """ Drop the staging table if it exists. """ - session.execute( - sa.text(f'DROP TABLE IF EXISTS "{cls.staging_tablename()}"') - ) + backend = resolve_backend(session) + backend.drop_staging_table(session, cls.staging_tablename()) @classmethod - def csv_columns(cls) -> dict[str, sa.ColumnElement]: + def csv_columns(cls) -> dict[str, sa.ColumnElement[Any]]: """ Return a mapping of CSV column names to model columns. @@ -590,4 +488,4 @@ def csv_columns(cls) -> dict[str, sa.ColumnElement]: """ cols = cls.model_columns() computed_names = {c.name for c in cls.__table__.columns if c.computed is not None} # type: ignore - return {k: v for k, v in cols.items() if k not in computed_names} \ No newline at end of file + return {k: v for k, v in cols.items() if k not in computed_names} diff --git a/src/orm_loader/tables/orm_table.py b/src/orm_loader/tables/orm_table.py index a7748e6..24c634d 100644 --- a/src/orm_loader/tables/orm_table.py +++ b/src/orm_loader/tables/orm_table.py @@ -1,7 +1,7 @@ import sqlalchemy as sa import sqlalchemy.orm as so -from sqlalchemy.exc import StatementError -from typing import Any, Tuple, Type, cast +from sqlalchemy.exc import NoInspectionAvailable, StatementError +from typing import Any import logging from .allocators import IdAllocator from ..helpers import normalise_null @@ -46,7 +46,7 @@ class ORMTableBase: __abstract__ = True @classmethod - def mapper_for(cls: Type) -> so.Mapper: + def mapper_for(cls: type[Any]) -> so.Mapper[Any]: """ Return the SQLAlchemy mapper associated with this ORM class. @@ -63,13 +63,13 @@ def mapper_for(cls: Type) -> so.Mapper: TypeError If the class is not a mapped SQLAlchemy ORM class. """ - mapper = sa.inspect(cls) - if not mapper: + try: + return sa.inspect(cls) + except NoInspectionAvailable: raise TypeError(f"{cls.__name__} is not a mapped ORM class") - return cast(so.Mapper, mapper) @classmethod - def pk_columns(cls) -> list[sa.ColumnElement]: + def pk_columns(cls) -> list[sa.ColumnElement[Any]]: """ Return the primary key columns for the mapped table. @@ -120,7 +120,7 @@ def pk_values(cls, obj: Any) -> dict[str, Any]: return {c.key: getattr(obj, c.key) for c in cls.pk_columns() if c.key is not None} @classmethod - def pk_tuple(cls, obj: Any) -> Tuple[Any, ...]: + def pk_tuple(cls, obj: Any) -> tuple[Any, ...]: """ Extract primary key values from an ORM instance as a tuple. @@ -143,7 +143,7 @@ def pk_tuple(cls, obj: Any) -> Tuple[Any, ...]: ) @classmethod - def model_columns(cls) -> dict[str, sa.ColumnElement]: + def model_columns(cls) -> dict[str, sa.ColumnElement[Any]]: """ Return all mapped columns for the table. @@ -153,7 +153,7 @@ def model_columns(cls) -> dict[str, sa.ColumnElement]: A mapping of column name to column object. """ mapper = cls.mapper_for() - return {c.key: c for c in mapper.columns if c.key is not None} + return {c.key: c for c in mapper.columns} @classmethod def required_columns(cls) -> set[str]: @@ -177,11 +177,11 @@ def required_columns(cls) -> set[str]: return { c.key for c in mapper.columns - if not c.nullable and not c.default and not c.server_default and c.key is not None + if not c.nullable and not c.default and not c.server_default } @classmethod - def max_id(cls, session) -> int: + def max_id(cls, session: so.Session) -> int: """ Return the maximum value of the primary key column. @@ -211,7 +211,7 @@ def max_id(cls, session) -> int: return session.query(sa.func.max(pk)).scalar() or 0 @classmethod - def allocator(cls, session) -> IdAllocator: + def allocator(cls, session: so.Session) -> IdAllocator: """ Create an ID allocator initialised from the current table state. @@ -251,7 +251,7 @@ def clean_kwargs( """ cols = cls.model_columns() - cleaned = {} + cleaned: dict[str, Any] = {} for k, v in data.items(): if k not in cols: continue # ignore unknown keys safely diff --git a/src/orm_loader/tables/serialisable_table.py b/src/orm_loader/tables/serialisable_table.py index e340b45..2310855 100644 --- a/src/orm_loader/tables/serialisable_table.py +++ b/src/orm_loader/tables/serialisable_table.py @@ -1,10 +1,14 @@ -from .orm_table import ORMTableBase -from typing import Any +from typing import Any, Unpack +from collections.abc import Iterator import json import hashlib import datetime -def json_default(obj) -> str: +from .orm_table import ORMTableBase +from .typing import ToDictKwargs + + +def json_default(obj: Any) -> str: """ Default JSON serialisation handler for unsupported types. @@ -79,7 +83,7 @@ def to_dict( dict[str, Any] A dictionary representation of the ORM row. """ - data = {} + data: dict[str, Any] = {} for key, _ in self.model_columns().items(): if only and key not in only: continue @@ -91,7 +95,7 @@ def to_dict( data[key] = value return data - def to_json(self, **kwargs) -> str: + def to_json(self, **kwargs: Unpack[ToDictKwargs]) -> str: """ Serialise the ORM instance to a JSON string. @@ -133,7 +137,7 @@ def fingerprint(self) -> str: payload = self.to_json(include_nulls=True) return hashlib.sha256(payload.encode("utf-8")).hexdigest() - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, Any]]: """ Iterate over the ORM instance as ``(key, value)`` pairs. @@ -147,7 +151,7 @@ def __iter__(self): """ yield from self.to_dict().items() - def __json__(self): + def __json__(self) -> dict[str, Any]: """ Return a JSON-serialisable representation of the ORM instance. diff --git a/src/orm_loader/tables/typing.py b/src/orm_loader/tables/typing.py index b61d183..bda4751 100644 --- a/src/orm_loader/tables/typing.py +++ b/src/orm_loader/tables/typing.py @@ -1,4 +1,4 @@ -from typing import Protocol, ClassVar, runtime_checkable, TYPE_CHECKING, Optional, Type, Dict, Any +from typing import Protocol, ClassVar, runtime_checkable, TYPE_CHECKING, Optional, Type, Dict, Any, Unpack, TypedDict import sqlalchemy.orm as so import sqlalchemy as sa from pathlib import Path @@ -6,6 +6,11 @@ if TYPE_CHECKING: from ..loaders import LoaderContext, LoaderInterface +class ToDictKwargs(TypedDict, total=False): + include_nulls: bool + only: set[str] | None + exclude: set[str] | None + @runtime_checkable class ORMTableProtocol(Protocol): """ @@ -28,17 +33,16 @@ class ORMTableProtocol(Protocol): metadata: ClassVar[sa.MetaData] @classmethod - def mapper_for(cls) -> so.Mapper: ... + def mapper_for(cls) -> so.Mapper[Any]: ... @classmethod def pk_names(cls) -> list[str]: ... @classmethod - def pk_columns(cls) -> list[sa.ColumnElement]: ... + def pk_columns(cls) -> list[sa.ColumnElement[Any]]: ... @classmethod - def model_columns(cls) -> dict[str, sa.ColumnElement]: ... - + def model_columns(cls) -> dict[str, sa.ColumnElement[Any]]: ... @runtime_checkable class CSVTableProtocol(ORMTableProtocol, Protocol): @@ -72,15 +76,17 @@ def load_staging(cls: Type["CSVTableProtocol"], loader: "LoaderInterface", loade @classmethod def load_csv( - cls, - session: so.Session, - path: Path, - *, - normalise: bool = True, - dedupe: bool = False, - chunksize: int | None = None, + cls, + session: so.Session, + path: Path, + *, + loader: Optional["LoaderInterface"] = None, + normalise: bool = True, + dedupe: bool = False, + chunksize: int | None = None, merge_strategy: str = "replace", quote_mode: str = "csv", + index_strategy: str = "auto", ) -> int: ... @classmethod @@ -99,13 +105,13 @@ def drop_staging_table(cls, session: so.Session) -> None: ... def _merge_insert(cls, session: so.Session, target: str, staging: str) -> None: ... @classmethod - def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ... + def _merge_replace(cls, session: so.Session, target: str, staging: str, pk_cols: list[str]) -> None: ... @classmethod - def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str], dialect: str) -> None: ... + def _merge_upsert(cls, session: so.Session, target: str, staging: str, pk_cols: list[str]) -> None: ... @classmethod - def manage_indices(cls, session: so.Session) -> AbstractContextManager[None]: + def manage_indices(cls, session: so.Session, index_strategy: str = "auto") -> AbstractContextManager[None]: ... @@ -130,11 +136,10 @@ def to_dict( exclude: set[str] | None = None, ) -> Dict[str, Any]: ... - def to_json(self, **kwargs) -> str: ... + def to_json(self, **kwargs: Unpack[ToDictKwargs]) -> str: ... def fingerprint(self) -> str: ... def __iter__(self) -> Any: ... def __json__(self) -> Any: ... - diff --git a/tests/backends/test_base_backend.py b/tests/backends/test_base_backend.py new file mode 100644 index 0000000..e1d2b44 --- /dev/null +++ b/tests/backends/test_base_backend.py @@ -0,0 +1,287 @@ +from __future__ import annotations + +import importlib +import importlib.abc +import sys +from importlib.machinery import ModuleSpec +from types import ModuleType +from typing import TYPE_CHECKING, Sequence, Type, cast, Any + +import pytest +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy.engine import Connection, Engine + +from orm_loader.backends import ( + BackendCapabilities, + DatabaseBackend, + Dialect, + resolve_backend, +) + +if TYPE_CHECKING: + from orm_loader.loaders.data_classes import LoaderContext + from orm_loader.tables.typing import CSVTableProtocol + + +class _BlockPsycopg(importlib.abc.MetaPathFinder): + def find_spec( + self, + fullname: str, + path: Sequence[str] | None = None, + target: ModuleType | None = None, + ) -> ModuleSpec | None: + if fullname == "psycopg" or fullname.startswith("psycopg."): + raise ModuleNotFoundError("No module named 'psycopg'") + return None + + +class FakeBackend(DatabaseBackend): + def __init__(self) -> None: + self.calls: list[tuple[str, object]] = [] + + @property + def name(self) -> str: + return "fake" + + @property + def dialect(self) -> Dialect: + return Dialect.SQLITE + + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities( + supports_fast_load=True, + supports_fk_toggle=True, + ) + + def create_staging_table( + self, table_cls: Type[CSVTableProtocol], session: so.Session, staging_name: str + ) -> None: + return None + + def drop_staging_table(self, session: so.Session, staging_name: str) -> None: + return None + + def merge_replace( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + return None + + def merge_upsert( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + pk_cols: list[str], + ) -> None: + return None + + def merge_insert( + self, + table_cls: Type[CSVTableProtocol], + session: so.Session, + target_name: str, + staging_name: str, + ) -> None: + return None + + @staticmethod + def _normalize_fk_check_state(previous_state: str | int) -> str | int: + return previous_state + + def disable_fk_check(self, session: so.Session) -> str | int: + self.calls.append(("disable_fk_check", session)) + return "enabled" + + def enable_fk_check(self, session: so.Session) -> str | int: + self.calls.append(("enable_fk_check", session)) + return "disabled" + + def restore_fk_check(self, session: so.Session, previous_state: str | int) -> None: + self.calls.append(("restore_fk_check", previous_state)) + + def create_materialized_view( + self, bind: Engine | Connection, name: str, selectable: sa.sql.Select[Any] + ) -> None: + return None + + def refresh_materialized_view(self, bind: Engine | Connection, name: str) -> None: + return None + + +class _ComputedTable: + __table__ = sa.Table( + "computed_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + +def test_backend_capabilities_defaults(): + caps = BackendCapabilities() + + assert caps.supports_fast_load is False + assert caps.supports_unlogged_staging is False + assert caps.supports_fk_toggle is False + assert caps.supports_materialized_views is False + + +def test_database_backend_is_abstract(): + with pytest.raises(TypeError): + DatabaseBackend() # type: ignore + + +def test_fake_backend_can_implement_contract(): + backend = FakeBackend() + + assert backend.name == "fake" + assert backend.dialect == Dialect.SQLITE + assert backend.capabilities.supports_fast_load is True + assert backend.capabilities.supports_fk_toggle is True + assert backend.supports_dialect(Dialect.SQLITE) is True + assert backend.supports_dialect(Dialect.POSTGRESQL) is False + assert backend.resolve_index_strategy("auto") == "drop_rebuild" + assert backend.resolve_index_strategy("keep") == "keep" + assert backend.load_staging_fast(cast("LoaderContext", None), "staging") is None + + with backend.merge_context(cast("Type[CSVTableProtocol]", None), cast(so.Session, None)): + pass + + +def test_require_capability_passes_for_supported_feature(): + backend = FakeBackend() + + backend._require_capability("supports_fast_load", "fast loading") + + +def test_require_capability_raises_for_unsupported_feature(): + backend = FakeBackend() + + with pytest.raises(NotImplementedError, match="does not support materialized views"): + backend._require_capability("supports_materialized_views", "materialized views") + + +def test_require_capability_raises_for_unknown_flag(): + backend = FakeBackend() + + with pytest.raises(AttributeError, match="Unknown backend capability"): + backend._require_capability("not_a_capability", "something") + + +def test_resolve_index_strategy_raises_for_invalid_value(): + backend = FakeBackend() + + with pytest.raises(ValueError, match="Unknown index_strategy"): + backend.resolve_index_strategy("not-valid") + + +def test_insertable_column_names_exclude_computed_columns(): + backend = FakeBackend() + + assert backend._insertable_column_names(_ComputedTableCls) == ["id", "name"] + + +def test_bulk_load_context_toggles_fk_and_restores(session): + backend = FakeBackend() + + with backend.bulk_load_context(session): + pass + + assert backend.calls == [ + ("disable_fk_check", session), + ("restore_fk_check", "enabled"), + ] + + +def test_bulk_load_context_without_fk_toggle(session): + backend = FakeBackend() + + with backend.bulk_load_context(session, disable_fk=False): + pass + + assert backend.calls == [] + + +def test_bulk_load_context_raises_when_capability_missing(session): + class NoFKBackend(FakeBackend): + @property + def capabilities(self) -> BackendCapabilities: + return BackendCapabilities() + + backend = NoFKBackend() + + with pytest.raises(NotImplementedError, match="does not support foreign key toggling"): + with backend.bulk_load_context(session): + pass + + +def test_bulk_load_context_rolls_back_and_restores(session): + backend = FakeBackend() + + with pytest.raises(RuntimeError, match="boom"): + with backend.bulk_load_context(session): + raise RuntimeError("boom") + + assert backend.calls == [ + ("disable_fk_check", session), + ("restore_fk_check", "enabled"), + ] + + +def test_backends_package_exports(): + import orm_loader.backends as backends + + assert backends.DatabaseBackend is DatabaseBackend + assert backends.BackendCapabilities is BackendCapabilities + assert backends.Dialect is Dialect + assert backends.resolve_backend is resolve_backend + + +def test_resolve_backend_for_sqlite_engine_and_session(): + engine = sa.create_engine("sqlite:///:memory:", future=True) + session = so.Session(engine) + + try: + engine_backend = resolve_backend(engine) + session_backend = resolve_backend(session) + + assert engine_backend.name == "sqlite" + assert session_backend.name == "sqlite" + finally: + session.close() + + +def test_resolve_backend_raises_for_unknown_dialect(): + class _Unknown: + class dialect: + name = "unknown" + + with pytest.raises(NotImplementedError, match="Unsupported SQLAlchemy dialect"): + resolve_backend(cast(Engine, _Unknown())) + + +def test_backends_import_does_not_require_psycopg(): + blocker = _BlockPsycopg() + original = sys.modules.pop("orm_loader.backends", None) + sys.meta_path.insert(0, blocker) + + try: + module = importlib.import_module("orm_loader.backends") + assert module.DatabaseBackend is not None + finally: + sys.meta_path.remove(blocker) + sys.modules.pop("orm_loader.backends", None) + if original is not None: + sys.modules["orm_loader.backends"] = original diff --git a/tests/backends/test_postgres_backend.py b/tests/backends/test_postgres_backend.py new file mode 100644 index 0000000..6ac4467 --- /dev/null +++ b/tests/backends/test_postgres_backend.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import sqlalchemy.event as sae +from typing import TYPE_CHECKING, Type, cast + +import sqlalchemy as sa +import sqlalchemy.orm as so +from sqlalchemy.dialects import postgresql +from sqlalchemy.engine import Connection, Engine + +from orm_loader.backends import Dialect, PostgresBackend + +if TYPE_CHECKING: + from orm_loader.tables.typing import CSVTableProtocol + + +class _ComputedTable: + __table__ = sa.Table( + "target_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +class _FakeSession: + def __init__(self, scalar_result: str | int = "origin") -> None: + self.statements: list[str] = [] + self.scalar_result = scalar_result + self.commits = 0 + + def execute(self, statement): + if hasattr(statement, "compile"): + sql = str(statement.compile(dialect=postgresql.dialect())) + else: + sql = str(statement) + self.statements.append(sql) + + class _Result: + def __init__(self, value): + self._value = value + + def scalar(self): + return self._value + + return _Result(self.scalar_result) + + def commit(self) -> None: + self.commits += 1 + + +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + +def _sess(s: _FakeSession) -> so.Session: + return cast(so.Session, s) + + +def _as_engine(s: _FakeSession) -> Engine | Connection: + return cast(Engine, s) + + +def test_postgres_backend_identity_and_capabilities(): + backend = PostgresBackend() + + assert backend.name == "postgres" + assert backend.dialect == Dialect.POSTGRESQL + assert backend.supports_dialect(Dialect.POSTGRESQL) is True + assert backend.capabilities.supports_fast_load is True + assert backend.capabilities.supports_unlogged_staging is True + assert backend.capabilities.supports_fk_toggle is True + assert backend.capabilities.supports_materialized_views is True + + +def test_postgres_backend_create_staging_table_drops_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.create_staging_table(_ComputedTableCls, _sess(session), "_staging_target_table") + + assert any('DROP TABLE IF EXISTS "_staging_target_table"' in sql for sql in session.statements) + assert any('CREATE UNLOGGED TABLE "_staging_target_table"' in sql for sql in session.statements) + assert any('ALTER TABLE "_staging_target_table" DROP COLUMN "slug"' in sql for sql in session.statements) + assert session.commits == 1 + + +def test_postgres_backend_drop_staging_table(): + backend = PostgresBackend() + session = _FakeSession() + + backend.drop_staging_table(_sess(session), "_staging_target_table") + + assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] + + +def test_postgres_backend_fk_methods_emit_expected_sql(): + backend = PostgresBackend() + session = _FakeSession() + + previous = backend.disable_fk_check(_sess(session)) + enabled = backend.enable_fk_check(_sess(session)) + backend.restore_fk_check(_sess(session), previous) + + assert previous == "origin" + assert enabled == "origin" + assert session.statements == [ + "SHOW session_replication_role", + "SET session_replication_role = 'replica'", + "SHOW session_replication_role", + "SET session_replication_role = 'origin'", + "SET session_replication_role = 'origin'", + ] + + +def test_postgres_backend_merge_replace_uses_using_delete(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_replace(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"]) + + sql = session.statements[0] + assert 'DELETE FROM "target_table" t' in sql + assert 'USING "_staging_target_table" s' in sql + assert 't."id" = s."id" AND t."name" = s."name"' in sql + + +def test_postgres_backend_merge_insert_excludes_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_insert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table") + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'SELECT "id", "name" FROM "_staging_target_table"' in sql + + +def test_postgres_backend_merge_upsert_excludes_computed_columns(): + backend = PostgresBackend() + session = _FakeSession() + + backend.merge_upsert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"]) + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'ON CONFLICT ("id") DO NOTHING' in sql + + +def test_postgres_backend_materialized_view_methods_emit_expected_sql(): + backend = PostgresBackend() + session = _FakeSession() + selectable = sa.select(sa.literal(1).label("n")) + + backend.create_materialized_view(_as_engine(session), "mv_test", selectable) + backend.refresh_materialized_view(_as_engine(session), "mv_test") + + assert any("CREATE MATERIALIZED VIEW IF NOT EXISTS mv_test as SELECT" in sql for sql in session.statements) + assert any("REFRESH MATERIALIZED VIEW mv_test;" == sql for sql in session.statements) + + +def test_postgres_backend_normalize_fk_check_state(): + normalize = PostgresBackend._normalize_fk_check_state + + assert normalize("origin") == "origin" + assert normalize("local") == "local" + assert normalize("replica") == "replica" + assert normalize(" ORIGIN ") == "origin" + + try: + normalize("invalid_role") + except ValueError as exc: + assert "Invalid PostgreSQL session_replication_role" in str(exc) + else: + raise AssertionError("Expected ValueError for unrecognised role") + + try: + normalize(1) + except ValueError as exc: + assert "Postgres uses string roles" in str(exc) + else: + raise AssertionError("Expected ValueError for integer input") + + +def test_postgres_backend_disable_fk_raises_when_show_returns_non_string(): + backend = PostgresBackend() + session = _FakeSession(scalar_result=42) + + try: + backend.disable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected PostgreSQL FK state to be a string" in str(exc) + else: + raise AssertionError("Expected RuntimeError when SHOW returns a non-string") + + +def test_postgres_backend_enable_fk_raises_when_show_returns_non_string(): + backend = PostgresBackend() + session = _FakeSession(scalar_result=42) + + try: + backend.enable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected PostgreSQL FK state to be a string" in str(exc) + else: + raise AssertionError("Expected RuntimeError when SHOW returns a non-string") + + +def test_postgres_backend_engine_with_replica_role_unregisters_listener(monkeypatch): + backend = PostgresBackend() + events: list[tuple[str, object, str]] = [] + statements: list[str] = [] + + class _Result: + def scalar(self): + return "origin" + + class _Conn: + def __enter__(self): + return self + + def __exit__(self, *_) -> None: + return None + + def execution_options(self, **_): + return self + + def execute(self, statement): + sql = str(statement.compile(dialect=postgresql.dialect())) + statements.append(sql) + return _Result() + + class _Engine: + def connect(self): + events.append(("connect", self, "connect")) + return _Conn() + + engine = _Engine() + + def _listen(target, name, *_) -> None: + events.append(("listen", target, name)) + + def _remove(target, name, *_) -> None: + events.append(("remove", target, name)) + + monkeypatch.setattr(sae, "listen", _listen) + monkeypatch.setattr(sae, "remove", _remove) + + with backend.engine_with_replica_role(cast(Engine, engine)): + pass + + assert events == [ + ("listen", engine, "connect"), + ("remove", engine, "connect"), + ("connect", engine, "connect"), + ] + assert statements == [ + "SET session_replication_role = DEFAULT", + "SHOW session_replication_role", + ] diff --git a/tests/backends/test_sqlite_backend.py b/tests/backends/test_sqlite_backend.py new file mode 100644 index 0000000..5b1f060 --- /dev/null +++ b/tests/backends/test_sqlite_backend.py @@ -0,0 +1,328 @@ +from __future__ import annotations + +import sqlite3 +from pathlib import Path +from typing import TYPE_CHECKING, Type, cast + +import sqlalchemy as sa +import sqlalchemy.orm as so + +from orm_loader.backends import Dialect, SQLiteBackend +from orm_loader.helpers.sqlite import attach_sqlite_bulk_load_pragmas + +if TYPE_CHECKING: + from orm_loader.tables.typing import CSVTableProtocol + + +class _ComputedTable: + __table__ = sa.Table( + "target_table", + sa.MetaData(), + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("name", sa.String), + sa.Column("slug", sa.String, sa.Computed("lower(name)")), + ) + + +class _FakeSession: + def __init__(self, scalar_result: int | str = 1) -> None: + self.statements: list[str] = [] + self.scalar_result = scalar_result + + def execute(self, statement): + self.statements.append(str(statement)) + + class _Result: + def __init__(self, value): + self._value = value + + def scalar(self): + return self._value + + return _Result(self.scalar_result) + + +_ComputedTableCls = cast("Type[CSVTableProtocol]", _ComputedTable) + + +def _sess(s: _FakeSession) -> so.Session: + return cast(so.Session, s) + + +def test_sqlite_backend_identity_and_capabilities(): + backend = SQLiteBackend() + + assert backend.name == "sqlite" + assert backend.dialect == Dialect.SQLITE + assert backend.supports_dialect(Dialect.SQLITE) is True + assert backend.capabilities.supports_fast_load is False + assert backend.capabilities.supports_unlogged_staging is False + assert backend.capabilities.supports_fk_toggle is True + assert backend.capabilities.supports_materialized_views is False + assert backend.resolve_index_strategy("auto") == "keep" + assert backend.journal_mode == "WAL" + + +def test_sqlite_backend_create_staging_table(session, engine): + backend = SQLiteBackend() + + backend.create_staging_table(_ComputedTableCls, session, "_staging_target_table") + inspector = sa.inspect(engine) + assert inspector.has_table("_staging_target_table") is True + cols = inspector.get_columns("_staging_target_table") + assert [c["name"] for c in cols] == ["id", "name", "slug"] + assert all(c["nullable"] is True for c in cols) + + +def test_sqlite_backend_drop_staging_table(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.drop_staging_table(_sess(session), "_staging_target_table") + + assert session.statements == ['DROP TABLE IF EXISTS "_staging_target_table"'] + + +def test_sqlite_backend_disable_fk_reads_then_sets(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result=1) + + previous = backend.disable_fk_check(_sess(session)) + + assert previous == 1 + assert session.statements == [ + "PRAGMA foreign_keys", # read current state + "PRAGMA foreign_keys = OFF", # set to OFF + ] + + +def test_sqlite_backend_enable_fk_reads_then_sets(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result=0) + + previous = backend.enable_fk_check(_sess(session)) + + assert previous == 0 + assert session.statements == [ + "PRAGMA foreign_keys", # read current state + "PRAGMA foreign_keys = ON", # set to ON + ] + + +def test_sqlite_backend_restore_fk_normalises_int_and_emits(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.restore_fk_check(_sess(session), 1) + backend.restore_fk_check(_sess(session), 0) + + assert session.statements == [ + "PRAGMA foreign_keys = ON", + "PRAGMA foreign_keys = OFF", + ] + + +def test_sqlite_backend_normalize_fk_check_state(): + normalize = SQLiteBackend._normalize_fk_check_state + + assert normalize(1) == "ON" + assert normalize(0) == "OFF" + assert normalize("1") == "ON" + assert normalize("0") == "OFF" + assert normalize("ON") == "ON" + assert normalize("OFF") == "OFF" + assert normalize("on") == "ON" + assert normalize("off") == "OFF" + + try: + normalize(2) + except ValueError as exc: + assert "Invalid SQLite foreign_keys state" in str(exc) + else: + raise AssertionError("Expected ValueError for out-of-range int") + + try: + normalize("enabled") + except ValueError as exc: + assert "Invalid SQLite foreign_keys state" in str(exc) + else: + raise AssertionError("Expected ValueError for unrecognised string") + + +def test_sqlite_backend_merge_replace_single_pk(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_replace( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"] + ) + + sql = session.statements[0] + assert 'DELETE FROM "target_table"' in sql + assert 'SELECT "id" FROM "_staging_target_table"' in sql + + +def test_sqlite_backend_merge_replace_composite_pk(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_replace( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id", "name"] + ) + + sql = session.statements[0] + assert "WHERE EXISTS (" in sql + assert '"target_table"."id" = "_staging_target_table"."id"' in sql + assert '"target_table"."name" = "_staging_target_table"."name"' in sql + + +def test_sqlite_backend_merge_insert_excludes_computed_columns(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_insert(_ComputedTableCls, _sess(session), "target_table", "_staging_target_table") + + sql = session.statements[0] + assert 'INSERT INTO "target_table" ("id", "name")' in sql + assert 'SELECT "id", "name" FROM "_staging_target_table"' in sql + + +def test_sqlite_backend_merge_upsert_excludes_computed_columns(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.merge_upsert( + _ComputedTableCls, _sess(session), "target_table", "_staging_target_table", ["id"] + ) + + sql = session.statements[0] + assert 'INSERT OR IGNORE INTO "target_table" ("id", "name")' in sql + + +def test_sqlite_backend_materialized_view_methods_raise(engine): + backend = SQLiteBackend() + selectable = sa.select(sa.literal(1).label("n")) + + try: + backend.create_materialized_view(engine, "mv_test", selectable) + except NotImplementedError as exc: + assert "does not support materialized views" in str(exc) + else: + raise AssertionError("Expected create_materialized_view() to raise NotImplementedError") + + try: + backend.refresh_materialized_view(engine, "mv_test") + except NotImplementedError as exc: + assert "does not support materialized views" in str(exc) + else: + raise AssertionError("Expected refresh_materialized_view() to raise NotImplementedError") + + +def test_sqlite_backend_configures_bulk_load_pragmas(tmp_path: Path): + backend = SQLiteBackend() + db_path = tmp_path / "test.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + backend.install_engine_hooks(engine) + + with engine.connect() as conn: + busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() + journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + foreign_keys = conn.execute(sa.text("PRAGMA foreign_keys")).scalar_one() + + assert busy_timeout == 60000 + assert str(journal_mode).lower() == "wal" + assert foreign_keys == 1 + + +def test_sqlite_backend_restore_journal_mode(tmp_path: Path): + backend = SQLiteBackend() + db_path = tmp_path / "journal.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + backend.install_engine_hooks(engine) + + with engine.begin() as conn: + conn.execute(sa.text("CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)")) + conn.execute(sa.text("INSERT INTO t (name) VALUES ('x')")) + + engine.dispose() + backend.restore_journal_mode(db_path) + + with sqlite3.connect(db_path.resolve()) as conn: + journal_mode = conn.execute("PRAGMA journal_mode").fetchone()[0] + + assert str(journal_mode).lower() == "delete" + + +def test_attach_sqlite_bulk_load_pragmas_installs_backend_hook(tmp_path: Path): + db_path = tmp_path / "attached.db" + engine = sa.create_engine(f"sqlite:///{db_path}", future=True) + + attach_sqlite_bulk_load_pragmas(engine, busy_timeout_ms=45000) + + with engine.connect() as conn: + busy_timeout = conn.execute(sa.text("PRAGMA busy_timeout")).scalar_one() + journal_mode = conn.execute(sa.text("PRAGMA journal_mode")).scalar_one() + foreign_keys = conn.execute(sa.text("PRAGMA foreign_keys")).scalar_one() + + assert busy_timeout == 45000 + assert str(journal_mode).lower() == "wal" + assert foreign_keys == 1 + + +def test_sqlite_backend_rejects_invalid_journal_mode(): + try: + SQLiteBackend(journal_mode="wal; drop table x;") + except ValueError as exc: + assert "Unsupported SQLite journal_mode" in str(exc) + else: + raise AssertionError("Expected invalid journal_mode to raise ValueError") + + +def test_sqlite_backend_disable_fk_raises_when_pragma_returns_non_int(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result="not_an_int") + + try: + backend.disable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected SQLite FK state to be an int" in str(exc) + else: + raise AssertionError("Expected RuntimeError when PRAGMA returns a non-int") + + +def test_sqlite_backend_enable_fk_raises_when_pragma_returns_non_int(): + backend = SQLiteBackend() + session = _FakeSession(scalar_result="not_an_int") + + try: + backend.enable_fk_check(_sess(session)) + except RuntimeError as exc: + assert "Expected SQLite FK state to be an int" in str(exc) + else: + raise AssertionError("Expected RuntimeError when PRAGMA returns a non-int") + + +def test_sqlite_backend_restore_fk_accepts_string_values(): + backend = SQLiteBackend() + session = _FakeSession() + + backend.restore_fk_check(_sess(session), "ON") + backend.restore_fk_check(_sess(session), "OFF") + + assert session.statements == [ + "PRAGMA foreign_keys = ON", + "PRAGMA foreign_keys = OFF", + ] + + +def test_sqlite_backend_fk_toggle_round_trip(session): + backend = SQLiteBackend() + + session.execute(sa.text("PRAGMA foreign_keys = ON")) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 1 + + previous = backend.disable_fk_check(session) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 0 + + backend.restore_fk_check(session, previous) + assert session.execute(sa.text("PRAGMA foreign_keys")).scalar() == 1 diff --git a/tests/conftest.py b/tests/conftest.py index e12e9a3..64a2531 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,94 @@ +import os +import time +from pathlib import Path +from urllib.parse import urlparse, urlunparse + import pytest import sqlalchemy as sa import sqlalchemy.orm as so -import time +from dotenv import load_dotenv from tests.models import Base +load_dotenv(Path(__file__).parent.parent / ".env") + + @pytest.fixture def engine(): - return sa.create_engine("sqlite:///:memory:") + engine = sa.create_engine("sqlite:///:memory:", future=True) + Base.metadata.create_all(engine) + return engine + @pytest.fixture def session(engine): - Base.metadata.create_all(engine) with so.Session(engine) as s: yield s -POSTGRES_URL = "postgresql+psycopg2://test:test@localhost:55432/test_db" +# --------------------------------------------------------------------------- +# Postgres fixtures +# --------------------------------------------------------------------------- + +POSTGRES_URL = os.getenv( + "TEST_POSTGRES_URL", + "postgresql+psycopg://test:test@localhost:55432/test", +) + +# Shown whenever Postgres is unreachable — centralised so every skip carries +# the same actionable instructions. +_PG_SKIP_MSG = ( + "Postgres tests skipped — could not connect to {url}.\n" + " Set TEST_POSTGRES_URL to a writable test database and re-run, e.g.:\n" + " export TEST_POSTGRES_URL='postgresql+psycopg://user:pass@host:5432/orm_loader_test'\n" + " Or add it to orm-loader/.env.\n" + " Last error: {{last_err}}" +).format(url=POSTGRES_URL) + +# Module-level sentinel: None = not yet attempted, str = skip reason. +# Prevents the 20-retry loop from running once per postgres test when +# the server is not reachable. +_pg_unavailable: str | None = None + + +def _ensure_db_exists(url: str) -> None: + """Create the target database if it doesn't already exist. + + Connects to the 'postgres' maintenance database (same host/user/pass) + so the target database can be created without touching anything else. + """ + parsed = urlparse(url) + db_name = parsed.path.lstrip("/") + admin_url = urlunparse(parsed._replace(path="/postgres")) + + admin_engine = sa.create_engine(admin_url, isolation_level="AUTOCOMMIT") + try: + with admin_engine.connect() as conn: + exists = conn.execute( + sa.text("SELECT 1 FROM pg_database WHERE datname = :name"), + {"name": db_name}, + ).scalar() + if not exists: + conn.execute(sa.text(f'CREATE DATABASE "{db_name}"')) + print(f"Created test database: {db_name!r}") + finally: + admin_engine.dispose() + @pytest.fixture(scope="session") def pg_engine(): + global _pg_unavailable + + # Fast path: already know Postgres is not reachable — skip immediately + # without re-running the retry loop. + if _pg_unavailable is not None: + pytest.skip(_pg_unavailable) + + try: + _ensure_db_exists(POSTGRES_URL) + except Exception as e: + print(f"Could not ensure test DB exists (will try connecting anyway): {e}") + last_err = None for i in range(20): try: @@ -35,13 +104,14 @@ def pg_engine(): print(f"[{i}] Postgres not ready:", repr(e)) time.sleep(1) - raise RuntimeError(f"Postgres never became available: {last_err!r}") + _pg_unavailable = _PG_SKIP_MSG.format(last_err=last_err) + pytest.skip(_pg_unavailable) + @pytest.fixture def pg_session(pg_engine): Session = so.sessionmaker(bind=pg_engine, future=True) with pg_engine.begin() as conn: - # optional: recreate schema per test Base.metadata.drop_all(conn) Base.metadata.create_all(conn) @@ -51,7 +121,3 @@ def pg_session(pg_engine): finally: session.rollback() session.close() - - - - diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 6328bfd..b8d6f8f 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -4,11 +4,15 @@ services: environment: POSTGRES_USER: test POSTGRES_PASSWORD: test - POSTGRES_DB: test_db + POSTGRES_DB: test ports: - "55432:5432" + volumes: + - postgres_orm_test_data:/var/lib/postgresql/data healthcheck: - test: ["CMD-SHELL", "pg_isready -U test"] + test: ["CMD-SHELL", "pg_isready -U test -d test"] interval: 2s timeout: 2s - retries: 10 \ No newline at end of file + retries: 10 +volumes: + postgres_orm_test_data: \ No newline at end of file diff --git a/tests/helpers/test_discovery.py b/tests/helpers/test_discovery.py new file mode 100644 index 0000000..5bf16c8 --- /dev/null +++ b/tests/helpers/test_discovery.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import sqlalchemy as sa + +from orm_loader.helpers.discovery import get_model_by_tablename +from orm_loader.helpers.metadata import Base + + +def test_get_model_by_tablename_supports_nested_inheritance() -> None: + class Child(Base): + __abstract__ = True + + class GrandChild(Child): + __tablename__ = "_discovery_grandchild" + id = sa.Column(sa.Integer, primary_key=True) + + resolved = get_model_by_tablename("_discovery_grandchild") + assert resolved is GrandChild + + +def test_get_model_by_tablename_returns_none_for_unknown_table() -> None: + assert get_model_by_tablename("_not_a_real_table_name_") is None diff --git a/tests/loaders/test_dedupe.py b/tests/loaders/test_dedupe.py index 6cae76b..c84a6b3 100644 --- a/tests/loaders/test_dedupe.py +++ b/tests/loaders/test_dedupe.py @@ -1,10 +1,12 @@ import pyarrow as pa +from typing import cast, Type from orm_loader.loaders.loading_helpers import arrow_drop_duplicates import pandas as pd import sqlalchemy as sa import sqlalchemy.orm as so from sqlalchemy.orm import DeclarativeBase from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol from orm_loader.loaders.loader_interface import PandasLoader @@ -19,6 +21,9 @@ class DedupTable(Base, CSVLoadableTableInterface): value: so.Mapped[str] = so.mapped_column(sa.String, nullable=False) +_DedupTable = cast(Type[CSVTableProtocol], DedupTable) + + def test_arrow_drop_duplicates_simple(): table = pa.table({ "id": [1, 1, 2], @@ -31,8 +36,8 @@ def test_arrow_drop_duplicates_simple(): -def test_internal_deduplication(session, tmp_path): - Base.metadata.create_all(session.get_bind()) +def test_internal_deduplication(session, engine, tmp_path): + Base.metadata.create_all(engine) csv = tmp_path / "dedup_table.csv" pd.DataFrame( @@ -43,7 +48,7 @@ def test_internal_deduplication(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = DedupTable.load_csv( # type: ignore + inserted = _DedupTable.load_csv( session, csv, loader=PandasLoader(), diff --git a/tests/loaders/test_helpers.py b/tests/loaders/test_helpers.py index 9a55de2..da907d1 100644 --- a/tests/loaders/test_helpers.py +++ b/tests/loaders/test_helpers.py @@ -1,5 +1,5 @@ from orm_loader.loaders.data_classes import ColumnCastingStats, TableCastingStats -from orm_loader.loaders.loading_helpers import infer_delim, infer_encoding +from orm_loader.loaders.loading_helpers import infer_delim, infer_encoding, infer_quote_mode def test_column_casting_stats_records_examples(): stats = ColumnCastingStats() @@ -35,3 +35,35 @@ def test_infer_encoding_utf8(tmp_path): p.write_text("hello") enc = infer_encoding(p).get("encoding") or "" assert enc.lower() in {"utf-8", "ascii"} + + +def test_infer_quote_mode_unquoted_tsv_returns_csv(tmp_path): + # No quotes anywhere: both modes identical, csv is the safe default + p = tmp_path / "x.csv" + p.write_text("id\tname\tvalue\n1\tAlice\t10\n2\tBob\t20\n") + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_rfc4180_quoted_field_returns_csv(tmp_path): + # Athena-style: quoted concept_name at the varchar(255) boundary, + # no embedded delimiter — the column-count tie-break must favour csv + p = tmp_path / "x.csv" + long_name = "A" * 255 + p.write_text(f'id\tname\n1\t"{long_name}"\n2\tnormal\n') + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_embedded_delimiter_in_quoted_field_returns_csv(tmp_path): + # Quoted field contains the delimiter: csv mode keeps column count consistent, + # literal mode splits on the embedded tab and produces ragged rows + p = tmp_path / "x.csv" + p.write_text('id\tname\tval\n1\t"foo\tbar"\t99\n2\tbaz\t0\n') + assert infer_quote_mode(p, delimiter="\t") == "csv" + + +def test_infer_quote_mode_unbalanced_quote_returns_literal(tmp_path): + # Unbalanced leading quote breaks CSV parsing: literal mode produces + # consistent 2-column rows while csv mode does not + p = tmp_path / "x.csv" + p.write_text('id\tname\n1\t"open\n2\t"open\n3\t"open\n') + assert infer_quote_mode(p, delimiter="\t") == "literal" diff --git a/tests/loaders/test_loader_e2e.py b/tests/loaders/test_loader_e2e.py index 697d601..bb53dd9 100644 --- a/tests/loaders/test_loader_e2e.py +++ b/tests/loaders/test_loader_e2e.py @@ -1,37 +1,27 @@ -import sqlalchemy as sa -import sqlalchemy.orm as so -from sqlalchemy.orm import Session -from pathlib import Path +from typing import Type, cast + +import numpy as np import pandas as pd import pytest +import sqlalchemy as sa +import sqlalchemy.event as sae +import sqlalchemy.orm as so + from orm_loader.loaders.data_classes import _clean_nulls -from orm_loader.tables.loadable_table import CSVLoadableTableInterface from orm_loader.loaders.loader_interface import PandasLoader +from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol +from tests.models import Base, CompositeTable, RequiredTable, SimpleTable -from tests.models import Base, SimpleTable, RequiredTable, CompositeTable - -import numpy as np - -@pytest.fixture -def engine(): - engine = sa.create_engine("sqlite:///:memory:", future=True) - Base.metadata.create_all(engine) - return engine - - -@pytest.fixture -def session(engine): - with Session(engine) as session: - yield session - - -@pytest.fixture -def tmp_csv_dir(tmp_path: Path) -> Path: - return tmp_path +# Typed aliases: Pylance cannot verify SQLAlchemy metaclass-generated attrs +# satisfy CSVTableProtocol structurally, so we cast once per class here. +_SimpleTable = cast(Type[CSVTableProtocol], SimpleTable) +_RequiredTable = cast(Type[CSVTableProtocol], RequiredTable) +_CompositeTable = cast(Type[CSVTableProtocol], CompositeTable) -def test_initial_csv_load(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_initial_csv_load(session, tmp_path): + csv_path = tmp_path / "test_table.csv" pd.DataFrame( [ @@ -43,7 +33,7 @@ def test_initial_csv_load(session, tmp_csv_dir): loader = PandasLoader() - inserted = SimpleTable.load_csv( # type: ignore + inserted = _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -53,9 +43,7 @@ def test_initial_csv_load(session, tmp_csv_dir): assert inserted == 3 - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id) - ).scalars().all() + rows = session.execute(sa.select(SimpleTable).order_by(SimpleTable.id)).scalars().all() assert [(r.id, r.name) for r in rows] == [ (1, "alpha"), @@ -64,8 +52,8 @@ def test_initial_csv_load(session, tmp_csv_dir): ] -def test_replace_merge_strategy(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_replace_merge_strategy(session, tmp_path): + csv_path = tmp_path / "test_table.csv" # Initial load pd.DataFrame( @@ -78,7 +66,7 @@ def test_replace_merge_strategy(session, tmp_csv_dir): loader = PandasLoader() - SimpleTable.load_csv( # type: ignore + _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -94,7 +82,7 @@ def test_replace_merge_strategy(session, tmp_csv_dir): ] ).to_csv(csv_path, index=False, sep="\t") - replaced = SimpleTable.load_csv( # type: ignore + replaced = _SimpleTable.load_csv( session, csv_path, dedupe=False, @@ -105,9 +93,7 @@ def test_replace_merge_strategy(session, tmp_csv_dir): assert replaced == 2 - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id) - ).scalars().all() + rows = session.execute(sa.select(SimpleTable).order_by(SimpleTable.id)).scalars().all() assert [(r.id, r.name) for r in rows] == [ (1, "alpha"), @@ -116,14 +102,14 @@ def test_replace_merge_strategy(session, tmp_csv_dir): ] -def test_empty_csv_is_noop(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_empty_csv_is_noop(session, tmp_path): + csv_path = tmp_path / "test_table.csv" csv_path.touch() loader = PandasLoader() - inserted = SimpleTable.load_csv( # type: ignore - session, + inserted = _SimpleTable.load_csv( + session, csv_path, dedupe=False, loader=loader, @@ -136,10 +122,7 @@ def test_empty_csv_is_noop(session, tmp_csv_dir): assert rows == [] - def test_required_column_violation_drops_rows(session, tmp_path): - Base.metadata.create_all(session.get_bind()) - csv = tmp_path / "required_table.csv" pd.DataFrame( [ @@ -148,7 +131,7 @@ def test_required_column_violation_drops_rows(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = RequiredTable.load_csv( # type: ignore + inserted = _RequiredTable.load_csv( session, csv, loader=PandasLoader(), @@ -159,10 +142,7 @@ def test_required_column_violation_drops_rows(session, tmp_path): assert inserted == 1 - def test_composite_pk_dedup(session, tmp_path): - Base.metadata.create_all(session.get_bind()) - csv = tmp_path / "composite_table.csv" pd.DataFrame( [ @@ -172,7 +152,7 @@ def test_composite_pk_dedup(session, tmp_path): ] ).to_csv(csv, index=False) - inserted = CompositeTable.load_csv( # type: ignore + inserted = _CompositeTable.load_csv( session, csv, loader=PandasLoader(), @@ -206,8 +186,8 @@ def test_composite_pk_dedup(session, tmp_path): ), ], ) -def test_merge_strategies(session, tmp_csv_dir, merge_strategy, expected_rows, expected_inserted): - csv_path = tmp_csv_dir / "test_table.csv" +def test_merge_strategies(session, tmp_path, merge_strategy, expected_rows, expected_inserted): + csv_path = tmp_path / "test_table.csv" pd.DataFrame( [ @@ -239,19 +219,21 @@ def test_merge_strategies(session, tmp_csv_dir, merge_strategy, expected_rows, e assert inserted == expected_inserted - rows = session.execute( - sa.select(SimpleTable).order_by(SimpleTable.id, SimpleTable.name) - ).scalars().all() + rows = ( + session.execute(sa.select(SimpleTable).order_by(SimpleTable.id, SimpleTable.name)) + .scalars() + .all() + ) assert [(r.id, r.name) for r in rows] == expected_rows -def test_staging_table_is_created_and_dropped(session, tmp_csv_dir): - csv_path = tmp_csv_dir / "test_table.csv" +def test_staging_table_is_created_and_dropped(session, engine, tmp_path): + csv_path = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False) - SimpleTable.load_csv( + _SimpleTable.load_csv( session, csv_path, loader=PandasLoader(), @@ -259,7 +241,7 @@ def test_staging_table_is_created_and_dropped(session, tmp_csv_dir): ) session.commit() - inspector = sa.inspect(session.get_bind()) + inspector = sa.inspect(engine) assert not inspector.has_table(SimpleTable.staging_tablename()) @@ -290,9 +272,11 @@ def test_composite_pk_replace_merge(session, tmp_path): ) session.commit() - rows = session.execute( - sa.select(CompositeTable).order_by(CompositeTable.a, CompositeTable.b) - ).scalars().all() + rows = ( + session.execute(sa.select(CompositeTable).order_by(CompositeTable.a, CompositeTable.b)) + .scalars() + .all() + ) assert [(r.a, r.b, r.value) for r in rows] == [ (1, 1, "x_updated"), @@ -318,31 +302,33 @@ def test_clean_nulls_basic(): assert _clean_nulls(float("nan")) is None assert _clean_nulls(np.nan) is None + def test_clean_nulls_passthrough(): assert _clean_nulls("") == "" - assert _clean_nulls("nan") == "nan" # string 'nan' must not be converted + assert _clean_nulls("nan") == "nan" # string 'nan' must not be converted assert _clean_nulls(0) == 0 assert _clean_nulls("S") == "S" -def test_nullable_column_with_nan_does_not_crash(session, tmp_path): +def test_nullable_column_with_nan_does_not_crash(session, engine, tmp_path): class NullableTable(Base, CSVLoadableTableInterface): __tablename__ = "nullable_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) flag: so.Mapped[str | None] = so.mapped_column(sa.String, nullable=True) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _NullableTable = cast(Type[CSVTableProtocol], NullableTable) csv = tmp_path / "nullable_table.csv" pd.DataFrame( [ {"id": 1, "flag": "S"}, - {"id": 2, "flag": None}, # becomes NaN in pandas + {"id": 2, "flag": None}, # becomes NaN in pandas ] ).to_csv(csv, index=False) - inserted = NullableTable.load_csv( # type: ignore + inserted = _NullableTable.load_csv( session, csv, loader=PandasLoader(), @@ -352,9 +338,7 @@ class NullableTable(Base, CSVLoadableTableInterface): assert inserted == 2 - rows = session.execute( - sa.select(NullableTable).order_by(NullableTable.id) - ).scalars().all() + rows = session.execute(sa.select(NullableTable).order_by(NullableTable.id)).scalars().all() assert [(r.id, r.flag) for r in rows] == [ (1, "S"), @@ -362,24 +346,22 @@ class NullableTable(Base, CSVLoadableTableInterface): ] -def test_embedded_newline_in_field_is_preserved(session, tmp_path): +def test_embedded_newline_in_field_is_preserved(session, engine, tmp_path): class TextTable(Base, CSVLoadableTableInterface): __tablename__ = "text_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _TextTable = cast(Type[CSVTableProtocol], TextTable) csv = tmp_path / "text_table.csv" # Properly quoted CSV with embedded newline - csv.write_text( - 'id\tname\n' - '1\t"hello\nworld"\n' - ) + csv.write_text('id\tname\n1\t"hello\nworld"\n') - TextTable.load_csv( # type: ignore + _TextTable.load_csv( session, csv, loader=PandasLoader(), @@ -391,22 +373,20 @@ class TextTable(Base, CSVLoadableTableInterface): assert rows[0].name == "hello\nworld" -def test_embedded_tab_in_field(session, tmp_path): +def test_embedded_tab_in_field(session, engine, tmp_path): class TextTable2(Base, CSVLoadableTableInterface): __tablename__ = "tab_table" id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String) - Base.metadata.create_all(session.get_bind()) + Base.metadata.create_all(engine) + _TextTable2 = cast(Type[CSVTableProtocol], TextTable2) csv = tmp_path / "tab_table.csv" - csv.write_text( - 'id\tname\n' - '1\t"foo\tbar"\n' - ) + csv.write_text('id\tname\n1\t"foo\tbar"\n') - TextTable2.load_csv( # type: ignore + _TextTable2.load_csv( session, csv, loader=PandasLoader(), @@ -417,6 +397,84 @@ class TextTable2(Base, CSVLoadableTableInterface): rows = session.execute(sa.select(TextTable2)).scalars().all() assert rows[0].name == "foo\tbar" + +# --- index_strategy tests --- + + +def _make_ddl_tracker(engine): + """Return a list that is populated with DROP/CREATE INDEX statements as they execute.""" + ddl_log: list[str] = [] + + def _capture(*args): + statement: str = args[2] + upper = statement.strip().upper() + if upper.startswith("DROP INDEX") or upper.startswith("CREATE INDEX"): + ddl_log.append(statement.strip()) + + sae.listen(engine, "before_cursor_execute", _capture) + return ddl_log + + +def test_auto_strategy_keeps_indices_on_sqlite(session, engine, tmp_path): + """On SQLite, 'auto' resolves to 'keep' — no index DDL should be emitted.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_path / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( + csv_path, index=False, sep="\t" + ) + + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="auto") + session.commit() + + assert not any("DROP INDEX" in s.upper() for s in ddl_log) + assert not any("CREATE INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_explicit_keep_preserves_indices(session, engine, tmp_path): + """Explicit 'keep' emits no index DDL regardless of dialect.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_path / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") + + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="keep") + session.commit() + + assert not any("DROP INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_explicit_drop_rebuild_on_sqlite_restores_index(session, engine, tmp_path): + """Explicit 'drop_rebuild' drops then restores the index even on SQLite.""" + ddl_log = _make_ddl_tracker(engine) + csv_path = tmp_path / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}]).to_csv( + csv_path, index=False, sep="\t" + ) + + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="drop_rebuild") + session.commit() + + assert any("DROP INDEX" in s.upper() for s in ddl_log) + assert any("CREATE INDEX" in s.upper() for s in ddl_log) + inspector = sa.inspect(engine) + inspector.clear_cache() + assert "ix_test_table_name" in {idx["name"] for idx in inspector.get_indexes("test_table")} + + +def test_invalid_index_strategy_raises(session, tmp_path): + """An unrecognised strategy value raises ValueError before any DB work.""" + csv_path = tmp_path / "test_table.csv" + pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv_path, index=False, sep="\t") + + with pytest.raises(ValueError, match="Unknown index_strategy"): + _SimpleTable.load_csv(session, csv_path, loader=PandasLoader(), index_strategy="not-valid") + + # from hypothesis import given, strategies as st # from sqlalchemy.orm import declarative_base # from pathlib import Path @@ -460,4 +518,4 @@ class TextTable2(Base, CSVLoadableTableInterface): # assert rows == [] # else: # # stored value may be str-canonicalised version -# assert rows[0].txt.encode("utf-8", errors="replace") == s.encode("utf-8", errors="replace") \ No newline at end of file +# assert rows[0].txt.encode("utf-8", errors="replace") == s.encode("utf-8", errors="replace") diff --git a/tests/loaders/test_parquet_loader.py b/tests/loaders/test_parquet_loader.py index 8dbca70..a1e735a 100644 --- a/tests/loaders/test_parquet_loader.py +++ b/tests/loaders/test_parquet_loader.py @@ -4,8 +4,10 @@ import sqlalchemy as sa import sqlalchemy.orm as so from sqlalchemy.orm import DeclarativeBase +from typing import cast, Type from orm_loader.tables.loadable_table import CSVLoadableTableInterface +from orm_loader.tables.typing import CSVTableProtocol from orm_loader.loaders.loader_interface import ParquetLoader @@ -20,8 +22,11 @@ class ParquetTable(Base, CSVLoadableTableInterface): value: so.Mapped[int] = so.mapped_column(sa.Integer, nullable=False) -def test_parquet_loader(session, tmp_path): - Base.metadata.create_all(session.get_bind()) +_ParquetTable = cast(Type[CSVTableProtocol], ParquetTable) + + +def test_parquet_loader(session, engine, tmp_path): + Base.metadata.create_all(engine) df = pd.DataFrame( [ @@ -33,7 +38,7 @@ def test_parquet_loader(session, tmp_path): path = tmp_path / "parquet_table.parquet" pq.write_table(table, path) - inserted = ParquetTable.load_csv( # type: ignore + inserted = _ParquetTable.load_csv( session, path, loader=ParquetLoader(), diff --git a/tests/loaders/test_pg_loader.py b/tests/loaders/test_pg_loader.py index 32228fd..0e278a8 100644 --- a/tests/loaders/test_pg_loader.py +++ b/tests/loaders/test_pg_loader.py @@ -49,8 +49,8 @@ def fake_quick_load_pg(*args, **kwargs): called["copy"] = True return 1 - import orm_loader.tables.loadable_table as loadable_table - monkeypatch.setattr(loadable_table, "quick_load_pg", fake_quick_load_pg) + import orm_loader.backends.postgres as pg_backend + monkeypatch.setattr(pg_backend, "quick_load_pg", fake_quick_load_pg) inserted = SimpleTable.load_csv(pg_session, csv) pg_session.commit() @@ -63,12 +63,12 @@ def test_copy_failure_falls_back_to_orm(pg_session, tmp_path, monkeypatch): csv = tmp_path / "test_table.csv" pd.DataFrame([{"id": 1, "name": "alpha"}]).to_csv(csv, index=False) - from orm_loader.loaders import loading_helpers + import orm_loader.backends.postgres as pg_backend def broken_copy(*args, **kwargs): raise RuntimeError("boom") - monkeypatch.setattr(loading_helpers, "quick_load_pg", broken_copy) + monkeypatch.setattr(pg_backend, "quick_load_pg", broken_copy) inserted = SimpleTable.load_csv(pg_session, csv) pg_session.commit() @@ -145,7 +145,9 @@ def test_infer_encoding_utf8(tmp_path): p.write_text("id,name\n1,α\n", encoding="utf-8") enc = infer_encoding(p) - assert enc["encoding"].lower().startswith("utf") + enc_str = enc["encoding"] + assert enc_str is not None + assert enc_str.lower().startswith("utf") def test_infer_delim_tab(tmp_path): p = tmp_path / "tab.csv" diff --git a/tests/models.py b/tests/models.py index b581b87..630361a 100644 --- a/tests/models.py +++ b/tests/models.py @@ -14,6 +14,9 @@ class PandasLoaderTable(CSVLoadableTableInterface, Base): class SimpleTable(Base, CSVLoadableTableInterface): __tablename__ = "test_table" + __table_args__ = ( + sa.Index("ix_test_table_name", "name"), + ) id: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) name: so.Mapped[str] = so.mapped_column(sa.String, nullable=False) @@ -32,4 +35,3 @@ class CompositeTable(Base, CSVLoadableTableInterface): a: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) b: so.Mapped[int] = so.mapped_column(sa.Integer, primary_key=True) value: so.Mapped[str] = so.mapped_column(sa.String) - diff --git a/tests/pg_db.py b/tests/pg_db.py index e383f9d..d0aacd5 100644 --- a/tests/pg_db.py +++ b/tests/pg_db.py @@ -5,7 +5,7 @@ from tests.models import Base -POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test_db" +POSTGRES_URL = "postgresql+psycopg://test:test@localhost:55432/test" @pytest.fixture(scope="session") def pg_engine(): diff --git a/tests/pytest.ini b/tests/pytest.ini deleted file mode 100644 index d89701f..0000000 --- a/tests/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -markers = - postgres: requires a running postgres instance \ No newline at end of file diff --git a/tests/tables/test_orm_table_base.py b/tests/tables/test_orm_table_base.py index 23c1dab..806ad10 100644 --- a/tests/tables/test_orm_table_base.py +++ b/tests/tables/test_orm_table_base.py @@ -1,11 +1,12 @@ -from sqlalchemy.exc import NoInspectionAvailable -import sqlalchemy.orm as so +import pytest import sqlalchemy as sa +import sqlalchemy.orm as so + from orm_loader.tables.orm_table import ORMTableBase -import pytest Base = so.declarative_base() + def test_pk_introspection(): class T(ORMTableBase, Base): __tablename__ = "t" @@ -13,8 +14,10 @@ class T(ORMTableBase, Base): assert T.pk_names() == ["id"] + def test_pk_missing_raises(): class T(ORMTableBase): __tablename__ = "t" - with pytest.raises(NoInspectionAvailable): + + with pytest.raises(TypeError, match="not a mapped ORM class"): T.pk_columns() diff --git a/uv.lock b/uv.lock index 0ec9570..f483117 100644 --- a/uv.lock +++ b/uv.lock @@ -134,6 +134,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "dotenv" +version = "0.9.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dotenv" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, +] + [[package]] name = "editorconfig" version = "0.17.1" @@ -618,7 +629,7 @@ wheels = [ [[package]] name = "orm-loader" -version = "0.3.27" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "chardet" }, @@ -629,6 +640,7 @@ dependencies = [ [package.optional-dependencies] dev = [ + { name = "dotenv" }, { name = "mkdocs" }, { name = "mkdocs-material" }, { name = "mkdocs-mermaid2-plugin" }, @@ -639,24 +651,30 @@ dev = [ { name = "requests" }, { name = "ruff" }, ] +postgres = [ + { name = "psycopg", extra = ["binary"] }, +] [package.metadata] requires-dist = [ { name = "chardet", specifier = ">=5.2.0" }, + { name = "dotenv", marker = "extra == 'dev'" }, { name = "mkdocs", marker = "extra == 'dev'", specifier = ">=1.6.1" }, { name = "mkdocs-material", marker = "extra == 'dev'", specifier = ">=9.7.1" }, { name = "mkdocs-mermaid2-plugin", marker = "extra == 'dev'" }, { name = "mkdocstrings-python", marker = "extra == 'dev'", specifier = ">=2.0.1" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" }, { name = "pandas", specifier = ">=2.3.3" }, + { name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.2" }, { name = "pyarrow", specifier = ">=23.0.0" }, { name = "pygments", marker = "extra == 'dev'", specifier = ">=2.20.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=9.0.3" }, { name = "requests", marker = "extra == 'dev'", specifier = ">=2.33.0" }, + { name = "ruff", marker = "extra == 'dev'" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.11" }, { name = "sqlalchemy", specifier = ">=2.0.45" }, ] -provides-extras = ["dev"] +provides-extras = ["postgres", "dev"] [[package]] name = "packaging" @@ -750,6 +768,64 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "psycopg" +version = "3.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/2f/cb91e5502ec9de1de6f1b76cfbf69531932725361168bb06963620c77e2e/psycopg-3.3.4.tar.gz", hash = "sha256:e21207764952cff81b6b8bdacad9a3939f2793367fdac2987b3aac36a651b5bc", size = 165799, upload-time = "2026-05-01T23:31:55.179Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/e0/7b3dee031daae7743609ce3c746565d4a3ed7c2c186479eb48e34e838c64/psycopg-3.3.4-py3-none-any.whl", hash = "sha256:b6bbc25ccf05c8fad3b061d9db2ef0909a555171b84b07f29458a447253d679a", size = 213001, upload-time = "2026-05-01T23:20:50.816Z" }, +] + +[package.optional-dependencies] +binary = [ + { name = "psycopg-binary", marker = "implementation_name != 'pypy'" }, +] + +[[package]] +name = "psycopg-binary" +version = "3.3.4" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/7d/03818e13ba7f36de93573c93ee3482006d3dfa8b0f8d28df511bad0a1a92/psycopg_binary-3.3.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5ab28a2a7649df3b72e6b674b4c190e448e8e77cf496a65bd846472048de2089", size = 4591122, upload-time = "2026-05-01T23:27:56.162Z" }, + { url = "https://files.pythonhosted.org/packages/a5/b9/11b341edf8d54e2694726b273fe9652b254d989f4f63e3ac6816ad6b55f4/psycopg_binary-3.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6402a9d8146cf4b3974ded3fd28a971e83dc6a0333eb7822524a3aa20b546578", size = 4669943, upload-time = "2026-05-01T23:28:04.522Z" }, + { url = "https://files.pythonhosted.org/packages/8b/18/4665bacd65e7865b4372fcd8abb8b9186ada4b0025f8c2ca691b364a556c/psycopg_binary-3.3.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:580ae30a5f95ccd90008ec697d3ed6a4a2047a516407ad904283fa42086936e9", size = 5469697, upload-time = "2026-05-01T23:28:11.337Z" }, + { url = "https://files.pythonhosted.org/packages/7c/b1/b83136c6e510593d9b0c759ba5384337bc4ad82d19fda675adc4b2703c84/psycopg_binary-3.3.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e7510c37550f91a187e3660a8cc50d4b760f8c3b8b2f89ebc5698cd2c7f2c85d", size = 5152995, upload-time = "2026-05-01T23:28:20.529Z" }, + { url = "https://files.pythonhosted.org/packages/67/8d/a9821e2a648afe6091989929982a3b0f00b2631a859cb81379728f08fb75/psycopg_binary-3.3.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77df19583501ea288eaf15ac0fe7ad01e6d8091a91d5c41df5c718f307d8e31b", size = 6738180, upload-time = "2026-05-01T23:28:30.654Z" }, + { url = "https://files.pythonhosted.org/packages/7e/58/2e349e8d23905dc2317b80ac65f48fb6f821a4777a4e994a60da91c4850f/psycopg_binary-3.3.4-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:018fbed325936da502feb546642c982dcc4b9ffdea32dfef78dbf3b7f7ad4070", size = 4978828, upload-time = "2026-05-01T23:28:37.277Z" }, + { url = "https://files.pythonhosted.org/packages/45/48/57b00d03b4721878326122a1f1e6b0a90b85bcaec56b5b2f8ea6cfa45235/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:17a21953a9e5ff3a16dab692625a3676e2f101db5e40072f39dbee2250194d68", size = 4509757, upload-time = "2026-05-01T23:28:43.078Z" }, + { url = "https://files.pythonhosted.org/packages/25/37/33b47d8c007df69aec500df5889767c4d313748e8e9e27a2fef8a6dabcee/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:eb05ee1c2b817d27c537333224c9e83c7afb86fe7296ba970990068baf819b16", size = 4190546, upload-time = "2026-05-01T23:28:50.016Z" }, + { url = "https://files.pythonhosted.org/packages/ca/c6/32b0835dbc2122617902b649d76a91c1e75406e76bf3d595b0c3bb5ffad6/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:773d573e11f437ce0bdb95b7c18dc58390494f96d43f8b45b9760436114f7652", size = 3926197, upload-time = "2026-05-01T23:28:55.55Z" }, + { url = "https://files.pythonhosted.org/packages/cd/68/d190ef0c0c5b16ded07831dabc8ddd412f4cdab07ec6e30ed38d9bda0e1f/psycopg_binary-3.3.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:71e55ccbdfae79a2ed9c6369c3008a3025817ff9d7e27b32a2d84e2a4267e66e", size = 4236627, upload-time = "2026-05-01T23:29:05.336Z" }, + { url = "https://files.pythonhosted.org/packages/25/8f/81dcbc2e8454b74d14881275ea45f00791052dac531a9fa8be1730d1685b/psycopg_binary-3.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:494ca54901be8cf9eb7e02c25b731f2317c378efa44f43e8f9bd0e1184ae7be4", size = 3560782, upload-time = "2026-05-01T23:29:11.967Z" }, + { url = "https://files.pythonhosted.org/packages/09/43/13e9c406fbbf354580476e248a16b64802a376873ebe6339e30bb655572d/psycopg_binary-3.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fbd1d4ed566895ad2d3bf4ddfd8bae90026930ddf29df3b9d91d32c8c47866a7", size = 4590377, upload-time = "2026-05-01T23:29:18.782Z" }, + { url = "https://files.pythonhosted.org/packages/22/be/2923cd7c3683e7afdecf4f10796a18de02f5c5ddc0969aa2ad0a8cdd3bbd/psycopg_binary-3.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:75a9067e236f9b9ae3535b66fe99bddb33d39c0de10112e49b9ab11eee53dc31", size = 4669023, upload-time = "2026-05-01T23:29:25.884Z" }, + { url = "https://files.pythonhosted.org/packages/96/a0/2c913d6fe13d6a8bd13597d36739bf47af063ad9399e402cfecab16f3c1e/psycopg_binary-3.3.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:b56b603ebcea8aa10b46228b8410ba7f13e7c2ee54389d4d9be0927fd8ce2a70", size = 5467423, upload-time = "2026-05-01T23:29:33.416Z" }, + { url = "https://files.pythonhosted.org/packages/e7/38/205d10bc1ad0df4a21c5c51659126bd3ea0ef98fcad1e852f78c249bb9c3/psycopg_binary-3.3.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c677c4ad433cb7150c8cd304a0769ae3bcfbe5ea0676eb53faa7b1443b16d0d3", size = 5151137, upload-time = "2026-05-01T23:29:42.013Z" }, + { url = "https://files.pythonhosted.org/packages/36/fc/f0381ddcd45eff3bb70dbca6823a996048d7f507b2ec3fc92c6fabc0fe87/psycopg_binary-3.3.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:26df2717e59c0473e4465a97dfb1b7afebaa479277870fd5784d1436470db47c", size = 6736671, upload-time = "2026-05-01T23:29:51.626Z" }, + { url = "https://files.pythonhosted.org/packages/95/40/fa545ae152c24327651e5624e4902121e808270be36c10b12e9939be09bc/psycopg_binary-3.3.4-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:1dc1f79fd16bb1f3f4421417a514607539f17804d95c7ed617265369d1981cae", size = 4979601, upload-time = "2026-05-01T23:29:56.961Z" }, + { url = "https://files.pythonhosted.org/packages/86/e4/2f8a47ee97f90cd2b933d0463081d35631ff419de2b8c984a5f369857de0/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:136f199a407b5348b9b857c504aff60c77622a28482e7195839ce1b51238c4cc", size = 4510513, upload-time = "2026-05-01T23:30:07.243Z" }, + { url = "https://files.pythonhosted.org/packages/0e/0e/94e842ff4a7f98ed162580ca2e8b8864b28c1e0350f2443f8ee47f821167/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b6f5a29e9c775b9f12a1a717aa7a2c80f9e1db6f27ba44a5b59c80ac61d2ffcf", size = 4187243, upload-time = "2026-05-01T23:30:15.352Z" }, + { url = "https://files.pythonhosted.org/packages/d0/83/fc6c174b672e29b7de996ea77b6cbddf46c891751c3355f6974292baa6b4/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:ee17a2cf4943cde261adfad1bbc5bf38d6b3776d7afff74c7cabcbeaeb08c260", size = 3927347, upload-time = "2026-05-01T23:30:21.186Z" }, + { url = "https://files.pythonhosted.org/packages/e9/65/768364d4a97a15b1a7f47ba52688c1686f22941d8332a8398cefc468e25f/psycopg_binary-3.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5c4ab71be17bdca30cb34c34c4e1496e2f5d6f20c199c12bad226070b22ef9bf", size = 4236393, upload-time = "2026-05-01T23:30:26.211Z" }, + { url = "https://files.pythonhosted.org/packages/bd/3b/218efbc9e645becd80cdf651acda05f85cfe546b7a9c0458c7cbc8fe1f74/psycopg_binary-3.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:dbfdb9b6cc79f31104a7b162a2b921b765fcc62af6c00540a167a8de47e4ed38", size = 3564592, upload-time = "2026-05-01T23:30:31.764Z" }, + { url = "https://files.pythonhosted.org/packages/48/a6/828c9185701dab71b234c2a76c38a08b098ebfec5020716b4e93807492b5/psycopg_binary-3.3.4-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:28b7398fdd19db3232c884fb24550bdfe951221f510e195e233299e4c9b78f97", size = 4607292, upload-time = "2026-05-01T23:30:38.962Z" }, + { url = "https://files.pythonhosted.org/packages/92/58/5b40dbc9d839045c9dae956960e4fb6d20bcabe6c59a2aa34fc3a371913f/psycopg_binary-3.3.4-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1fbaa292a3c8bb61b45df1ad3da1908ccee7cb889db9425e3557d9e34e2a4829", size = 4687023, upload-time = "2026-05-01T23:30:47.227Z" }, + { url = "https://files.pythonhosted.org/packages/85/a9/793f0ac107a9003b48441d0d1f9f616d96e0f37458dd8dc12528ceff55fb/psycopg_binary-3.3.4-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94596f9e7633ee3f6440711d43bb70aa31cc0a46a900ab8b4201a366ace5c9e7", size = 5486985, upload-time = "2026-05-01T23:30:55.517Z" }, + { url = "https://files.pythonhosted.org/packages/8f/26/42e8533497e2592334f68ec529cf5f840f7fa4e99575a4bb61aa184dbfbf/psycopg_binary-3.3.4-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8c0056529e68dbe9184cd4019a1f3d8f3a4ead2f6fc7a5afcf27d3314edd1277", size = 5168745, upload-time = "2026-05-01T23:31:01.904Z" }, + { url = "https://files.pythonhosted.org/packages/15/af/b7151776cc08d5935d45c833ec818a9beb417cf7c08239af1aafbdae78ee/psycopg_binary-3.3.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c09aad7051326e7603c14e50636db9c01f78272dc54b3accff03d46370461e6", size = 6761486, upload-time = "2026-05-01T23:31:14.511Z" }, + { url = "https://files.pythonhosted.org/packages/d0/ed/c92533b9124712d592cbf1cd6c76da933a2e0acea81dfe1fbe7e735f0cff/psycopg_binary-3.3.4-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:514404ed543efd620c85602b747df2a23cf1241b4067199e1a66f2d2757aaa41", size = 4997427, upload-time = "2026-05-01T23:31:20.901Z" }, + { url = "https://files.pythonhosted.org/packages/a2/23/ccadfd0de416aa188356daa199453af24087b042e296088706d190ae0295/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:46893c26858be12cc49ca4226ed6a60b4bfccadd946b3bebb783a60b38788228", size = 4533549, upload-time = "2026-05-01T23:31:26.204Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a0/c8f43cee36386f7bc891ab41a9d31ea07cf9826038e732da79f26b1e5f34/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:df1d567fc430f6df15c9fcf67d87685fc49bdb325adc0db5af1adfb2f44eb5c9", size = 4210256, upload-time = "2026-05-01T23:31:33.884Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2c/c1547871be3790676e8868b38655496422f94f0978dfb66b74bdba2f1676/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:6b9016b1714da4dd5ecaaa75b82098aa5a0b87854ce9b092e21c27c4ae23e014", size = 3946204, upload-time = "2026-05-01T23:31:39.626Z" }, + { url = "https://files.pythonhosted.org/packages/c4/b1/f6670f00fa7ea601584623f6c11602ab92117d83eaff885e0210f6de7418/psycopg_binary-3.3.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:47c656a8a7ba6eb0cff1801a4caaa9c8bdc12d03080e273aff1c8ac39971a77e", size = 4255811, upload-time = "2026-05-01T23:31:44.986Z" }, + { url = "https://files.pythonhosted.org/packages/eb/e6/5fff07a70d1f945ed90ae131c3bd76cab32beff7c58c6db15ad5820b6d1f/psycopg_binary-3.3.4-cp314-cp314-win_amd64.whl", hash = "sha256:c37e024c07308cd06cf3ec51bfd0e7f6157585a4d84d1bce4a7f5f7913719bf8", size = 3666849, upload-time = "2026-05-01T23:31:51.165Z" }, +] + [[package]] name = "pyarrow" version = "23.0.0" @@ -843,6 +919,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, +] + [[package]] name = "pytz" version = "2025.2"