From 148a512726a5b9cfedda4d464cc5fc788fb33c0e Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Fri, 5 Dec 2025 16:00:00 +0100 Subject: [PATCH 1/7] Implemented first version of contracts --- docs/Config_and_Macros.md | 2 +- docs/Contracts.md | 459 +++++++++ docs/Quickstart.md | 2 +- docs/Source_Freshness.md | 2 +- docs/Sources.md | 6 +- docs/Technical_Overview.md | 3 +- docs/YAML_Tests.md | 2 +- docs/examples/DQ_Demo.md | 877 +++++++++++------- docs/examples/Local_Engine_Setup.md | 2 +- examples/api_demo/sources.yml | 2 +- examples/basic_demo/sources.yml | 2 +- examples/cache_demo/sources.yml | 2 +- examples/ci_demo/sources.yml | 2 +- examples/dq_demo/contracts.yml | 19 + .../marts/mart_orders_agg.contracts.yml | 36 + .../models/staging/customers.contracts.yml | 46 + .../models/staging/orders.contracts.yml | 12 + examples/dq_demo/project.yml | 23 +- examples/dq_demo/sources.yml | 2 +- .../dq_demo/tests/dq/min_positive_share.ff.py | 6 +- examples/env_matrix/sources.yml | 2 +- examples/hooks_demo/sources.yml | 2 +- examples/incremental_demo/sources.yml | 2 +- examples/macros_demo/sources.yml | 2 +- examples/materializations_demo/sources.yml | 2 +- .../packages_demo/main_project/sources.yml | 2 +- examples/snapshot_demo/sources.yml | 2 +- .../building_locally_demo/sources.yml | 2 +- examples_article/http_cache_demo/sources.yml | 2 +- mkdocs.yml | 1 + pyproject.toml | 2 +- src/fastflowtransform/cli/__init__.py | 2 - src/fastflowtransform/cli/bootstrap.py | 28 +- src/fastflowtransform/cli/init_cmd.py | 2 +- src/fastflowtransform/cli/source_cmd.py | 4 +- src/fastflowtransform/cli/test_cmd.py | 21 +- src/fastflowtransform/config/contracts.py | 328 +++++++ src/fastflowtransform/config/loaders.py | 19 + src/fastflowtransform/config/sources.py | 2 +- src/fastflowtransform/contracts.py | 300 ++++++ src/fastflowtransform/decorators.py | 6 +- src/fastflowtransform/errors.py | 18 + src/fastflowtransform/executors/_shims.py | 142 --- .../executors/_test_utils.py | 82 ++ src/fastflowtransform/executors/base.py | 42 + .../executors/bigquery/base.py | 117 ++- .../executors/databricks_spark.py | 90 +- src/fastflowtransform/executors/duckdb.py | 77 +- src/fastflowtransform/executors/postgres.py | 79 +- .../executors/snowflake_snowpark.py | 83 +- src/fastflowtransform/incremental.py | 28 +- src/fastflowtransform/schema_loader.py | 4 +- src/fastflowtransform/source_freshness.py | 4 +- src/fastflowtransform/testing/base.py | 409 ++++---- src/fastflowtransform/testing/registry.py | 167 +++- tests/common/fixtures.py | 11 +- .../test_test_cmd_schema_merge_integration.py | 6 +- .../test_ephemeral_inlining_integration.py | 2 +- .../test_schema_yaml_basic_integration.py | 4 +- ...st_schema_yaml_registry_mix_integration.py | 4 +- .../streaming/test_smoke_streaming.py | 71 -- .../integration/test_artifacts_integration.py | 2 +- .../registry/test_dispatch_integration.py | 8 +- tests/unit/artifacts/test_manifest_unit.py | 2 +- tests/unit/cli/test_bootstrap_unit.py | 29 - tests/unit/cli/test_source_cmd_unit.py | 2 +- tests/unit/config/test_config_hook_unit.py | 2 +- tests/unit/core/test_macros_loading_unit.py | 2 +- .../test_databricks_spark_exec_unit.py | 25 - tests/unit/executors/test_shims_unit.py | 262 ------ .../executors/test_snowflake_snowpark_exec.py | 11 +- tests/unit/schema/test_schema_loader_unit.py | 2 +- tests/unit/test_contracts_unit.py | 406 ++++++++ tests/unit/test_testing_unit.py | 429 +++------ .../unit/testing/test_accepted_values_unit.py | 4 +- uv.lock | 2 +- 76 files changed, 3177 insertions(+), 1691 deletions(-) create mode 100644 docs/Contracts.md create mode 100644 examples/dq_demo/contracts.yml create mode 100644 examples/dq_demo/models/marts/mart_orders_agg.contracts.yml create mode 100644 examples/dq_demo/models/staging/customers.contracts.yml create mode 100644 examples/dq_demo/models/staging/orders.contracts.yml create mode 100644 src/fastflowtransform/config/contracts.py create mode 100644 src/fastflowtransform/config/loaders.py create mode 100644 src/fastflowtransform/contracts.py delete mode 100644 src/fastflowtransform/executors/_shims.py create mode 100644 src/fastflowtransform/executors/_test_utils.py delete mode 100644 tests/integration/streaming/test_smoke_streaming.py delete mode 100644 tests/unit/executors/test_shims_unit.py create mode 100644 tests/unit/test_contracts_unit.py diff --git a/docs/Config_and_Macros.md b/docs/Config_and_Macros.md index 42608a4..f1252e2 100644 --- a/docs/Config_and_Macros.md +++ b/docs/Config_and_Macros.md @@ -144,7 +144,7 @@ Allowed values are case-insensitive strings or tuples. If the engine does not ma ```yaml # sources.yml -version: 2 +version: 1 sources: - name: crm diff --git a/docs/Contracts.md b/docs/Contracts.md new file mode 100644 index 0000000..305c8f2 --- /dev/null +++ b/docs/Contracts.md @@ -0,0 +1,459 @@ +# Contracts + +FastFlowTransform supports **data contracts**: declarative expectations about your +tables and columns. Contracts are stored in YAML files and are compiled into +normal `fft test` checks. + +You get: + +- A place to describe the **intended schema** (types, nullability, enums, etc.) +- Automatic **data-quality tests** derived from those contracts +- Optional checks for the **physical DB data type** (per engine) + +Contracts live in two places: + +- Per-table: `models/**/.contracts.yml` +- Project-level defaults: `contracts.yml` at the project root + + +--- + +## Per-table contracts (`*.contracts.yml`) + +For each logical table you can create a `*.contracts.yml` file under `models/`. + +**Convention** + +- File name: ends with `.contracts.yml` +- Location: anywhere under `models/` +- Each file describes **exactly one table** + +Example: + +```yaml +# models/staging/customers.contracts.yml +version: 1 +table: customers + +columns: + customer_id: + type: integer + physical: + duckdb: BIGINT + postgres: integer + bigquery: INT64 + snowflake_snowpark: NUMBER + databricks_spark: BIGINT + nullable: false + unique: true + + name: + type: string + nullable: false + + status: + type: string + nullable: false + enum: + - active + - inactive + + created_at: + type: timestamp + nullable: false +```` + +The `table` name should match the logical relation name you use in your models +(e.g. `relation_for("customers")`). + +--- + +## Column attributes + +Each entry under `columns:` is a **column contract**. + +Supported attributes: + +```yaml +columns: + some_column: + type: string # optional semantic type + physical: # optional physical DB type(s) + duckdb: VARCHAR + postgres: text + nullable: false # nullability contract + unique: true # uniqueness contract + enum: [a, b, c] # allowed values + regex: "^[A-Z]{2}[0-9]{4}$" # regex pattern + min: 0 # numeric min (inclusive) + max: 100 # numeric max (inclusive) + description: "Human note" # free-form description +``` + +### `type` (semantic type) + +Free-form semantic type hint, things like: + +* `integer` +* `string` +* `timestamp` +* `boolean` +* … + +Right now this is **documentation / intent only**; it does not generate tests by itself. +Use it to communicate intent and align with your physical types. + +--- + +### `physical` (engine-specific physical DB type) + +`physical` describes the **actual DB type** of the column, per engine. + +There are two forms: + +**1) Shorthand string** + +```yaml +physical: BIGINT +``` + +This is interpreted as: + +```yaml +physical: + default: BIGINT +``` + +**2) Per-engine mapping** + +```yaml +physical: + default: BIGINT # fallback if no engine-specific key is set + duckdb: BIGINT + postgres: integer + bigquery: INT64 + snowflake_snowpark: NUMBER + databricks_spark: BIGINT +``` + +Supported keys: + +| Key | Engine / executor | +| -------------------- | --------------------------- | +| `default` | Fallback for all engines | +| `duckdb` | DuckDB executor | +| `postgres` | Postgres executor | +| `bigquery` | BigQuery executors | +| `snowflake_snowpark` | Snowflake Snowpark executor | +| `databricks_spark` | Databricks / Spark executor | + +> **Important** +> +> The value here must match what your warehouse reports in its catalog / +> information schema for that column (e.g. `INT64` in BigQuery, `NUMBER` in +> Snowflake, etc.). + +Each `physical` contract is turned into a `column_physical_type` test. +If the engine does not yet support physical type introspection, the test will +fail with a clear “engine not yet supported” message instead of silently +passing. + +--- + +### `nullable` + +```yaml +nullable: false +``` + +* `nullable: false` → generates a `not_null` test for this column. +* `nullable: true` or omitted → no nullability test. + +--- + +### `unique` + +```yaml +unique: true +``` + +* `unique: true` → generates a `unique` test for this column. +* `unique: false` or omitted → no uniqueness test. + +--- + +### `enum` + +```yaml +enum: + - active + - inactive + - pending +``` + +`enum` defines a finite set of allowed values and generates an +`accepted_values` test. + +You can also use a single scalar: + +```yaml +enum: active +``` + +which is treated as `["active"]`. + +--- + +### `regex` + +```yaml +regex: "^[^@]+@[^@]+$" +``` + +`regex` defines a pattern that all non-null values must match. It generates a +`regex_match` test. + +--- + +### `min` / `max` + +```yaml +min: 0 +max: 100 +``` + +`min` and `max` define an inclusive numeric range and generate a `between` test. + +You can specify just one side: + +```yaml +min: 0 # only lower bound +# or +max: 100 # only upper bound +``` + +--- + +### `description` + +```yaml +description: "Customer signup timestamp in UTC" +``` + +Free-form description field. This does not generate tests; it’s for docs / +tooling. + +--- + +## Project-level contracts (`contracts.yml`) + +You can define **project-wide defaults** in a single `contracts.yml` file at +the project root. + +This file only defines **defaults**, not concrete tables. + +Example: + +```yaml +# contracts.yml +version: 1 + +defaults: + columns: + # All *_id columns are non-null integers with engine-specific types + - match: + name: ".*_id$" + type: integer + nullable: false + physical: + duckdb: BIGINT + postgres: integer + bigquery: INT64 + + # created_at should always be a non-null timestamp + - match: + name: "^created_at$" + type: timestamp + nullable: false +``` + +### Column match rules + +Each entry under `defaults.columns` is a **column default rule**: + +```yaml +- match: + name: "regex on column name" # required + table: "regex on table name" # optional + type: ... + physical: ... + nullable: ... + unique: ... + enum: ... + regex: ... + min: ... + max: ... + description: ... +``` + +* `match.name` + Required **regex** applied to the column name. + +* `match.table` + Optional **regex** applied to the table name. + +All the other fields are the same as in `*.contracts.yml`. They act as +**defaults**. + +### How defaults are applied + +For each column contract from a per-table file: + +1. All `defaults.columns` rules are evaluated **in file order**. +2. A rule applies if both: + + * `match.name` matches the column name, and + * `match.table` is empty or matches the table name. +3. For every applicable rule: + + * Fields that are currently `null` / unset on the column are **filled** from + the rule. + * Fields that are already set on the column are **not overridden**. + +**Per-table contracts always win.** +Defaults only fill in missing values. + +Example: + +```yaml +# contracts.yml +defaults: + columns: + - match: + name: ".*_id$" + nullable: false + physical: BIGINT +``` + +```yaml +# models/orders.contracts.yml +version: 1 +table: orders +columns: + customer_id: + # nullable unspecified → inherited as false from defaults + physical: + duckdb: BIGINT + postgres: integer # overrides default +``` + +Effective contract for `orders.customer_id`: + +```yaml +type: null +nullable: false # from defaults +physical: + duckdb: BIGINT # from per-table + postgres: integer # from per-table + default: BIGINT # from defaults.physical (other engines) +unique: null +... +``` + +--- + +## How contracts become tests + +Contracts are turned into regular `TestSpec` entries used by `fft test`. + +For each column: + +| Contract field | Generated test type | Notes | +| ----------------- | ---------------------- | ---------------------------- | +| `physical` | `column_physical_type` | Uses engine-specific mapping | +| `nullable: false` | `not_null` | | +| `unique: true` | `unique` | | +| `enum` | `accepted_values` | | +| `min` / `max` | `between` | inclusive range | +| `regex` | `regex_match` | Python regex | + +All contract-derived tests: + +* Use **severity** `error` by default (today) +* Receive the tag `contract` (so you can filter on them later) + +Example for `customers`: + +```yaml +# models/staging/customers.contracts.yml +version: 1 +table: customers +columns: + customer_id: + nullable: false + unique: true + physical: + duckdb: BIGINT + status: + enum: [active, inactive] +``` + +This yields tests roughly equivalent to: + +```text +customers.customer_id not_null (tags: contract) +customers.customer_id unique (tags: contract) +customers.customer_id column_physical_type (tags: contract) +customers.status accepted_values (tags: contract) +``` + +You don’t need to write those tests yourself; they’re derived automatically +from the contract files. + +--- + +## Using contracts with `fft test` + +The high-level flow: + +1. You define `*.contracts.yml` under `models/` and, optionally, a root + `contracts.yml` with defaults. +2. `fft test` loads: + + * all per-table contracts + * project-level defaults +3. Contracts are expanded into test specs. +4. Tests are executed like any other `fft test` checks. + +If a contract file is malformed (YAML, duplicate keys, or schema), FFT raises a +friendly `ContractsConfigError` with a hint. The test run will fail until the +file is fixed, rather than silently skipping it. + +--- + +## Current limitations + +A few things contracts **do not** do yet: + +* Contracts **do not change DDL**: tables are still created with the types + inferred by the warehouse from your `SELECT`. +* `type` (semantic type) is not used to alter the schema; it is for intent / + documentation. +* Physical type checks require engine support: + + * Currently, only engines that can introspect their `INFORMATION_SCHEMA` + and expose that to FFT can fully enforce `column_physical_type`. + * Other engines may reject such tests with a clear “engine not supported” + message. + +The intended next step (not implemented yet) is an **“enforce schema”** mode +which uses contracts to drive actual table DDL (or casts) instead of only +post-hoc assertions. + +For now, contracts give you **schema-as-YAML** + **tests-from-contracts** in a +single, consistent place. + +Additional validation: + +* Duplicate YAML keys in contract files are rejected (the loader raises before + parsing). Fix or remove duplicates to proceed. diff --git a/docs/Quickstart.md b/docs/Quickstart.md index a06407f..9ea17ab 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -47,7 +47,7 @@ pip install -e .[full] ```bash mkdir -p demo/{models,seeds} cat <<'YAML' > demo/sources.yml -version: 2 +version: 1 sources: - name: raw diff --git a/docs/Source_Freshness.md b/docs/Source_Freshness.md index f441edc..baeb426 100644 --- a/docs/Source_Freshness.md +++ b/docs/Source_Freshness.md @@ -33,7 +33,7 @@ Freshness rules are attached to source tables in your metadata (conceptually alo A minimal example: ```yaml -version: 2 +version: 1 sources: - name: crm schema: raw diff --git a/docs/Sources.md b/docs/Sources.md index c67d827..f2433a2 100644 --- a/docs/Sources.md +++ b/docs/Sources.md @@ -11,14 +11,14 @@ project/ ├── models/ ├── sources.yml └── seeds/ -```` +``` ## YAML Schema (Version 2) FastFlowTransform expects a dbt-style structure: ```yaml -version: 2 +version: 1 sources: - name: raw schema: staging # default schema for this source group @@ -82,7 +82,7 @@ Engine-specific overrides follow this merge order: A typical analytics project mixes **seeded reference data**, **database tables**, and **lakehouse paths**. A single `sources.yml` might look like this: ```yaml -version: 2 +version: 1 sources: # Seeded reference data (CSV → tables) - name: ref diff --git a/docs/Technical_Overview.md b/docs/Technical_Overview.md index 0eba393..6fd9968 100644 --- a/docs/Technical_Overview.md +++ b/docs/Technical_Overview.md @@ -142,7 +142,7 @@ CLI (Typer) ├── Executors (executors/*) │ ├── BaseExecutor (SQL rendering, dependency loading, materialization, requires guard) │ ├── DuckExecutor (DuckDB) -│ ├── PostgresExecutor (SQLAlchemy, shims) +│ ├── PostgresExecutor (SQLAlchemy) │ ├── BigQueryExecutor (pandas) │ ├── BigQueryBFExecutor (BigQuery DataFrames / bigframes) │ ├── DatabricksSparkExecutor (PySpark, without pandas) @@ -256,7 +256,6 @@ class BaseExecutor(ABC): **Postgres (`postgres.py`)** -- `_SAConnShim` (compatible with `testing._exec`). - `run_sql` renders SQL and rewrites `CREATE OR REPLACE TABLE` to `DROP + CREATE AS`. - `_read_relation` uses pandas, handles schemas, and provides clear guidance. - `_materialize_relation` writes via `to_sql(if_exists="replace")`. diff --git a/docs/YAML_Tests.md b/docs/YAML_Tests.md index 7eba47f..36c34c8 100644 --- a/docs/YAML_Tests.md +++ b/docs/YAML_Tests.md @@ -6,7 +6,7 @@ Schema-bound tests live in `models/*.yml` or `models/**/schema.yml` and compleme ```yaml # examples/r1_demo/models/users_enriched.yml -version: 2 +version: 1 models: - name: users_enriched description: "Adds gmail flag" diff --git a/docs/examples/DQ_Demo.md b/docs/examples/DQ_Demo.md index 28fcbe3..23719b0 100644 --- a/docs/examples/DQ_Demo.md +++ b/docs/examples/DQ_Demo.md @@ -1,159 +1,296 @@ # Data Quality Demo Project -The **Data Quality Demo** shows how to use **all built-in FFT data quality tests** plus **custom DQ tests (Python & SQL)** on a small, understandable model: - -* Column checks: - - * `not_null` - * `unique` - * `accepted_values` - * `greater_equal` - * `non_negative_sum` - * `row_count_between` - * `freshness` -* Cross-table reconciliations: - - * `reconcile_equal` - * `reconcile_ratio_within` - * `reconcile_diff_within` - * `reconcile_coverage` - -* Custom tests (demo): - - * `min_positive_share` (Python-based) - * `no_future_orders` (SQL-based) - -It uses a simple **customers / orders / mart** setup so you can see exactly what each test does and how it fails when something goes wrong. +The **Data Quality Demo** shows how to combine: + +- **Built-in FFT data quality tests** +- **Tests generated from data contracts (`*.contracts.yml` + `contracts.yml`)** +- **Custom DQ tests (Python & SQL)** +- **Multiple engines** (DuckDB, Postgres, Databricks Spark, BigQuery, Snowflake Snowpark) + +on a small, understandable model: + +- **Column checks (from contracts + project.yml):** + - `column_physical_type` + - `not_null` + - `unique` + - `accepted_values` + - `between` + - `regex_match` + - `greater_equal` + - `non_negative_sum` + - `row_count_between` + - `freshness` + - `relationships` + +- **Cross-table reconciliations:** + - `reconcile_equal` + - `reconcile_ratio_within` + - `reconcile_diff_within` + - `reconcile_coverage` + +- **Custom tests (demo):** + - `min_positive_share` (Python-based) + - `no_future_orders` (SQL-based) + +It uses a simple **customers / orders / mart** setup so you can see exactly what +each test does and how it fails when something goes wrong. --- ## What this example demonstrates -1. **Basic column checks** on staging tables - Ensure IDs are present and unique, amounts are non-negative, and status values are valid. +1. **Basic column checks** on staging tables + - Enforced via **contracts** (`*.contracts.yml` + `contracts.yml`) and + project tests: + - IDs are present / non-null, status values are constrained, numeric ranges + are respected, physical types match the warehouse. + +2. **Freshness** on a timestamp column + - Table-level `freshness` test on `orders.order_ts`. + - Source-level freshness via `sources.yml` for `crm.customers` / `crm.orders`. -2. **Freshness** on a timestamp column - Check that the most recent order in your mart is not “too old”, using `last_order_ts`. +3. **Row count sanity checks** + - Guard against empty tables and unexpectedly large row counts. -3. **Row count sanity checks** - Guard against empty tables and unexpectedly large row counts. +4. **Cross-table reconciliations** between staging and mart + - Verify that sums and counts match between `orders` and the aggregated + `mart_orders_agg`, and that every order has a matching customer. -4. **Cross-table reconciliations** between staging and mart - Verify that sums and counts match between `orders` and the aggregated `mart_orders_agg`, and that every customer has a corresponding mart row. +5. **Tagged tests and selective execution** + - All tests are tagged (e.g. `example:dq_demo`, `reconcile`, `fk`, + `contract`) so you can run exactly the subset you care about. -5. **Tagged tests and selective execution** - All tests are tagged (e.g. `example:dq_demo`, `reconcile`) so you can run exactly the subset you care about. +6. **Contracts-driven tests** + - Per-table contracts plus project-wide defaults generate DQ tests + automatically (including `column_physical_type`). --- -## Project layout (example) +## Project layout ```text examples/dq_demo/ - .env + .env.dev_bigquery_bigframes + .env.dev_bigquery_pandas + .env.dev_databricks .env.dev_duckdb .env.dev_postgres - .env.dev_databricks - .env.dev_bigquery_pandas - .env.dev_bigquery_bigframes .env.dev_snowflake - Makefile # optional, convenience wrapper around fft commands + Makefile + README.md + contracts.yml profiles.yml project.yml sources.yml - seeds/ - customers.csv - orders.csv - models/ + README.md + marts/ + mart_orders_agg.contracts.yml + mart_orders_agg.ff.sql staging/ + customers.contracts.yml customers.ff.sql + orders.contracts.yml orders.ff.sql - marts/ - mart_orders_agg.ff.sql + + seeds/ + README.md + schema.yml + seed_customers.csv + seed_orders.csv tests/ dq/ min_positive_share.ff.py no_future_orders.ff.sql + unit/ + README.md +```` + +High level: + +* **`.env.dev_*`** — engine-specific environment examples +* **`Makefile`** — convenience wrapper for seeding, running models, DAG HTML and tests +* **`profiles.yml`** — connection profiles for all engines +* **`project.yml`** — central place for **tests** (including reconciliations & custom DQ tests) +* **`contracts.yml`** — project-level **contract defaults** +* **`models/**.contracts.yml`** — per-table contracts +* **`sources.yml`** — source definitions + freshness on raw seeds +* **`seeds/`** — demo CSVs and seed schema +* **`tests/dq/`** — custom DQ tests (Python + SQL) + +--- + +## Seeds + +### `seeds/seed_customers.csv` + +Simple customer dimension with a creation timestamp: + +```csv +customer_id,name,status,created_at +1,Alice,active,2025-01-01T10:00:00 +2,Bob,active,2025-01-02T11:00:00 +3,Carol,inactive,2025-01-03T12:00:00 ``` -### Seeds +Columns: -* `seeds/customers.csv` - Simple customer dimension with a creation timestamp: - `customer_id`, `name`, `status`, `created_at` (ISO-8601, e.g. `2025-01-01T10:00:00`). - The demo ships with three rows (Alice, Bob, Carol) so it’s easy to reason about failures. +* `customer_id` – integer +* `name` – string +* `status` – string (`active` / `inactive`) +* `created_at` – ISO-8601 timestamp -* `seeds/orders.csv` - Order fact data with per-order timestamps: - `order_id`, `customer_id`, `amount`, `order_ts` (ISO-8601, e.g. `2025-01-10T09:00:00`). - One order has `amount = 0.00` so the custom `min_positive_share` test has something to complain about. +### `seeds/seed_orders.csv` -### Models +Order fact data with per-order timestamps: -**1. Staging: `customers.ff.sql`** +```csv +order_id,customer_id,amount,order_ts +100,1,50.00,2025-01-10T09:00:00 +101,1,20.00,2025-01-11T09:00:00 +102,2,30.00,2025-01-11T10:00:00 +103,3,0.00,2025-01-12T10:00:00 +``` -* Materialized as a table. -* Casts IDs and other fields into proper types. -* Used as the “clean” customer dimension for downstream checks. +Columns: - ```sql - {{ config( - materialized='table', - tags=[ - 'example:dq_demo', - 'scope:staging', - 'engine:duckdb', - 'engine:postgres', - 'engine:databricks_spark', - 'engine:bigquery', - 'engine:snowflake_snowpark' - ], - ) }} +* `order_id` – integer +* `customer_id` – integer +* `amount` – double +* `order_ts` – ISO-8601 timestamp - select - cast(customer_id as int) as customer_id, - name, - status, - cast(created_at as timestamp) as created_at - from {{ source('crm', 'customers') }}; - ``` - -**2. Staging: `orders.ff.sql`** - -* Materialized as a table. -* Casts fields to proper types so DQ tests work reliably: - - ```sql - {{ config( - materialized='table', - tags=[ - 'example:dq_demo', - 'scope:staging', - 'engine:duckdb', - 'engine:postgres', - 'engine:databricks_spark', - 'engine:bigquery', - 'engine:snowflake_snowpark' - ], - ) }} +**One order has `amount = 0.00`** so the custom +`min_positive_share` test has something to complain about. - select - cast(order_id as int) as order_id, - cast(customer_id as int) as customer_id, - cast(amount as double) as amount, - cast(order_ts as timestamp) as order_ts - from {{ source('crm', 'orders') }}; - ``` +### Seed schema and sources + +`seeds/schema.yml` defines target placement and types: + +```yaml +targets: + seed_customers: + schema: dq_demo + seed_orders: + schema: dq_demo + +columns: + seed_customers: + customer_id: integer + name: string + status: string + created_at: + type: timestamp + seed_orders: + order_id: integer + customer_id: integer + amount: double + order_ts: + type: timestamp +``` + +`sources.yml` exposes them as `crm.customers` and `crm.orders` with **source +freshness**: + +```yaml +version: 1 + +sources: + - name: crm + schema: dq_demo + tables: + - name: customers + identifier: seed_customers + description: "Seeded customers table" + freshness: + loaded_at_field: _ff_loaded_at + warn_after: + count: 60 + period: minute + error_after: + count: 240 + period: minute + - name: orders + identifier: seed_orders + description: "Seeded orders table" + freshness: + loaded_at_field: _ff_loaded_at + warn_after: + count: 60 + period: minute + error_after: + count: 240 + period: minute +``` + +--- + +## Models + +### 1. Staging: `models/staging/customers.ff.sql` + +Materialized as a table; casts IDs and timestamps into proper types and +prepares the customer dimension: + +```sql +{{ config( + materialized='table', + tags=[ + 'example:dq_demo', + 'scope:staging', + 'engine:duckdb', + 'engine:postgres', + 'engine:databricks_spark', + 'engine:bigquery', + 'engine:snowflake_snowpark' + ], +) }} + +-- Staging table for customers +select + cast(customer_id as int) as customer_id, + name, + status, + cast(created_at as timestamp) as created_at +from {{ source('crm', 'customers') }}; +``` + +### 2. Staging: `models/staging/orders.ff.sql` + +Materialized as a table; ensures types are suitable for numeric and freshness +checks: + +```sql +{{ config( + materialized='table', + tags=[ + 'example:dq_demo', + 'scope:staging', + 'engine:duckdb', + 'engine:postgres', + 'engine:databricks_spark', + 'engine:bigquery', + 'engine:snowflake_snowpark' + ], +) }} - This is important for: +-- Staging table for orders with proper types for DQ checks +select + cast(order_id as int) as order_id, + cast(customer_id as int) as customer_id, + cast(amount as numeric) as amount, + cast(order_ts as timestamp) as order_ts +from {{ source('crm', 'orders') }}; +``` - * numeric checks (`greater_equal`, `non_negative_sum`) - * timestamp-based `freshness` checks +This is important for: -**3. Mart: `mart_orders_agg.ff.sql`** +* Numeric checks (`greater_equal`, `non_negative_sum`) +* Timestamp-based `freshness` checks on `order_ts` +* Relationships on `customer_id` + +### 3. Mart: `models/marts/mart_orders_agg.ff.sql` Aggregates orders per customer and prepares data for reconciliation + freshness: @@ -171,13 +308,13 @@ Aggregates orders per customer and prepares data for reconciliation + freshness: ], ) }} --- Aggregate orders per customer for DQ & reconciliation tests +-- Aggregate orders per customer for reconciliation & freshness tests with base as ( select o.order_id, o.customer_id, -- Ensure numeric and timestamp types for downstream DQ checks - cast(o.amount as double) as amount, + cast(o.amount as numeric) as amount, cast(o.order_ts as timestamp) as order_ts, c.name as customer_name, c.status as customer_status @@ -197,93 +334,184 @@ from base group by customer_id, customer_name, customer_status; ``` -The important columns for DQ tests are: +Key columns: + +* `status` → used by contracts (`enum`) +* `order_count` and `total_amount` → used by reconciliation tests +* `first_order_ts` / `last_order_ts` → available for freshness & diagnostics + +--- + +## Contracts in the demo + +The demo uses contracts for: + +* **Per-table contracts** in `models/**.contracts.yml` +* **Project-wide defaults** in `contracts.yml` + +See `docs/Contracts.md` for the full specification; below is how the demo uses +it. + +### Project-level defaults: `contracts.yml` + +```yaml +version: 1 + +defaults: + columns: + - match: + name: ".*_id$" + type: integer + nullable: false + + - match: + name: "created_at" + type: timestamp + nullable: false + + - match: + name: ".*_ts$" + type: timestamp + nullable: true + description: "Timestamp-like but allowed to be null in some pipelines" +``` + +These rules say: + +* Any column ending with `_id` is an integer and not nullable. +* Any `created_at` is a non-null timestamp. +* Any `*_ts` column is a (possibly nullable) timestamp with a description. + +Defaults are *merged into* per-table contracts, but never override explicit +settings. + +### Example: `models/staging/customers.contracts.yml` + +```yaml +version: 1 +table: customers + +columns: + customer_id: + type: integer + nullable: false + physical: + duckdb: integer + postgres: integer + bigquery: INT64 + snowflake_snowpark: NUMBER + databricks_spark: INT + + name: + type: string + nullable: false + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + snowflake_snowpark: TEXT + databricks_spark: STRING + + status: + type: string + nullable: false + enum: + - active + - inactive + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + snowflake_snowpark: TEXT + databricks_spark: STRING + + created_at: + type: timestamp + nullable: false + physical: + duckdb: TIMESTAMP + postgres: timestamp without time zone + bigquery: TIMESTAMP + snowflake_snowpark: TIMESTAMP_NTZ + databricks_spark: TIMESTAMP +``` + +At runtime, these contracts get turned into tests: + +* `column_physical_type` for each column with `physical` +* `not_null` for columns with `nullable: false` +* `accepted_values` for `status` via `enum` +* Plus any inherited defaults from `contracts.yml` -* `status` → used for `accepted_values` -* `order_count` and `total_amount` → used for numeric and reconciliation tests -* `last_order_ts` → used for `freshness` +Similarly, the mart has a contract at +`models/marts/mart_orders_agg.contracts.yml` specifying types, nullability, +and enums. --- ## Data quality configuration (`project.yml`) -All tests live under `project.yml → tests:`. -This example uses the tag `example:dq_demo` for easy selection. +All **explicit** tests live under `project.yml → tests:`. +Contracts produce additional tests with the tag `contract`. -### Column-level checks +The demo uses the tag `example:dq_demo` for easy selection. + +### Single-table & relationships tests ```yaml tests: - # 1) IDs must be present and unique - - type: not_null - table: customers - column: customer_id - tags: [example:dq_demo, batch] + # --- Single-table checks ---------------------------------------------------- - - type: unique + - type: row_count_between table: customers - column: customer_id + min_rows: 1 + max_rows: 100 tags: [example:dq_demo, batch] - # 2) Order amounts must be >= 0 - type: greater_equal table: orders column: amount threshold: 0 tags: [example:dq_demo, batch] - # 3) Total sum of amounts must not be negative - type: non_negative_sum table: orders column: amount tags: [example:dq_demo, batch] - # 4) Customer status values must be within a known set - - type: accepted_values - table: mart_orders_agg - column: status - values: ["active", "inactive", "prospect"] - severity: warn # show as warning, not hard failure - tags: [example:dq_demo, batch] - - # 5) Row count sanity check on mart - - type: row_count_between - table: mart_orders_agg - min_rows: 1 - max_rows: 100000 - tags: [example:dq_demo, batch] - - # 6) Freshness: last order in the mart must not be "too old" - - type: freshness - table: mart_orders_agg - column: last_order_ts - max_delay_minutes: 100000000 - tags: [example:dq_demo, batch] - - # 7) Custom Python test: ensure at least a given share of positive amounts - - type: min_positive_share + - type: relationships table: orders - column: amount - params: - min_share: 0.75 - where: "amount <> 0" - tags: [example:dq_demo, batch] + column: customer_id + to: "ref('customers.ff')" + to_field: customer_id + tags: [example:dq_demo, fk] - # 8) Custom SQL test: no future orders allowed - - type: no_future_orders + # Large max_delay_minutes so the example typically passes; + # adjust down in real projects to enforce freshness SLAs. + - type: freshness table: orders column: order_ts - params: - where: "amount <> 0" + max_delay_minutes: 100000000 tags: [example:dq_demo, batch] ``` -### Cross-table reconciliations +What these do: + +* `row_count_between` — ensure `customers` is not empty and not unexpectedly + large. +* `greater_equal` / `non_negative_sum` — protect against negative `amount` and + weird aggregates. +* `relationships` — enforces referential integrity: + every `orders.customer_id` must exist in `customers.customer_id`. +* `freshness` — checks that the latest `order_ts` is recent enough. + +### Reconciliation tests ```yaml - # 7) Reconcile total revenue between orders and mart + # --- Reconciliation checks -------------------------------------------------- + - type: reconcile_equal - name: total_amount_orders_vs_mart + name: orders_total_matches_mart tags: [example:dq_demo, reconcile] left: table: orders @@ -291,103 +519,140 @@ tests: right: table: mart_orders_agg expr: "sum(total_amount)" - abs_tolerance: 0.01 + abs_tolerance: 0.0 - # 8) Ratio of sums should be ~1 (within tight bounds) - type: reconcile_ratio_within - name: total_amount_ratio + name: order_counts_match tags: [example:dq_demo, reconcile] left: - table: orders - expr: "sum(amount)" - right: table: mart_orders_agg - expr: "sum(total_amount)" + expr: "sum(order_count)" + right: + table: orders + expr: "count(*)" min_ratio: 0.999 max_ratio: 1.001 - # 9) Row count diff between orders and mart should be bounded - type: reconcile_diff_within - name: order_count_diff + name: customers_vs_orders_volume tags: [example:dq_demo, reconcile] left: table: orders expr: "count(*)" right: - table: mart_orders_agg - expr: "sum(order_count)" - max_abs_diff: 0 + table: customers + expr: "count(*)" + max_abs_diff: 10 - # 10) Coverage: every customer should appear in the mart - type: reconcile_coverage - name: customers_covered_in_mart + name: all_orders_have_customers tags: [example:dq_demo, reconcile] source: - table: customers + table: orders key: "customer_id" target: - table: mart_orders_agg + table: customers key: "customer_id" ``` -This set of tests touches **all available test types** and ties directly back to the simple data model. +These checks ensure: + +* Sums match between raw `orders.amount` and `mart_orders_agg.total_amount`. +* The number of rows in `orders` matches the sum of `order_count` in the mart. +* Overall orders vs customers volume stays within a reasonable bound. +* All orders reference an existing customer (coverage). + +### Custom tests + +```yaml + # --- Custom tests -------------------------------------------------- + - type: no_future_orders + table: orders + column: order_ts + where: "order_ts is not null" + tags: [example:dq_demo, batch] + + - type: min_positive_share + table: orders + column: amount + params: + min_share: 0.75 + where: "amount <> 0" + tags: [example:dq_demo, batch] +``` + +* `no_future_orders` — SQL-based test that fails if any order has a timestamp + in the future. +* `min_positive_share` — Python-based test that requires a minimum share of + positive values in `amount`. --- ## Custom DQ tests (Python & SQL) -The demo also shows how to define **custom data quality tests** that integrate with: +The demo shows how to define **custom data quality tests** that integrate with: -* the `project.yml → tests:` block, -* the `fft test` CLI, -* and the same summary output as built-in tests. +* `project.yml → tests:` +* `fft test` +* The same summary output as built-in tests. ### Python-based test: `min_positive_share` File: `examples/dq_demo/tests/dq/min_positive_share.ff.py` ```python +# examples/dq_demo/tests/dq/min_positive_share.ff.py from __future__ import annotations from typing import Any +from pydantic import BaseModel, ConfigDict + from fastflowtransform.decorators import dq_test from fastflowtransform.testing import base as testing -@dq_test("min_positive_share") +class MinPositiveShareParams(BaseModel): + """ + Params for the min_positive_share test. + + - min_share: required minimum share of positive values in [0, 1] + - where: optional WHERE predicate to filter rows + """ + + model_config = ConfigDict(extra="forbid") + + min_share: float = 0.5 + where: str | None = None + + +@dq_test("min_positive_share", params_model=MinPositiveShareParams) def min_positive_share( - con: Any, + executor: Any, table: str, column: str | None, params: dict[str, Any], ) -> tuple[bool, str | None, str | None]: """ - Custom DQ test: require that at least `min_share` of rows have column > 0. - - Parameters (from project.yml → tests → params): - - min_share: float in [0,1], e.g. 0.75 - - where: optional filter (string) to restrict the population + Require that at least `min_share` of rows have column > 0. """ if column is None: example = f"select count(*) from {table} where > 0" return False, "min_positive_share requires a 'column' parameter", example - # Params come from project.yml under `params:` - cfg: dict[str, Any] = params.get("params") or params # project.yml wrapper - min_share: float = cfg["min_share"] - where: str | None = cfg.get("where") + min_share: float = params["min_share"] + where: str | None = params.get("where") where_clause = f" where {where}" if where else "" total_sql = f"select count(*) from {table}{where_clause}" if where: - pos_sql = f"select count(*) from {table}{where_clause} and {column} > 0" + pos_sql = f"{total_sql} and {column} > 0" else: pos_sql = f"select count(*) from {table} where {column} > 0" - total = testing._scalar(con, total_sql) - positives = testing._scalar(con, pos_sql) + total = testing._scalar(executor, total_sql) + positives = testing._scalar(executor, pos_sql) example_sql = f"{pos_sql}; -- positives\n{total_sql}; -- total" @@ -399,14 +664,16 @@ def min_positive_share( msg = ( f"min_positive_share failed: positive share {share:.4f} " f"< required {min_share:.4f} " - f"({positives} of {total} rows have {column} > 0)" + f"({positives} of {total} rows have {column} > 0" + + (f" where {where}" if where else "") + + ")" ) return False, msg, example_sql return True, None, example_sql -```` +``` -This test is wired up from `project.yml` like this: +Wiring in `project.yml`: ```yaml - type: min_positive_share @@ -430,10 +697,11 @@ File: `examples/dq_demo/tests/dq/no_future_orders.ff.sql` -- Custom DQ test: fail if any row has a timestamp in the future. -- --- Conventions: --- - {{ table }} : table name (e.g. "orders") --- - {{ column }} : timestamp column (e.g. "order_ts") --- - {{ where }} : optional filter, passed via params["where"] +-- Context variables injected by the runner: +-- {{ table }} : table name (e.g. "orders") +-- {{ column }} : timestamp column (e.g. "order_ts") +-- {{ where }} : optional filter (string), from params["where"] +-- {{ params }} : full params dict (validated), if you ever need it select count(*) as failures from {{ table }} @@ -441,144 +709,111 @@ where {{ column }} > current_timestamp {%- if where %} and ({{ where }}){%- endif %} ``` -And the corresponding `project.yml` test: +And the corresponding entry in `project.yml`: ```yaml - type: no_future_orders table: orders column: order_ts - params: - where: "amount <> 0" + where: "order_ts is not null" tags: [example:dq_demo, batch] ``` At runtime: -* The SQL file is discovered under `tests/**/*.ff.sql`. -* `{{ config(...) }}` tells FFT the logical `type` and allowed `params`. -* `fft test` validates your `params:` from `project.yml` against this schema and - then executes the rendered SQL as a “violation count” query (`0` = pass, `>0` = fail). +* FFT discovers `*.ff.sql` test files under `tests/dq/`. +* `{{ config(...) }}` declares the test `type` and valid `params`. +* `fft test` validates and injects params, then executes the query as a + “violation count” (`0` = pass, `>0` = fail). --- ## Running the demo -Assuming you are in the repo root and using DuckDB as a starting point: +From `examples/dq_demo/`, you can either: + +* Use the **Makefile** (recommended), or +* Run `fft` commands manually. + +### Using the Makefile -### 1. Seed the data +Pick an engine: ```bash -fft seed examples/dq_demo --env dev_duckdb -``` +# DuckDB +make demo ENGINE=duckdb -This reads `seeds/customers.csv` and `seeds/orders.csv` and materializes them as tables referenced by `sources.yml`. +# Postgres +make demo ENGINE=postgres -### 2. Run the models +# Databricks Spark +make demo ENGINE=databricks_spark -```bash -fft run examples/dq_demo --env dev_duckdb +# BigQuery (pandas or BigFrames) +make demo ENGINE=bigquery BQ_FRAME=pandas +make demo ENGINE=bigquery BQ_FRAME=bigframes + +# Snowflake Snowpark +make demo ENGINE=snowflake_snowpark ``` -This builds: +The `demo` target runs: + +1. `fft seed` (load seeds) +2. `fft source freshness` +3. `fft run` (build models) +4. `fft dag` (generate DAG HTML) +5. `fft test` (run DQ tests) +6. Prints locations of artifacts (manifest, run_results, catalog, DAG HTML) -* `customers` (staging) -* `orders` (staging) -* `mart_orders_agg` (mart) +### Running manually (DuckDB example) -### 3. Run all DQ tests +From the repo root: ```bash -fft test examples/dq_demo --env dev_duckdb --select tag:example:dq_demo -``` +# 1) Seed +fft seed examples/dq_demo --env dev_duckdb -You should see a summary like: +# 2) Build models +fft run examples/dq_demo --env dev_duckdb -```text -Data Quality Summary -──────────────────── -✅ not_null customers.customer_id -✅ unique customers.customer_id -✅ greater_equal orders.amount -✅ non_negative_sum orders.amount -❕ accepted_values mart_orders_agg.status -✅ row_count_between mart_orders_agg -✅ freshness mart_orders_agg.last_order_ts -✅ reconcile_equal total_amount_orders_vs_mart -✅ reconcile_ratio_within total_amount_ratio -✅ reconcile_diff_within order_count_diff -✅ reconcile_coverage customers_covered_in_mart - -Totals -────── -✓ passed: 10 -! warnings: 1 +# 3) Run all DQ tests +fft test examples/dq_demo --env dev_duckdb --select tag:example:dq_demo ``` -(Exact output will differ, but you’ll see pass/failed/warned checks listed.) +You’ll see a summary of: -### 4. Run only reconciliation tests +* Tests derived from **contracts** (tag: `contract`) +* Explicit tests from `project.yml` (tags: `batch`, `reconcile`, `fk`, …) + +You can also run just reconciliations, just FK tests, etc.: ```bash +# Only reconciliation tests fft test examples/dq_demo --env dev_duckdb --select tag:reconcile -``` -This executes just the cross-table checks, which is handy when you’re iterating on a mart. +# Only FK-style relationship tests +fft test examples/dq_demo --env dev_duckdb --select tag:fk +``` --- -## BigQuery variant (pandas or BigFrames) - -To run the same demo on BigQuery: - -1. Copy `.env.dev_bigquery_pandas` or `.env.dev_bigquery_bigframes` to `.env` and fill in: - ```bash - FF_BQ_PROJECT= - FF_BQ_DATASET=dq_demo - FF_BQ_LOCATION= # e.g., EU or US - GOOGLE_APPLICATION_CREDENTIALS=../secrets/.json # or rely on gcloud / WIF - ``` -2. Run via the Makefile from `examples/dq_demo`: - ```bash - make demo ENGINE=bigquery BQ_FRAME=pandas # or bigframes - ``` - -Both profiles accept `allow_create_dataset` in `profiles.yml` if you want the example to create the dataset automatically. - -## Snowflake Snowpark variant - -To run on Snowflake: - -1. Copy `.env.dev_snowflake` to `.env` and populate: - ```bash - FF_SF_ACCOUNT= - FF_SF_USER= - FF_SF_PASSWORD= - FF_SF_WAREHOUSE=COMPUTE_WH - FF_SF_DATABASE=DQ_DEMO - FF_SF_SCHEMA=DQ_DEMO - FF_SF_ROLE= - ``` -2. Install the Snowflake extra if needed: - ```bash - pip install "fastflowtransform[snowflake]" - ``` -3. Run via the Makefile: - ```bash - make demo ENGINE=snowflake_snowpark - ``` - -The Snowflake profile enables `allow_create_schema`, so the schema is created automatically on first run when permitted. - ## Things to experiment with -To understand the tests better, intentionally break the data and re-run `fft test`: +To understand the tests better, intentionally break the data and re-run +`fft test`: -* Set one `customers.customer_id` to `NULL` → watch `not_null` fail. -* Duplicate a `customer_id` → watch `unique` fail. -* Put a negative `amount` in `orders.csv` → `greater_equal` and `non_negative_sum` fail. -* Add a new `status` value (e.g. `"paused"`) → `accepted_values` warns. -* Drop a customer from `mart_orders_agg` manually (or filter it out in SQL) → `reconcile_coverage` fails. +* Set one `customers.customer_id` to `NULL` → `not_null` (from contracts) fails. +* Duplicate a `customer_id` → `unique` (from contracts) fails. +* Put a negative `amount` in `seed_orders.csv` → `greater_equal` and + `non_negative_sum` fail. +* Change `status` to a value not in the enum → `accepted_values` fails. +* Drop a customer from `customers` or change an ID → `relationships` and + reconciliation tests complain. * Change an amount in the mart only → reconciliation tests fail. +* Push an order timestamp into the future → `no_future_orders` fails. +* Change a physical column type in the warehouse to disagree with the + contract → `column_physical_type` fails. This makes it very clear what each test guards against. @@ -589,24 +824,38 @@ This makes it very clear what each test guards against. The Data Quality Demo is designed to be: * **Small and readable** – customers, orders, and a single mart. -* **Complete** – exercises every built-in FFT DQ test type. +* **Complete** – exercises: + + * Built-in FFT DQ tests, + * Tests generated from contracts, + * Custom Python & SQL tests. * **Practical** – real-world patterns like: - * typing in staging models, - * testing freshness on a mart timestamp, - * reconciling sums and row counts across tables. + * Typing in staging models, + * Testing freshness on staging tables and sources, + * Reconciling sums and row counts across tables, + * Enforcing physical types per engine. + +Once you’re comfortable with this example, you can copy the patterns into your +real projects: -Once you’re comfortable with this example, you can copy the patterns into your real project: start with staging-level checks, then layer in reconciliations and freshness on your most important marts. +1. Start with **contracts** and simple column tests on staging. +2. Add **freshness** on key timestamps and sources. +3. Layer in **reconciliations** across marts and fact tables. +4. Add **custom tests** when built-ins aren’t enough. > **Tip – Source vs. table freshness** -> -> The demo uses the `freshness` test type on the mart (`mart_orders_agg.last_order_ts`). -> For *source-level freshness* (e.g. “when was `crm.orders` last loaded?”), define -> freshness rules on your sources and run: -> +> +> The demo uses: +> +> * `freshness` tests on tables (`orders.order_ts`), and +> * `freshness` in `sources.yml` (via `_ff_loaded_at`). +> +> Run source freshness with: +> > ```bash > fft source freshness examples/dq_demo --env dev_duckdb > ``` -> -> This complements table-level DQ tests by checking whether your inputs are recent enough -> *before* you even build marts. +> +> This complements table-level DQ tests by checking whether your inputs are +> recent enough *before* you even build marts. diff --git a/docs/examples/Local_Engine_Setup.md b/docs/examples/Local_Engine_Setup.md index 6ef2fdb..04ce832 100644 --- a/docs/examples/Local_Engine_Setup.md +++ b/docs/examples/Local_Engine_Setup.md @@ -172,7 +172,7 @@ The BigQuery client in `fastflowtransform` will pick this up automatically **as make ENGINE=bigquery test ``` - `fft test` uses the BigQuery shim to run checks like `not_null`, `unique`, + `fft test` uses the BigQuery to run checks like `not_null`, `unique`, `row_count_between`, `greater_equal`, etc. against `${FF_BQ_PROJECT}.${FF_BQ_DATASET}.
`. diff --git a/examples/api_demo/sources.yml b/examples/api_demo/sources.yml index 84e04d5..4048b3c 100644 --- a/examples/api_demo/sources.yml +++ b/examples/api_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples/basic_demo/sources.yml b/examples/basic_demo/sources.yml index d48deca..2b8db8a 100644 --- a/examples/basic_demo/sources.yml +++ b/examples/basic_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples/cache_demo/sources.yml b/examples/cache_demo/sources.yml index 0490edc..174bb70 100644 --- a/examples/cache_demo/sources.yml +++ b/examples/cache_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm schema: cache_demo diff --git a/examples/ci_demo/sources.yml b/examples/ci_demo/sources.yml index 83436dc..fa207c5 100644 --- a/examples/ci_demo/sources.yml +++ b/examples/ci_demo/sources.yml @@ -1,5 +1,5 @@ # Source declarations describe external tables. See docs/Sources.md for details. -version: 2 +version: 1 # sources: # Example: # - name: raw diff --git a/examples/dq_demo/contracts.yml b/examples/dq_demo/contracts.yml new file mode 100644 index 0000000..8834b05 --- /dev/null +++ b/examples/dq_demo/contracts.yml @@ -0,0 +1,19 @@ +version: 1 + +defaults: + columns: + - match: + name: ".*_id$" + type: integer + nullable: false + + - match: + name: "created_at" + type: timestamp + nullable: false + + - match: + name: ".*_ts$" + type: timestamp + nullable: true + description: "Timestamp-like but allowed to be null in some pipelines" diff --git a/examples/dq_demo/models/marts/mart_orders_agg.contracts.yml b/examples/dq_demo/models/marts/mart_orders_agg.contracts.yml new file mode 100644 index 0000000..5fa2cba --- /dev/null +++ b/examples/dq_demo/models/marts/mart_orders_agg.contracts.yml @@ -0,0 +1,36 @@ +version: 1 +table: mart_orders_agg + +columns: + customer_id: + type: integer + nullable: false + + customer_name: + type: string + nullable: false + + status: + type: string + nullable: false + enum: + - active + - inactive # keep in sync with customers.status + + order_count: + type: integer + nullable: false + min: 0 + + total_amount: + type: double + nullable: false + min: 0 + + first_order_ts: + type: timestamp + nullable: false + + last_order_ts: + type: timestamp + nullable: false diff --git a/examples/dq_demo/models/staging/customers.contracts.yml b/examples/dq_demo/models/staging/customers.contracts.yml new file mode 100644 index 0000000..2ac46ff --- /dev/null +++ b/examples/dq_demo/models/staging/customers.contracts.yml @@ -0,0 +1,46 @@ +version: 1 +table: customers + +columns: + customer_id: + type: integer + nullable: false + physical: + duckdb: integer + postgres: integer + bigquery: INT64 + snowflake_snowpark: NUMBER + databricks_spark: INT + + name: + type: string + nullable: false + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + snowflake_snowpark: TEXT + databricks_spark: STRING + + status: + type: string + nullable: false + enum: + - active + - inactive + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + snowflake_snowpark: TEXT + databricks_spark: STRING + + created_at: + type: timestamp + nullable: false + physical: + duckdb: TIMESTAMP + postgres: timestamp without time zone + bigquery: TIMESTAMP + snowflake_snowpark: TIMESTAMP_NTZ + databricks_spark: TIMESTAMP diff --git a/examples/dq_demo/models/staging/orders.contracts.yml b/examples/dq_demo/models/staging/orders.contracts.yml new file mode 100644 index 0000000..45d7a7c --- /dev/null +++ b/examples/dq_demo/models/staging/orders.contracts.yml @@ -0,0 +1,12 @@ +version: 1 +table: orders + +columns: + order_id: + type: integer + + customer_id: + type: integer + + order_ts: + type: timestamp diff --git a/examples/dq_demo/project.yml b/examples/dq_demo/project.yml index 49b6f85..8ea9e70 100644 --- a/examples/dq_demo/project.yml +++ b/examples/dq_demo/project.yml @@ -10,21 +10,10 @@ seeds: {} tests: # --- Single-table checks ---------------------------------------------------- - - type: not_null - table: customers - column: customer_id - tags: [example:dq_demo, batch] - - - type: unique - table: customers - column: customer_id - tags: [example:dq_demo, batch] - - - type: accepted_values + - type: row_count_between table: customers - column: status - values: [active, inactive] - severity: warn # demo of warn vs error + min_rows: 1 + max_rows: 100 tags: [example:dq_demo, batch] - type: greater_equal @@ -38,12 +27,6 @@ tests: column: amount tags: [example:dq_demo, batch] - - type: row_count_between - table: customers - min_rows: 1 - max_rows: 100 - tags: [example:dq_demo, batch] - - type: relationships table: orders column: customer_id diff --git a/examples/dq_demo/sources.yml b/examples/dq_demo/sources.yml index 57bc26a..93d49e4 100644 --- a/examples/dq_demo/sources.yml +++ b/examples/dq_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples/dq_demo/tests/dq/min_positive_share.ff.py b/examples/dq_demo/tests/dq/min_positive_share.ff.py index 47c054d..67da3d1 100644 --- a/examples/dq_demo/tests/dq/min_positive_share.ff.py +++ b/examples/dq_demo/tests/dq/min_positive_share.ff.py @@ -25,7 +25,7 @@ class MinPositiveShareParams(BaseModel): @dq_test("min_positive_share", params_model=MinPositiveShareParams) def min_positive_share( - con: Any, + executor: Any, table: str, column: str | None, params: dict[str, Any], @@ -48,8 +48,8 @@ def min_positive_share( else: pos_sql = f"select count(*) from {table} where {column} > 0" - total = testing._scalar(con, total_sql) - positives = testing._scalar(con, pos_sql) + total = testing._scalar(executor, total_sql) + positives = testing._scalar(executor, pos_sql) example_sql = f"{pos_sql}; -- positives\n{total_sql}; -- total" diff --git a/examples/env_matrix/sources.yml b/examples/env_matrix/sources.yml index cac9e94..16f3749 100644 --- a/examples/env_matrix/sources.yml +++ b/examples/env_matrix/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: raw diff --git a/examples/hooks_demo/sources.yml b/examples/hooks_demo/sources.yml index 4386923..293a2c5 100644 --- a/examples/hooks_demo/sources.yml +++ b/examples/hooks_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: raw diff --git a/examples/incremental_demo/sources.yml b/examples/incremental_demo/sources.yml index 6715df2..014c7b6 100644 --- a/examples/incremental_demo/sources.yml +++ b/examples/incremental_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: raw diff --git a/examples/macros_demo/sources.yml b/examples/macros_demo/sources.yml index d451751..1d0f8af 100644 --- a/examples/macros_demo/sources.yml +++ b/examples/macros_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples/materializations_demo/sources.yml b/examples/materializations_demo/sources.yml index 1fa3552..44dbb1d 100644 --- a/examples/materializations_demo/sources.yml +++ b/examples/materializations_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: demo diff --git a/examples/packages_demo/main_project/sources.yml b/examples/packages_demo/main_project/sources.yml index 7866e1e..61d3cca 100644 --- a/examples/packages_demo/main_project/sources.yml +++ b/examples/packages_demo/main_project/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples/snapshot_demo/sources.yml b/examples/snapshot_demo/sources.yml index d48deca..2b8db8a 100644 --- a/examples/snapshot_demo/sources.yml +++ b/examples/snapshot_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: crm diff --git a/examples_article/building_locally_demo/sources.yml b/examples_article/building_locally_demo/sources.yml index 7ef2d9f..1a7cfe3 100644 --- a/examples_article/building_locally_demo/sources.yml +++ b/examples_article/building_locally_demo/sources.yml @@ -1,4 +1,4 @@ -version: 2 +version: 1 sources: - name: raw diff --git a/examples_article/http_cache_demo/sources.yml b/examples_article/http_cache_demo/sources.yml index 83436dc..fa207c5 100644 --- a/examples_article/http_cache_demo/sources.yml +++ b/examples_article/http_cache_demo/sources.yml @@ -1,5 +1,5 @@ # Source declarations describe external tables. See docs/Sources.md for details. -version: 2 +version: 1 # sources: # Example: # - name: raw diff --git a/mkdocs.yml b/mkdocs.yml index 8cb8f0d..d2a933d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -24,6 +24,7 @@ nav: - Technical Overview: Technical_Overview.md - API & Models: Api_Models.md - Configuration & Macros: Config_and_Macros.md + - Contracts: Contracts.md - Cache & Parallelism: Cache_and_Parallelism.md - Incremental Processing: Incremental.md - Profiles & Environments: Profiles.md diff --git a/pyproject.toml b/pyproject.toml index cdf6cbf..74cfd22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastflowtransform" -version = "0.6.11" +version = "0.6.13" description = "Python framework for SQL & Python data transformation, ETL pipelines, and dbt-style data modeling" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/src/fastflowtransform/cli/__init__.py b/src/fastflowtransform/cli/__init__.py index 7f22b7f..88fff42 100644 --- a/src/fastflowtransform/cli/__init__.py +++ b/src/fastflowtransform/cli/__init__.py @@ -8,7 +8,6 @@ from fastflowtransform.cli.bootstrap import ( CLIContext, _die, - _get_test_con, _load_project_and_env, _make_executor, _parse_cli_vars, @@ -170,7 +169,6 @@ def main( "_build_predicates", "_compile_selector", "_die", - "_get_test_con", "_infer_sql_ref_aliases", "_load_project_and_env", "_make_executor", diff --git a/src/fastflowtransform/cli/bootstrap.py b/src/fastflowtransform/cli/bootstrap.py index e7ee61d..ebb4a10 100644 --- a/src/fastflowtransform/cli/bootstrap.py +++ b/src/fastflowtransform/cli/bootstrap.py @@ -16,7 +16,6 @@ from fastflowtransform.config.budgets import BudgetsConfig, load_budgets_config from fastflowtransform.core import REGISTRY from fastflowtransform.errors import DependencyNotFoundError -from fastflowtransform.executors._shims import BigQueryConnShim, SAConnShim from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.logging import echo from fastflowtransform.settings import ( @@ -36,7 +35,7 @@ class CLIContext: profile: Profile budgets_cfg: BudgetsConfig | None = None - def make_executor(self) -> tuple[Any, Callable, Callable]: + def make_executor(self) -> tuple[BaseExecutor, Callable, Callable]: executor, run_sql, run_py = _make_executor(self.profile, self.jinja_env) self._configure_budget_limit(executor) return executor, run_sql, run_py @@ -316,30 +315,7 @@ def _parse_cli_vars(pairs: list[str]) -> dict[str, object]: return out -def _get_test_con(executor: Any) -> Any: - """ - Return a connection with .execute(...) that understands sequences and (sql, params). - Reuse shims on the executor or build an appropriate one when needed. - """ - if hasattr(executor, "engine"): - try: - return SAConnShim(executor.engine, schema=getattr(executor, "schema", None)) - except Exception: - pass - if hasattr(executor, "client") and hasattr(executor, "dataset"): - try: - return BigQueryConnShim(executor.client, executor.dataset, executor.location) - except Exception: - try: - return BigQueryConnShim(executor.client, getattr(executor, "location", None)) - except Exception: - pass - if hasattr(executor, "con") and hasattr(executor.con, "execute"): - return executor.con - return executor - - -def _make_executor(prof: Profile, jenv: Environment) -> tuple[Any, Callable, Callable]: +def _make_executor(prof: Profile, jenv: Environment) -> tuple[BaseExecutor, Callable, Callable]: ex: BaseExecutor if prof.engine == "duckdb": DuckExecutor = _import_optional( diff --git a/src/fastflowtransform/cli/init_cmd.py b/src/fastflowtransform/cli/init_cmd.py index e8e3c6d..1cd50f2 100644 --- a/src/fastflowtransform/cli/init_cmd.py +++ b/src/fastflowtransform/cli/init_cmd.py @@ -143,7 +143,7 @@ def _build_sources_yaml() -> str: return "\n".join( [ "# Source declarations describe external tables. See docs/Sources.md for details.", - "version: 2", + "version: 1", "# sources:", " # Example:", " # - name: raw", diff --git a/src/fastflowtransform/cli/source_cmd.py b/src/fastflowtransform/cli/source_cmd.py index c626cff..4ac3080 100644 --- a/src/fastflowtransform/cli/source_cmd.py +++ b/src/fastflowtransform/cli/source_cmd.py @@ -3,7 +3,7 @@ import typer -from fastflowtransform.cli.bootstrap import _get_test_con, _prepare_context +from fastflowtransform.cli.bootstrap import _prepare_context from fastflowtransform.cli.options import EngineOpt, EnvOpt, ProjectArg, VarsOpt from fastflowtransform.logging import bind_context, clear_context, echo from fastflowtransform.source_freshness import SourceFreshnessResult, run_source_freshness @@ -31,12 +31,10 @@ def freshness( # Get a live connection / executor from the context execu, _run_sql, _run_py = ctx.make_executor() - con = _get_test_con(execu) # Run freshness checks over all sources with a configured freshness block results: list[SourceFreshnessResult] = run_source_freshness( execu, - con=con, engine=ctx.profile.engine, ) diff --git a/src/fastflowtransform/cli/test_cmd.py b/src/fastflowtransform/cli/test_cmd.py index 3261d22..48b4a82 100644 --- a/src/fastflowtransform/cli/test_cmd.py +++ b/src/fastflowtransform/cli/test_cmd.py @@ -1,7 +1,6 @@ # fastflowtransform/cli/test_cmd.py from __future__ import annotations -import os import re import time from collections.abc import Callable, Iterable, Mapping @@ -11,7 +10,7 @@ import typer -from fastflowtransform.cli.bootstrap import _get_test_con, _prepare_context +from fastflowtransform.cli.bootstrap import _prepare_context from fastflowtransform.cli.options import ( EngineOpt, EnvOpt, @@ -25,9 +24,11 @@ BaseProjectTestConfig, parse_project_yaml_config, ) +from fastflowtransform.contracts import load_contract_tests from fastflowtransform.core import REGISTRY from fastflowtransform.dag import topo_sort from fastflowtransform.errors import ModelExecutionError +from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.logging import echo from fastflowtransform.schema_loader import Severity, TestSpec, load_schema_tests from fastflowtransform.testing.discovery import ( @@ -111,11 +112,6 @@ def _execute_models( on_error(name, node, exc) -def _maybe_print_marker(con: Any) -> None: - if os.getenv("FFT_SQL_DEBUG") == "1": - echo(getattr(con, "marker", "NO_SHIM")) - - def _run_models( pred: Callable[[Any], bool], run_sql: Callable[[Any], Any], @@ -342,7 +338,7 @@ def _prepare_test( return _prepare_test_from_mapping(raw_test, executor) -def _run_dq_tests(con: Any, tests: Iterable[Any], executor: Any) -> list[DQResult]: +def _run_dq_tests(executor: BaseExecutor, tests: Iterable[Any]) -> list[DQResult]: results: list[DQResult] = [] for raw_test in tests: @@ -381,7 +377,7 @@ def _run_dq_tests(con: Any, tests: Iterable[Any], executor: Any) -> list[DQResul ) continue - ok, msg, example = runner(con, table_for_exec, col, params) + ok, msg, example = runner(executor, table_for_exec, col, params) ms = int((time.perf_counter() - t0) * 1000) param_str = _format_params_for_summary(kind, params) @@ -476,9 +472,6 @@ def test( legacy_tag_only = _is_legacy_test_token(tokens) and not has_model_matches execu, run_sql, run_py = ctx.make_executor() - con = _get_test_con(execu) - _maybe_print_marker(con) - model_pred = (lambda _n: True) if legacy_tag_only else pred # Run models; if a model fails, show friendly error then exit(1). if not skip_build: @@ -492,13 +485,15 @@ def test( tests: list[Any] = _load_tests(ctx.project) # 2) schema YAML tests tests.extend(load_schema_tests(ctx.project)) + # 2b) contracts tests (contracts/*.contracts.yml) + tests.extend(load_contract_tests(ctx.project)) # 3) optional legacy tagfilter (e.g., "batch") tests = _apply_legacy_tag_filter(tests, tokens, legacy_token=legacy_tag_only) if not tests: typer.secho("No tests configured.", fg="bright_black") raise typer.Exit(code=0) - results = _run_dq_tests(con, tests, execu) + results = _run_dq_tests(execu, tests) _print_summary(results) # Exit code: count only ERROR fails diff --git a/src/fastflowtransform/config/contracts.py b/src/fastflowtransform/config/contracts.py new file mode 100644 index 0000000..cf34780 --- /dev/null +++ b/src/fastflowtransform/config/contracts.py @@ -0,0 +1,328 @@ +# fastflowtransform/config/contracts.py +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import yaml +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from fastflowtransform.config.loaders import NoDupLoader +from fastflowtransform.errors import ContractsConfigError + + +class PhysicalTypeConfig(BaseModel): + """ + Engine-specific physical type configuration for a column. + + All fields are optional; you can set: + - default: applies to all engines if no engine-specific override is set + - duckdb, postgres, bigquery, snowflake_snowpark, databricks_spark: + engine-specific physical types (e.g. "integer", "NUMERIC", "TIMESTAMP") + + Example YAML: + physical: "integer" + + physical: + default: numeric + postgres: numeric + bigquery: NUMERIC + """ + + model_config = ConfigDict(extra="forbid") + + default: str | None = None + duckdb: str | None = None + postgres: str | None = None + bigquery: str | None = None + snowflake_snowpark: str | None = None + databricks_spark: str | None = None + + +class ColumnContractModel(BaseModel): + """ + Column-level contract definition. + + Example YAML fragment: + + columns: + id: + type: integer + nullable: false + status: + type: string + enum: ["active", "inactive"] + amount: + type: double + nullable: false + min: 0 + max: 10000 + email: + type: string + regex: "^[^@]+@[^@]+$" + """ + + model_config = ConfigDict(extra="forbid") + + # Optional semantic / physical type hint ("integer", "string", "timestamp", ...) + type: str | None = None + + # Engine-specific physical DB types; see PhysicalTypeConfig. + physical: PhysicalTypeConfig | None = None + + # Nullability: nullable=False → not_null check + nullable: bool | None = None + + # Uniqueness: unique=True → unique test + unique: bool | None = None + + # Enumerated allowed values (accepted_values test) + enum: list[Any] | None = None + + # Regex constraint; currently used via a generic regex_match test + regex: str | None = None + + # Numeric range (inclusive) for numeric-like columns + min: float | int | None = None + max: float | int | None = None + + # Optional free-form description (handy for docs later) + description: str | None = None + + @field_validator("enum", mode="before") + @classmethod + def _normalize_enum(cls, v: Any) -> list[Any] | None: + """ + Allow: + enum: "A" -> ["A"] + enum: [1, 2, 3] -> [1, 2, 3] + """ + if v is None: + return None + if isinstance(v, (list, tuple)): + return list(v) + return [v] + + @field_validator("physical", mode="before") + @classmethod + def _coerce_physical(cls, v: Any) -> Any: + """ + Accept either: + physical: "integer" + physical: + default: numeric + postgres: numeric + bigquery: NUMERIC + and normalize to a PhysicalTypeConfig-compatible dict. + """ + if v is None: + return None + if isinstance(v, PhysicalTypeConfig): + return v + if isinstance(v, str): + # Shorthand: same type for all engines → default + return {"default": v} + if isinstance(v, dict): + # Let Pydantic validate keys; we just pass through. + return v + raise TypeError( + "physical must be either a string or a mapping of engine keys to types " + "(e.g. {default: numeric, postgres: numeric})" + ) + + +class ContractsFileModel(BaseModel): + """ + One contracts file. + + Convention: + - One file describes contracts for exactly one table/relation. + - The table name is what will be used in DQ tests (SELECT ... FROM
). + + Example `*.contracts.yml`: + + version: 1 + table: users_enriched + columns: + id: + type: integer + nullable: false + status: + type: string + enum: ["active", "inactive"] + email: + type: string + nullable: false + regex: "^[^@]+@[^@]+$" + """ + + model_config = ConfigDict(extra="forbid") + + version: int = 1 + table: str = Field(..., description="Logical/physical table name the contract applies to") + columns: dict[str, ColumnContractModel] = Field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# Project-level contracts (contracts.yml at project root) +# --------------------------------------------------------------------------- + + +class ColumnMatchModel(BaseModel): + """ + Column match expression for project-level defaults. + + Currently supports: + - name: regex on column name (required) + - table: optional regex on table name (future-proof; optional) + """ + + model_config = ConfigDict(extra="forbid") + + name: str = Field(..., description="Regex to match column name") + table: str | None = Field( + default=None, description="Optional regex to restrict to specific tables" + ) + + @model_validator(mode="after") + def _strip(self) -> ColumnMatchModel: + object.__setattr__(self, "name", self.name.strip()) + if self.table is not None: + object.__setattr__(self, "table", self.table.strip()) + return self + + +class ColumnDefaultsRuleModel(BaseModel): + """ + One rule under defaults.columns in contracts.yml. + + Example: + + defaults: + columns: + - match: + name: ".*_id$" + type: integer + nullable: false + - match: + name: "created_at" + type: timestamp + nullable: false + """ + + model_config = ConfigDict(extra="forbid") + + match: ColumnMatchModel + # Payload is the same shape as ColumnContractModel but optional: + type: str | None = None + physical: PhysicalTypeConfig | None = None + nullable: bool | None = None + unique: bool | None = None + enum: list[Any] | None = None + regex: str | None = None + min: float | None = None + max: float | None = None + description: str | None = None + + @field_validator("enum", mode="before") + @classmethod + def _normalize_enum(cls, v: Any) -> list[Any] | None: + if v is None: + return None + if isinstance(v, (list, tuple)): + return list(v) + return [v] + + @field_validator("physical", mode="before") + @classmethod + def _coerce_physical(cls, v: Any) -> Any: + if v is None: + return None + if isinstance(v, PhysicalTypeConfig): + return v + if isinstance(v, str): + return {"default": v} + if isinstance(v, dict): + return v + raise TypeError( + "defaults.columns[*].physical must be either a string or a mapping of engine " + "keys to types (e.g. {default: numeric, postgres: numeric})" + ) + + +class ContractsDefaultsModel(BaseModel): + """ + Root defaults block for project-level contracts.yml. + + Example: + + version: 1 + + defaults: + models: + - match: + name: "staging.*" + materialized: table + + columns: + - match: + name: ".*_id$" + type: integer + nullable: false + - match: + name: "created_at" + type: timestamp + nullable: false + """ + + model_config = ConfigDict(extra="forbid") + + # Future global defaults (e.g. a default severity for contract tests) could live here. + columns: list[ColumnDefaultsRuleModel] = Field(default_factory=list) + + +class ProjectContractsModel(BaseModel): + """ + Top-level model for project-level contracts.yml. + + Only defines defaults, no table-specific contracts (those live in + per-table *.contracts.yml files). + """ + + model_config = ConfigDict(extra="forbid") + + version: int = 1 + defaults: ContractsDefaultsModel = Field(default_factory=ContractsDefaultsModel) + + +# ---- Parsers ----------------------------------------------------------------- + + +def parse_contracts_file(path: Path) -> ContractsFileModel: + """ + Load and validate a single *.contracts.yml file. + Raises a Pydantic validation error or yaml.YAMLError on malformed input. + """ + try: + raw = yaml.load(path.read_text(encoding="utf-8"), Loader=NoDupLoader) or {} + return ContractsFileModel.model_validate(raw) + except Exception as exc: + hint = "Check the contracts YAML for duplicate keys or invalid structure." + raise ContractsConfigError( + f"Failed to parse contracts file: {exc}", path=str(path), hint=hint + ) from exc + + +def parse_project_contracts_file(path: Path) -> ProjectContractsModel: + """ + Load and validate the project-level contracts.yml file. + Returns ProjectContractsModel, raising on malformed input. + """ + try: + raw = yaml.load(path.read_text(encoding="utf-8"), Loader=NoDupLoader) or {} + return ProjectContractsModel.model_validate(raw) + except Exception as exc: + hint = "Check the project-level contracts.yml for duplicate keys or invalid structure." + raise ContractsConfigError( + f"Failed to parse project contracts file: {exc}", path=str(path), hint=hint + ) from exc diff --git a/src/fastflowtransform/config/loaders.py b/src/fastflowtransform/config/loaders.py new file mode 100644 index 0000000..7032b88 --- /dev/null +++ b/src/fastflowtransform/config/loaders.py @@ -0,0 +1,19 @@ +import yaml +from yaml.loader import SafeLoader + + +class NoDupLoader(SafeLoader): + pass + + +def _construct_mapping(loader, node, deep=False): + mapping = {} + for key_node, value_node in node.value: + key = loader.construct_object(key_node, deep=deep) + if key in mapping: + raise ValueError(f"Duplicate key {key!r} in {node.start_mark}") + mapping[key] = loader.construct_object(value_node, deep=deep) + return mapping + + +NoDupLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, _construct_mapping) diff --git a/src/fastflowtransform/config/sources.py b/src/fastflowtransform/config/sources.py index dce68da..ee02936 100644 --- a/src/fastflowtransform/config/sources.py +++ b/src/fastflowtransform/config/sources.py @@ -305,7 +305,7 @@ class SourcesFileConfig(BaseModel): model_config = ConfigDict(extra="forbid") - version: Literal[2] + version: Literal[1] sources: list[SourceGroupConfig] = Field(default_factory=list) diff --git a/src/fastflowtransform/contracts.py b/src/fastflowtransform/contracts.py new file mode 100644 index 0000000..bd2c4a7 --- /dev/null +++ b/src/fastflowtransform/contracts.py @@ -0,0 +1,300 @@ +# fastflowtransform/contracts.py +from __future__ import annotations + +import re +from pathlib import Path +from typing import Any + +from fastflowtransform.config.contracts import ( + ColumnContractModel, + ContractsDefaultsModel, + ContractsFileModel, + ProjectContractsModel, + parse_contracts_file, + parse_project_contracts_file, +) +from fastflowtransform.logging import get_logger +from fastflowtransform.schema_loader import Severity, TestSpec + +logger = get_logger("contracts") + + +# --------------------------------------------------------------------------- +# Discovery helpers +# --------------------------------------------------------------------------- + + +def _discover_contract_paths(project_dir: Path) -> list[Path]: + """ + Discover *.contracts.yml files under models/. + + Convention: + - You can place contracts anywhere under models/, as long as the file + name ends with ".contracts.yml". + - Each file describes contracts for one logical table (ContractsFileModel.table). + """ + models_dir = project_dir / "models" + if not models_dir.exists(): + return [] + + paths: list[Path] = [] + for p in models_dir.rglob("*.contracts.yml"): + if p.is_file(): + paths.append(p) + return sorted(paths) + + +def load_contracts(project_dir: Path) -> dict[str, ContractsFileModel]: + """ + Load all per-table contracts from *.contracts.yml under models/. + + Returns: + dict[table_name, ContractsFileModel] + If multiple files define the same `table`, the last one wins (with a warning). + """ + contracts: dict[str, ContractsFileModel] = {} + for path in _discover_contract_paths(project_dir): + cfg = parse_contracts_file(path) + + table = cfg.table + if table in contracts: + logger.warning( + "Multiple contracts for table %r: overriding previous definition with %s", + table, + path, + ) + contracts[table] = cfg + + return contracts + + +def _load_project_contracts(project_dir: Path) -> ProjectContractsModel | None: + """ + Load project-level contracts.yml (if present). + + The file is optional; if it does not exist, None is returned. + """ + path = project_dir / "contracts.yml" + if not path.exists(): + return None + + cfg = parse_project_contracts_file(path) + + return cfg + + +# --------------------------------------------------------------------------- +# Column defaults application +# --------------------------------------------------------------------------- + + +def _apply_column_defaults( + col_name: str, + table: str, + col: ColumnContractModel, + defaults: ContractsDefaultsModel | None, +) -> ColumnContractModel: + """ + Merge project-level column defaults into a column contract. + + Rules: + - We only consider defaults.columns rules where the regex on + name matches `col_name` *and* optional table regex matches `table`. + - Rules are applied in file order; later rules override earlier ones. + - Per-table contracts take precedence: we only fill attributes that are + still None on the ColumnContractModel. + """ + if defaults is None or not defaults.columns: + return col + + # Start from the explicit per-table column config (already validated) + data: dict[str, Any] = col.model_dump() + + for rule in defaults.columns: + m = rule.match + # name regex is required + if not re.search(m.name, col_name): + continue + # optional table regex + if m.table and not re.search(m.table, table): + continue + + # For each field, only apply if current value is None and rule defines a value + if data.get("type") is None and rule.type is not None: + data["type"] = rule.type + if data.get("physical") is None and rule.physical is not None: + data["physical"] = rule.physical + if data.get("nullable") is None and rule.nullable is not None: + data["nullable"] = rule.nullable + if data.get("unique") is None and rule.unique is not None: + data["unique"] = rule.unique + if data.get("enum") is None and rule.enum is not None: + data["enum"] = list(rule.enum) + if data.get("regex") is None and rule.regex is not None: + data["regex"] = rule.regex + if data.get("min") is None and rule.min is not None: + data["min"] = rule.min + if data.get("max") is None and rule.max is not None: + data["max"] = rule.max + if data.get("description") is None and rule.description is not None: + data["description"] = rule.description + + # Re-validate into a ColumnContractModel (cheap; keeps invariants) + return ColumnContractModel.model_validate(data) + + +# --------------------------------------------------------------------------- +# DQ test expansion from contracts +# --------------------------------------------------------------------------- + + +def _contract_tests_for_table( + table: str, + contract: ContractsFileModel, + *, + defaults: ContractsDefaultsModel | None, + default_severity: Severity = "error", +) -> list[TestSpec]: + """ + Convert column contracts for a single table into TestSpec instances, taking + project-level column defaults into account. + """ + specs: list[TestSpec] = [] + + # Base tags shared by all contract-derived tests. You can always add more + # tags at the project level later if needed. + base_tags: list[str] = ["contract"] + + for col_name, col in contract.columns.items(): + effective_col = _apply_column_defaults(col_name, table, col, defaults) + + # 0) Physical type assertion → column_physical_type test + if effective_col.physical is not None: + specs.append( + TestSpec( + type="column_physical_type", + table=table, + column=col_name, + params={"physical": effective_col.physical}, + severity=default_severity, + tags=list(base_tags), + ) + ) + + # 1) Nullability: nullable=False → not_null test + if effective_col.nullable is False: + specs.append( + TestSpec( + type="not_null", + table=table, + column=col_name, + params={}, # no extra params + severity=default_severity, + tags=list(base_tags), + ) + ) + + # 1b) Uniqueness → unique test + if effective_col.unique: + specs.append( + TestSpec( + type="unique", + table=table, + column=col_name, + params={}, + severity=default_severity, + tags=list(base_tags), + ) + ) + + # 2) Enumerated values → accepted_values test (if any values declared) + if effective_col.enum: + specs.append( + TestSpec( + type="accepted_values", + table=table, + column=col_name, + params={"values": list(effective_col.enum)}, + severity=default_severity, + tags=list(base_tags), + ) + ) + + # 3) Numeric range (inclusive) → between test + if effective_col.min is not None or effective_col.max is not None: + params: dict[str, Any] = {} + if effective_col.min is not None: + params["min"] = effective_col.min + if effective_col.max is not None: + params["max"] = effective_col.max + + specs.append( + TestSpec( + type="between", + table=table, + column=col_name, + params=params, + severity=default_severity, + tags=list(base_tags), + ) + ) + + # 4) Regex constraint → regex_match test (Python-side evaluation) + if effective_col.regex: + specs.append( + TestSpec( + type="regex_match", + table=table, + column=col_name, + params={"pattern": effective_col.regex}, + severity=default_severity, + tags=list(base_tags), + ) + ) + + return specs + + +def build_contract_tests( + contracts: dict[str, ContractsFileModel], + *, + defaults: ContractsDefaultsModel | None = None, + default_severity: Severity = "error", +) -> list[TestSpec]: + """ + Convert a set of ContractsFileModel objects into a flat list of TestSpec. + + `defaults` is the (optional) project-level defaults section from contracts.yml. + """ + if not contracts: + return [] + + all_specs: list[TestSpec] = [] + for table, cfg in contracts.items(): + all_specs.extend( + _contract_tests_for_table( + table, + cfg, + defaults=defaults, + default_severity=default_severity, + ) + ) + return all_specs + + +def load_contract_tests(project_dir: Path) -> list[TestSpec]: + """ + High-level helper used by the CLI: + + project_dir -> [TestSpec, ...] + + This is what we plug into `fft test` so contracts become "first-class" tests. + """ + contracts = load_contracts(project_dir) + if not contracts: + return [] + + project_cfg = _load_project_contracts(project_dir) + defaults = project_cfg.defaults if project_cfg is not None else None + + return build_contract_tests(contracts, defaults=defaults) diff --git a/src/fastflowtransform/decorators.py b/src/fastflowtransform/decorators.py index 1346e98..fe974af 100644 --- a/src/fastflowtransform/decorators.py +++ b/src/fastflowtransform/decorators.py @@ -195,14 +195,14 @@ def dq_test( from fastflowtransform import dq_test @dq_test("email_domain_allowed") - def email_domain_allowed(con, table, column, params): + def email_domain_allowed(executor, table, column, params): ... return True, None, "select ..." If `name` is omitted, the function name is used: @dq_test() - def email_sanity(con, table, column, params): + def email_sanity(executor, table, column, params): ... # In project.yml / schema.yml: type: email_sanity @@ -213,7 +213,7 @@ class EmailTestParams(DQParamsBase): allowed_domains: list[str] @dq_test("email_domain_allowed", params_model=EmailTestParams) - def email_domain_allowed(con, table, column, params: EmailTestParams): + def email_domain_allowed(executor, table, column, params: EmailTestParams): ... Args: diff --git a/src/fastflowtransform/errors.py b/src/fastflowtransform/errors.py index e5deccb..a8e7394 100644 --- a/src/fastflowtransform/errors.py +++ b/src/fastflowtransform/errors.py @@ -114,6 +114,24 @@ def __init__(self, message: str): super().__init__(message.replace("\n", " ").strip()) +class ContractsConfigError(FastFlowTransformError): + """ + Raised when a contracts.yml (project-level or per-table) is malformed. + """ + + def __init__( + self, + message: str, + *, + path: str | None = None, + hint: str | None = None, + code: str = "CONTRACTS_PARSE", + ): + prefix = f"{path}: " if path else "" + super().__init__(prefix + message, code=code, hint=hint) + self.path = path + + class ModelExecutionError(Exception): """Raised when a model fails to execute/render on the engine. Carries friendly context for CLI formatting. diff --git a/src/fastflowtransform/executors/_shims.py b/src/fastflowtransform/executors/_shims.py deleted file mode 100644 index f0a0d84..0000000 --- a/src/fastflowtransform/executors/_shims.py +++ /dev/null @@ -1,142 +0,0 @@ -# fastflowtransform/executors/_shims.py -from __future__ import annotations - -import re -from collections.abc import Iterable, Sequence -from typing import Any - -from sqlalchemy import text -from sqlalchemy.engine import Engine -from sqlalchemy.sql.elements import ClauseElement - -from fastflowtransform.typing import Client - - -class BigQueryConnShim: - """ - Lightweight shim so fastflowtransform.testing can call executor.con.execute(...) - against BigQuery clients. - """ - - marker = "BQ_SHIM" - - def __init__( - self, - client: Client, - location: str | None = None, - project: str | None = None, - dataset: str | None = None, - ): - self.client = client - self.location = location - self.project = project - self.dataset = dataset - - class _ResultWrapper: - """ - Minimal wrapper around a BigQuery RowIterator so that testing helpers - can call .fetchone() like on a DB-API cursor. - """ - - def __init__(self, row_iter: Any): - self._iter = iter(row_iter) - - def fetchone(self): - try: - return next(self._iter) - except StopIteration: - return None - - def execute(self, sql_or_stmts: Any) -> Any: - if isinstance(sql_or_stmts, str): - # Execute the query and return a cursor-like wrapper with .fetchone() - job = self.client.query(sql_or_stmts) - rows = job.result() - return BigQueryConnShim._ResultWrapper(rows) - - if isinstance(sql_or_stmts, Sequence) and not isinstance( - sql_or_stmts, (bytes, bytearray, str) - ): - # Execute a sequence of statements; return wrapper for the last result. - last_rows: Any = None - for stmt in sql_or_stmts: - job = self.client.query(str(stmt)) - last_rows = job.result() - return BigQueryConnShim._ResultWrapper(last_rows or []) - raise TypeError(f"Unsupported sql argument type for BigQuery shim: {type(sql_or_stmts)}") - - -_RE_PG_COR_TABLE = re.compile( - r"""^\s*create\s+or\s+replace\s+table\s+ - (?P(?:"[^"]+"|\w+)(?:\.(?:"[^"]+"|\w+))?) # optional schema + ident - \s+as\s+(?P.*)$ - """, - re.IGNORECASE | re.DOTALL | re.VERBOSE, -) - - -def _rewrite_pg_create_or_replace_table(sql: str) -> str: - """ - Rewrite 'CREATE OR REPLACE TABLE t AS ' into - 'DROP TABLE IF EXISTS t CASCADE; CREATE TABLE t AS ' for Postgres. - Leave all other SQL untouched. - """ - m = _RE_PG_COR_TABLE.match(sql or "") - if not m: - return sql - ident = m.group("ident").strip() - body = m.group("body").strip().rstrip(";\n\t ") - return f"DROP TABLE IF EXISTS {ident} CASCADE; CREATE TABLE {ident} AS {body}" - - -class SAConnShim: - """ - Compatibility layer so fastflowtransform.testing can call executor.con.execute(...) - against SQLAlchemy engines (Postgres, etc.). Adds PG-safe DDL rewrites. - """ - - marker = "PG_SHIM" - - def __init__(self, engine: Engine, schema: str | None = None): - self._engine = engine - self._schema = schema - - def _exec_one(self, conn: Any, stmt: Any, params: dict | None = None) -> Any: - # tuple (sql, params) - statement_len = 2 - if ( - isinstance(stmt, tuple) - and len(stmt) == statement_len - and isinstance(stmt[0], str) - and isinstance(stmt[1], dict) - ): - return self._exec_one(conn, stmt[0], stmt[1]) - - # sqlalchemy expression - if isinstance(stmt, ClauseElement): - return conn.execute(stmt) - - # plain string (apply rewrite, then possibly split into multiple statements) - if isinstance(stmt, str): - rewritten = _rewrite_pg_create_or_replace_table(stmt) - parts = [p.strip() for p in rewritten.split(";") if p.strip()] - res = None - for i, part in enumerate(parts): - res = conn.execute(text(part), params if (i == len(parts) - 1) else None) - return res - - # iterable of statements -> sequential execution - if isinstance(stmt, Iterable) and not isinstance(stmt, (bytes, bytearray, str)): - res = None - for s in stmt: - res = self._exec_one(conn, s) - return res - - # fallback - return self._exec_one(conn, str(stmt)) - - def execute(self, sql: Any) -> Any: - with self._engine.begin() as conn: - if self._schema: - conn.execute(text(f'SET LOCAL search_path = "{self._schema}"')) - return self._exec_one(conn, sql) diff --git a/src/fastflowtransform/executors/_test_utils.py b/src/fastflowtransform/executors/_test_utils.py new file mode 100644 index 0000000..a372e8b --- /dev/null +++ b/src/fastflowtransform/executors/_test_utils.py @@ -0,0 +1,82 @@ +# fastflowtransform/executors/_test_utils.py +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from typing import Any + + +class _ListFetchWrapper: + """ + Minimal fetch wrapper for iterable results to mimic DB-API/SQLA cursors. + """ + + def __init__(self, rows: Iterable[Any] | None): + self._rows = list(rows or []) + self._idx = 0 + + def fetchone(self) -> Any: + if self._idx >= len(self._rows): + return None + row = self._rows[self._idx] + self._idx += 1 + return row + + def fetchall(self) -> list[Any]: + if self._idx == 0: + self._idx = len(self._rows) + return list(self._rows) + rows = self._rows[self._idx :] + self._idx = len(self._rows) + return rows + + +def make_fetchable(result: Any) -> Any: + """ + Ensure a result supports .fetchone()/.fetchall(). + + - If already fetchable, return as-is. + - If it has .result() (e.g., BigQuery QueryJob), use that first. + - If it's an iterable (not string/bytes), wrap in a simple fetch wrapper. + - Otherwise return unchanged (may still fail if caller expects fetch methods). + """ + if hasattr(result, "fetchone") and hasattr(result, "fetchall"): + return result + if hasattr(result, "result") and callable(result.result): + try: + return make_fetchable(result.result()) + except Exception: + raise + if isinstance(result, Iterable) and not isinstance(result, (str, bytes, bytearray)): + return _ListFetchWrapper(result) + return result + + +def rows_to_tuples(rows: Iterable[Any] | None) -> list[tuple[Any, ...]]: + """ + Normalize various row shapes to simple tuples. + + Supported shapes: + - tuple -> returned as-is + - mapping-like via .asDict() -> tuple(values) + - general Sequence (excluding str/bytes) -> tuple(row) + - fallback: (row,) + """ + + def _one(row: Any) -> tuple[Any, ...]: + if isinstance(row, tuple): + return row + if hasattr(row, "asDict"): + try: + d = row.asDict() + if isinstance(d, dict): + return tuple(d.values()) + except Exception: + pass + if isinstance(row, Sequence) and not isinstance(row, (str, bytes, bytearray)): + try: + return tuple(row) + except Exception: + pass + return (row,) + + return [_one(r) for r in (rows or [])] diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 4a33461..64f63df 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -84,6 +84,12 @@ def _load_callable(path: str) -> Callable[..., Any]: return fn +def _scalar(executor: BaseExecutor, sql: Any) -> Any: + """Execute SQL and return the first column of the first row (or None).""" + row = executor.execute_test_sql(sql).fetchone() + return None if row is None else row[0] + + # Frame type (pandas.DataFrame, pyspark.sql.DataFrame, snowflake.snowpark.DataFrame, ...) TFrame = TypeVar("TFrame") @@ -433,6 +439,28 @@ def _execute_sql( """ raise NotImplementedError + def execute_test_sql(self, stmt: Any) -> Any: # pragma: no cover - abstract + """ + Execute a lightweight SQL statement for DQ tests. + + Implementations should accept: + - str + - (str, params dict) + - ClauseElement (optional, where supported) + - Sequence of the above (executed sequentially; return last result) + and return an object supporting .fetchone() / .fetchall(). + """ + raise NotImplementedError + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + """ + Compute delay in minutes between now and max(ts_col) for a relation. + + Returns (delay_minutes, sql_used). + Default implementation is not provided; executors implement engine-specific logic. + """ + raise NotImplementedError + def _render_ephemeral_sql(self, name: str, env: Environment) -> str: """ Render the SQL for an 'ephemeral' model and return it as a parenthesized @@ -1090,6 +1118,20 @@ def utest_clean_target(self, relation: str) -> None: """ return + # ── Column schema introspection hook ──────────────────────────────── + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + Return the engine's physical data type for `table.column`, or None + if it cannot be determined. + + Subclasses should override this. Default implementation raises so + callers can surface a clear "engine not supported" message. + """ + raise NotImplementedError( + f"Column physical type introspection is not implemented for " + f"engine '{self.engine_name}'." + ) + ENGINE_NAME: str = "generic" @property diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index e9729a5..5d997a2 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -1,12 +1,13 @@ # fastflowtransform/executors/bigquery/base.py from __future__ import annotations -from typing import TypeVar +from collections.abc import Iterable +from typing import Any, TypeVar from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._shims import BigQueryConnShim from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin +from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.bigquery._bigquery_mixin import BigQueryIdentifierMixin from fastflowtransform.executors.budget import BudgetGuard @@ -54,13 +55,62 @@ def __init__( project=self.project, location=self.location, ) - # Testing-API: con.execute(...) - self.con = BigQueryConnShim( - self.client, - location=self.location, - project=self.project, - dataset=self.dataset, + + def execute_test_sql(self, stmt: Any) -> Any: + """ + Execute lightweight SQL for DQ tests using the BigQuery client. + """ + + def _infer_param_type(value: Any) -> str: + if isinstance(value, bool): + return "BOOL" + if isinstance(value, int) and not isinstance(value, bool): + return "INT64" + if isinstance(value, float): + return "FLOAT64" + return "STRING" + + def _run_job(sql: str, params: dict[str, Any] | None = None) -> Any: + job_config = bigquery.QueryJobConfig() + if self.dataset: + job_config.default_dataset = bigquery.DatasetReference(self.project, self.dataset) + if params: + job_config.query_parameters = [ + bigquery.ScalarQueryParameter(k, _infer_param_type(v), v) + for k, v in params.items() + ] + return self.client.query(sql, job_config=job_config, location=self.location) + + def _run_one(s: Any) -> Any: + statement_len = 2 + if ( + isinstance(s, tuple) + and len(s) == statement_len + and isinstance(s[0], str) + and isinstance(s[1], dict) + ): + return _run_job(s[0], s[1]).result() + if isinstance(s, str): + # Use guarded execution path for simple statements + return self._execute_sql(s).result() + if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): + res = None + for item in s: + res = _run_one(item) + return res + return _run_job(str(s)).result() + + return make_fetchable(_run_one(stmt)) + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + sql = ( + f"select cast(TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), max({ts_col}), MINUTE) as float64) " + f"as delay_min from {table}" ) + res = self.execute_test_sql(sql) + delay = getattr(res, "fetchone", lambda: None)() + val = delay[0] if delay else None + return (float(val) if val is not None else None, sql) def _execute_sql(self, sql: str) -> _TrackedQueryJob: """ @@ -333,3 +383,54 @@ def execute_hook_sql(self, sql: str) -> None: Execute one SQL statement for pre/post/on_run hooks. """ self._execute_sql(sql).result() + + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + BigQuery: read DATA_TYPE from INFORMATION_SCHEMA.COLUMNS, handling qualified names. + """ + project = self.project + dataset = self.dataset + table_name = table + + parts = table.split(".") + if len(parts) == 3: + project, dataset, table_name = parts + elif len(parts) == 2: + dataset, table_name = parts + + table_name = table_name.strip("`") + dataset = dataset.strip("`") if dataset else dataset + project = project.strip("`") if project else project + + if not table_name: + return None + + sql = """ + select data_type + from `{catalog}.{schema}.INFORMATION_SCHEMA.COLUMNS` + where lower(table_name) = lower(@t) + and lower(column_name) = lower(@c) + limit 1 + """ + sql = sql.format( + catalog=project or self.project, + schema=dataset or self.dataset, + ) + + job = self.client.query( + sql, + job_config=bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ScalarQueryParameter("t", "STRING", table_name), + bigquery.ScalarQueryParameter("c", "STRING", column), + ], + default_dataset=bigquery.DatasetReference( + project or self.project, dataset or self.dataset + ), + ), + location=self.location, + ) + rows = list(job.result()) + if not rows: + return None + return rows[0][0] diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index 3261820..2e527ce 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -21,6 +21,7 @@ get_spark_functions, get_spark_window, ) +from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.logging import echo, echo_debug @@ -269,8 +270,6 @@ def __init__( builder = builder.config(catalog_key, _DELTA_CATALOG) self.spark = self._user_spark or builder.getOrCreate() - # Lightweight testing shim so tests can call executor.con.execute("SQL") - self.con = _SparkConnShim(self.spark) self._registered_path_sources: dict[str, dict[str, Any]] = {} self.warehouse_dir = warehouse_path self.catalog = catalog @@ -427,6 +426,33 @@ def _estimate_query_bytes(self, sql: str) -> int | None: """ return self._spark_plan_bytes(sql) + def execute_test_sql(self, stmt: Any) -> Any: + """ + Execute lightweight SQL for DQ tests via Spark and return fetchable rows. + """ + + def _run_one(s: Any) -> Any: + if isinstance(s, str): + return rows_to_tuples(self.spark.sql(s).collect()) + if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): + res = None + for item in s: + res = _run_one(item) + return res + return rows_to_tuples(self.spark.sql(str(s)).collect()) + + return make_fetchable(_run_one(stmt)) + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + sql = ( + f"select (unix_timestamp(current_timestamp()) - unix_timestamp(max({ts_col}))) / 60.0 " + f"as delay_min from {table}" + ) + res = self.execute_test_sql(sql) + row = getattr(res, "fetchone", lambda: None)() + val = row[0] if row else None + return (float(val) if val is not None else None, sql) + def _execute_sql(self, sql: str) -> SDF: """ Central Spark SQL runner. @@ -1293,45 +1319,21 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self._execute_sql(f"DROP TABLE IF EXISTS {ident}") - -# ────────────────────────── local helpers / shim ────────────────────────── -class _SparkResult: - """Tiny result shim to mimic duckdb/psycopg fetch API in tests.""" - - def __init__(self, rows: list[tuple]): - self._rows = rows - - def fetchall(self) -> list[tuple]: - return self._rows - - def fetchone(self) -> tuple | None: - return self._rows[0] if self._rows else None - - -class _SparkConnShim: # pragma: no cover - """Provide .execute(sql) with fetch* for test utilities.""" - - def __init__(self, spark: SparkSession): - self._spark = spark - - def execute(self, sql: str, params: Any | None = None) -> _SparkResult: - if params: - # Minimal positional param interpolation for tests is intentionally not implemented. - # All internal calls use plain SQL strings for Spark. - raise NotImplementedError("SparkConnShim does not support parametrized SQL") - df = self._spark.sql(sql) - rows = [tuple(r) for r in df.collect()] - return _SparkResult(rows) - - -def _split_db_table(qualified: str) -> tuple[str | None, str]: - """ - Split "db.table" → (db, table); backticks allowed. - Returns (None, name) if unqualified. - """ - s = qualified.strip("`") - parts = s.split(".") - part_len = 2 - if len(parts) >= part_len: - return parts[-2], parts[-1] - return None, s + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + Spark: use DataFrame schema for `table` and return the Spark SQL type + (simpleString) for the given column, uppercased. + """ + physical = self._physical_identifier(table) + df = self.spark.table(physical) + + col_lower = column.lower() + for field in df.schema.fields: + if field.name.lower() == col_lower: + dt = field.dataType + try: + # e.g. "bigint", "string", "timestamp" + return dt.simpleString().upper() + except Exception: + return str(dt) + return None diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 8c31ad8..e99437d 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -17,7 +17,8 @@ from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors._test_utils import make_fetchable +from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.meta import ensure_meta_table, upsert_meta @@ -86,6 +87,40 @@ def __init__( self._execute_sql(f"create schema if not exists {safe_schema}") self._execute_sql(f"set schema '{self.schema}'") + def execute_test_sql(self, stmt: Any) -> Any: + """ + Execute lightweight SQL for DQ tests using the underlying DuckDB connection. + """ + + def _run_one(s: Any) -> Any: + statement_len = 2 + if ( + isinstance(s, tuple) + and len(s) == statement_len + and isinstance(s[0], str) + and isinstance(s[1], dict) + ): + return self.con.execute(s[0], s[1]) + if isinstance(s, str): + return self.con.execute(s) + if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): + res = None + for item in s: + res = _run_one(item) + return res + return self.con.execute(str(s)) + + return make_fetchable(_run_one(stmt)) + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + now_expr = "cast(now() as timestamp)" + sql = ( + f"select date_part('epoch', {now_expr} - max({ts_col})) " + f"/ 60.0 as delay_min from {table}" + ) + delay = _scalar(self, sql) + return (float(delay) if delay is not None else None, sql) + def _execute_sql(self, sql: str, *args: Any, **kwargs: Any) -> duckdb.DuckDBPyConnection: """ Central DuckDB SQL runner. @@ -666,3 +701,43 @@ def utest_clean_target(self, relation: str) -> None: self._execute_sql(f"drop view if exists {target}") with suppress(Exception): self._execute_sql(f"drop table if exists {target}") + + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + DuckDB: read `data_type` from information_schema.columns. + """ + if "." in table: + schema, table_name = table.split(".", 1) + else: + schema, table_name = None, table + + table_lower = table_name.lower() + column_lower = column.lower() + + if schema: + rows = self._execute_sql( + """ + select data_type + from information_schema.columns + where lower(table_name) = lower(?) + and lower(table_schema)= lower(?) + and lower(column_name) = lower(?) + order by table_schema, ordinal_position + limit 1 + """, + [table_lower, schema.lower(), column_lower], + ).fetchall() + else: + rows = self._execute_sql( + """ + select data_type + from information_schema.columns + where lower(table_name) = lower(?) + and lower(column_name) = lower(?) + order by table_schema, ordinal_position + limit 1 + """, + [table_lower, column_lower], + ).fetchall() + + return rows[0][0] if rows else None diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index dacc4d8..0d3f7ef 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -2,20 +2,22 @@ import json from collections.abc import Callable, Iterable from time import perf_counter -from typing import Any +from typing import Any, cast import pandas as pd from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError +from sqlalchemy.sql import Executable +from sqlalchemy.sql.elements import ClauseElement from fastflowtransform.core import Node from fastflowtransform.errors import ModelExecutionError, ProfileConfigError from fastflowtransform.executors._budget_runner import run_sql_with_budget -from fastflowtransform.executors._shims import SAConnShim from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin -from fastflowtransform.executors.base import BaseExecutor +from fastflowtransform.executors._test_utils import make_fetchable +from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.meta import ensure_meta_table, upsert_meta @@ -54,8 +56,39 @@ def __init__(self, dsn: str, schema: str | None = None): f"Failed to ensure schema '{self.schema}' exists: {exc}" ) from exc - # ⇣ fastflowtransform.testing expects executor.con.execute("SQL") - self.con = SAConnShim(self.engine, schema=self.schema) + def execute_test_sql(self, stmt: Any) -> Any: + """ + Execute lightweight SQL for DQ tests using a transactional connection. + """ + + def _run_one(s: Any, conn: Connection) -> Any: + statement_len = 2 + if ( + isinstance(s, tuple) + and len(s) == statement_len + and isinstance(s[0], str) + and isinstance(s[1], dict) + ): + return conn.execute(text(s[0]), s[1]) + if isinstance(s, str): + return conn.execute(text(s)) + if isinstance(s, ClauseElement): + return conn.execute(cast(Executable, s)) + if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): + res = None + for item in s: + res = _run_one(item, conn) + return res + return conn.execute(text(str(s))) + + with self.engine.begin() as conn: + self._set_search_path(conn) + return make_fetchable(_run_one(stmt, conn)) + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + sql = f"select date_part('epoch', now() - max({ts_col})) / 60.0 as delay_min from {table}" + delay = _scalar(self, sql) + return (float(delay) if delay is not None else None, sql) def _execute_sql_core( self, @@ -274,7 +307,7 @@ def _quote_identifier(self, ident: str) -> str: def _qualified(self, relname: str, schema: str | None = None) -> str: return self._qualify_identifier(relname, schema=schema) - def _set_search_path(self, conn: Connection | SAConnShim) -> None: + def _set_search_path(self, conn: Connection) -> None: if self.schema: conn.execute(text(f"SET LOCAL search_path = {self._q_ident(self.schema)}")) @@ -636,3 +669,37 @@ def utest_clean_target(self, relation: str) -> None: conn.execute(text(f"DROP VIEW IF EXISTS {qualified} CASCADE")) else: # table conn.execute(text(f"DROP TABLE IF EXISTS {qualified} CASCADE")) + + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + Postgres: read `data_type` from information_schema.columns for the + current schema (or an explicit schema if table is qualified). + """ + if "." in table: + schema, table_name = table.split(".", 1) + else: + schema, table_name = None, table + + if schema: + sql = """ + select data_type + from information_schema.columns + where lower(table_schema) = lower(:schema) + and lower(table_name) = lower(:table) + and lower(column_name) = lower(:column) + limit 1 + """ + params = {"schema": schema, "table": table_name, "column": column} + else: + sql = """ + select data_type + from information_schema.columns + where table_schema = current_schema() + and lower(table_name) = lower(:table) + and lower(column_name) = lower(:column) + limit 1 + """ + params = {"table": table_name, "column": column} + + rows = self._execute_sql(sql, params).fetchall() + return rows[0][0] if rows else None diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 400bf97..f6aeac2 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -13,6 +13,7 @@ from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin +from fastflowtransform.executors._test_utils import make_fetchable, rows_to_tuples from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats @@ -39,8 +40,32 @@ def __init__(self, cfg: dict): self.allow_create_schema: bool = bool(cfg["allow_create_schema"]) self._ensure_schema() - # Provide a tiny testing shim so tests can call executor.con.execute("SQL") - self.con = _SFCursorShim(self.session) + def execute_test_sql(self, stmt: Any) -> Any: + """ + Execute lightweight SQL for DQ tests via Snowpark and return fetchable rows. + """ + + def _run_one(s: Any) -> Any: + if isinstance(s, str): + return rows_to_tuples(self._execute_sql(s).collect()) + if isinstance(s, Iterable) and not isinstance(s, (bytes, bytearray, str)): + res = None + for item in s: + res = _run_one(item) + return res + return rows_to_tuples(self._execute_sql(str(s)).collect()) + + return make_fetchable(_run_one(stmt)) + + def compute_freshness_delay_minutes(self, table: str, ts_col: str) -> tuple[float | None, str]: + sql = ( + f"select DATEDIFF('minute', max({ts_col}), CURRENT_TIMESTAMP())::float as delay_min " + f"from {table}" + ) + res = self.execute_test_sql(sql) + row = getattr(res, "fetchone", lambda: None)() + val = row[0] if row else None + return (float(val) if val is not None else None, sql) # ---------- Cost estimation & central execution ---------- @@ -625,35 +650,27 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self.session.sql(f"DROP TABLE IF EXISTS {qualified}").collect() + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + Snowflake: read DATA_TYPE from database.information_schema.columns. + """ + db_ident = self._q(self.database) + schema_lit = self.schema.replace("'", "''").upper() + table_name = table.split(".")[-1] + table_lit = table_name.replace("'", "''").upper() + col_lit = column.replace("'", "''").upper() + + sql = f""" + select data_type + from {db_ident}.information_schema.columns + where upper(table_schema) = '{schema_lit}' + and upper(table_name) = '{table_lit}' + and upper(column_name) = '{col_lit}' + limit 1 + """ -# ────────────────────────── local testing shim ─────────────────────────── -class _SFCursorShim: - """Very small shim to expose .execute(...).fetch* for tests.""" - - def __init__(self, session: Session): - self._session = session - - def execute(self, sql: str, params: Any | None = None) -> _SFResult: - if params: - # Parametrized SQL not needed in our internal calls - raise NotImplementedError("Snowflake shim does not support parametrized SQL") - rows = self._session.sql(sql).collect() - - if rows: - cols = list(rows[0].asDict().keys()) - as_tuples = [tuple(row.asDict()[c] for c in cols) for row in rows] - else: - as_tuples = [] - - return _SFResult(as_tuples) - - -class _SFResult: - def __init__(self, rows: list[tuple]): - self._rows = rows - - def fetchall(self) -> list[tuple]: - return self._rows - - def fetchone(self) -> tuple | None: - return self._rows[0] if self._rows else None + rows = self._execute_sql(sql).collect() + if not rows: + return None + # first column of first row + return str(rows[0][0]) if rows[0] and rows[0][0] is not None else None diff --git a/src/fastflowtransform/incremental.py b/src/fastflowtransform/incremental.py index cc70c65..dc91219 100644 --- a/src/fastflowtransform/incremental.py +++ b/src/fastflowtransform/incremental.py @@ -4,8 +4,6 @@ from collections.abc import Callable, Sequence from typing import Any -from sqlalchemy import text as _sa_text - from fastflowtransform.core import relation_for from fastflowtransform.errors import ModelExecutionError @@ -71,30 +69,6 @@ def _is_merge_not_supported_error(exc: Exception) -> bool: # ---------- Helper ---------- -def _exec_sql(exe: Any, sql: str) -> None: - """Best-effort SQL execution across engines (DuckDB/PG/Snowflake/BQ shims).""" - # Prefer an engine-provided '_execute_sql' hook if available. - hook = getattr(exe, "_execute_sql", None) - if callable(hook): - hook(sql) - return - - if hasattr(exe, "con") and hasattr(exe.con, "execute"): # DuckDB / BQ shim etc. - exe.con.execute(sql) - return - if hasattr(exe, "engine"): # SQLAlchemy Engine - with exe.engine.begin() as conn: - conn.execute(_sa_text(sql)) - return - if hasattr(exe, "execute"): # BigQuery-like shim - exe.execute(sql) - return - if hasattr(exe, "run_sql_raw"): - exe.run_sql_raw(sql) - return - raise RuntimeError("No suitable raw-SQL execution path on executor") - - def _safe_exists(executor: Any, relation: Any) -> bool: try: return bool(executor.exists_relation(relation)) @@ -181,7 +155,7 @@ def _full_refresh_table(executor: Any, relation: Any, rendered_sql: str) -> None try: executor.create_table_as(relation, rendered_sql) except Exception: - _exec_sql(executor, f"create or replace table {target} as {rendered_sql}") + executor._execute_sql(f"create or replace table {target} as {rendered_sql}") UniqueKey = str | Sequence[str] | None diff --git a/src/fastflowtransform/schema_loader.py b/src/fastflowtransform/schema_loader.py index 02520aa..dee6784 100644 --- a/src/fastflowtransform/schema_loader.py +++ b/src/fastflowtransform/schema_loader.py @@ -39,7 +39,7 @@ def __post_init__(self): def load_schema_tests(project_dir: Path) -> list[TestSpec]: """ - Loads schema yamls (version: 2) in models/**.yml (& schema.yml), + Loads schema yamls (version: 1) in models/**.yml (& schema.yml), and returns normalized TestSpec objects. """ project_dir = Path(project_dir) @@ -54,7 +54,7 @@ def load_schema_tests(project_dir: Path) -> list[TestSpec]: files = sorted(set(files)) specs: list[TestSpec] = [] - version = 2 + version = 1 for yml in files: try: data = yaml.safe_load(yml.read_text(encoding="utf-8")) or {} diff --git a/src/fastflowtransform/source_freshness.py b/src/fastflowtransform/source_freshness.py index 9565b83..5ad915a 100644 --- a/src/fastflowtransform/source_freshness.py +++ b/src/fastflowtransform/source_freshness.py @@ -59,7 +59,6 @@ def _relation_for_source( def run_source_freshness( executor: Any, *, - con: Any | None = None, engine: str | None = None, ) -> list[SourceFreshnessResult]: """ @@ -70,7 +69,6 @@ def run_source_freshness( """ engine_label = engine or getattr(executor, "engine_name", None) or "" engine_norm = engine_label.lower() - connection = con or getattr(executor, "con", executor) results: list[SourceFreshnessResult] = [] sources = getattr(REGISTRY, "sources", {}) or {} @@ -102,7 +100,7 @@ def run_source_freshness( if threshold is None: # should not happen given the guard above continue - _freshness_test(connection, relation, loaded_at, max_delay_minutes=int(threshold)) + _freshness_test(executor, relation, loaded_at, max_delay_minutes=int(threshold)) # If we reach here, delay <= threshold; we can recompute the actual delay # by re-running with a large threshold and inferring from error message # OR we can simply omit it. Keep it simple and omit for now. diff --git a/src/fastflowtransform/testing/base.py b/src/fastflowtransform/testing/base.py index 25f71a3..f0500b7 100644 --- a/src/fastflowtransform/testing/base.py +++ b/src/fastflowtransform/testing/base.py @@ -1,93 +1,18 @@ # src/fastflowtransform/testing/base.py from __future__ import annotations -from collections.abc import Iterable, Sequence -from typing import Any, cast +import re +from collections.abc import Sequence +from typing import Any -from sqlalchemy import text -from sqlalchemy.engine import Connection as _SAConn from sqlalchemy.sql.elements import ClauseElement +from fastflowtransform.config.contracts import PhysicalTypeConfig +from fastflowtransform.executors.base import BaseExecutor, _scalar from fastflowtransform.logging import dprint from fastflowtransform.utils.timefmt import format_duration_minutes -# ===== Execution helpers (consistent for DuckDB / Postgres / BigQuery) == - - -def _exec(con: Any, sql: Any) -> Any: - """ - Execute SQL robustly and consistently. - Accepts: - - str - - (str, params: dict) - - SQLAlchemy ClauseElement (if available) - - Sequence[ of the above types ] -> executed sequentially (return result of the last) - Delegates to `con.execute(sql)` when available (e.g. DuckDB or our executor shims). - Fallback: use Connection.begin() + SQLAlchemy text(). - """ - # 1) Direct delegation to existing con.execute (e.g. DuckDB, our PG/BQ shims) - if hasattr(con, "execute"): - dprint("con.execute <-", _pretty_sql(sql)) - try: - if isinstance(con, _SAConn) or "sqlalchemy" in type(con).__module__: - sql_tuple_len = 2 - if isinstance(sql, str): - return con.execute(text(sql)) - if ( - isinstance(sql, tuple) - and len(sql) == sql_tuple_len - and isinstance(sql[0], str) - and isinstance(sql[1], dict) - ): - return con.execute(text(sql[0]), sql[1]) - return cast(Any, con).execute(sql) - except Exception: - # The check name is unknown at this point → the caller adds that context - raise - - # 2) Fallback: generic SQLAlchemy handling - - statement_tuple_len = 2 - - def _exec_one(c: Any, stmt: Any) -> Any: - if ( - isinstance(stmt, tuple) - and len(stmt) == statement_tuple_len - and isinstance(stmt[0], str) - and isinstance(stmt[1], dict) - ): - dprint("run (sql, params):", stmt[0], stmt[1]) - return c.execute(text(stmt[0]), stmt[1]) - if isinstance(stmt, ClauseElement): - dprint("run ClauseElement") - return c.execute(stmt) - if isinstance(stmt, str): - dprint("run sql:", stmt) - return c.execute(text(stmt)) - # Sequences (recursive) - if isinstance(stmt, Iterable) and not isinstance(stmt, (bytes, bytearray, str)): - res = None - for s in stmt: - res = _exec_one(c, s) - return res - raise TypeError(f"Unsupported statement type: {type(stmt)} → {stmt!r}") - - if hasattr(con, "begin"): - with con.begin() as c: - return _exec_one(c, sql) - # Last resort: best effort - return _exec_one(con, sql) - - -def _scalar(con: Any, sql: Any) -> Any: - """Execute SQL and return the first column of the first row (or None).""" - try: - res = _exec(con, sql) - except Exception as e: - # Caller adds the check name in _fail() - raise e - row = getattr(res, "fetchone", lambda: None)() - return None if row is None else row[0] +# ===== Execution helpers ================== def _fail(check: str, table: str, column: str | None, sql: str, detail: str) -> None: @@ -129,7 +54,7 @@ def lit(v: Any) -> str: def accepted_values( - con: Any, table: str, column: str, *, values: list[Any], where: str | None = None + executor: BaseExecutor, table: str, column: str, *, values: list[Any], where: str | None = None ) -> None: """ Fail if any non-NULL value of table.column is outside the set 'values'. @@ -143,7 +68,7 @@ def accepted_values( sql = f"select count(*) from {table} where {column} is not null and {column} not in ({in_list})" - n = _scalar(con, sql) + n = _scalar(executor, sql) if int(n or 0) > 0: sample_sql = f"select distinct {column} from {table} where {column} is not null" if in_list: @@ -152,7 +77,7 @@ def accepted_values( sql += f" and ({where})" sample_sql += f" and ({where})" sample_sql += " limit 5" - rows = [r[0] for r in _exec(con, sample_sql).fetchall()] + rows = [r[0] for r in executor.execute_test_sql(sample_sql).fetchall()] raise TestFailure(f"{table}.{column} has {n} value(s) outside accepted set; e.g. {rows}") @@ -182,13 +107,13 @@ def _wrap_db_error( return TestFailure("\n".join(msg)) -def not_null(con: Any, table: str, column: str, where: str | None = None) -> None: +def not_null(executor: BaseExecutor, table: str, column: str, where: str | None = None) -> None: """Fails if any non-filtered row has NULL in `column`.""" sql = f"select count(*) from {table} where {column} is null" if where: sql += f" and ({where})" try: - c = _scalar(con, sql) + c = _scalar(executor, sql) except Exception as e: raise _wrap_db_error("not_null", table, column, sql, e) from e dprint("not_null:", sql, "=>", c) @@ -196,7 +121,7 @@ def not_null(con: Any, table: str, column: str, where: str | None = None) -> Non _fail("not_null", table, column, sql, f"has {c} NULL-values") -def unique(con: Any, table: str, column: str, where: str | None = None) -> None: +def unique(executor: BaseExecutor, table: str, column: str, where: str | None = None) -> None: """Fails if any duplicate appears in `column` within the (optionally) filtered set.""" sql = ( "select count(*) from (select {col} as v, " @@ -205,7 +130,7 @@ def unique(con: Any, table: str, column: str, where: str | None = None) -> None: w = f" where ({where})" if where else "" sql = sql.format(col=column, tbl=table, w=w) try: - c = _scalar(con, sql) + c = _scalar(executor, sql) except Exception as e: raise _wrap_db_error("unique", table, column, sql, e) from e dprint("unique:", sql, "=>", c) @@ -213,25 +138,111 @@ def unique(con: Any, table: str, column: str, where: str | None = None) -> None: _fail("unique", table, column, sql, f"contains {c} duplicates") -def greater_equal(con: Any, table: str, column: str, threshold: float = 0.0) -> None: +def greater_equal(executor: BaseExecutor, table: str, column: str, threshold: float = 0.0) -> None: sql = f"select count(*) from {table} where {column} < {threshold}" - c = _scalar(con, sql) + c = _scalar(executor, sql) dprint("greater_equal:", sql, "=>", c) if c and c != 0: raise TestFailure(f"{table}.{column} has {c} values < {threshold}") -def non_negative_sum(con: Any, table: str, column: str) -> None: +def between( + executor: BaseExecutor, + table: str, + column: str, + *, + min_value: float | int | None = None, + max_value: float | int | None = None, +) -> None: + """ + Fail if any non-NULL value of table.column is outside the inclusive + range [min_value, max_value]. If one bound is None, only the other + is enforced. + """ + if min_value is None and max_value is None: + return + + conds: list[str] = [] + if min_value is not None: + conds.append(f"{column} < {min_value}") + if max_value is not None: + conds.append(f"{column} > {max_value}") + + where_expr = " or ".join(conds) + sql = f"select count(*) from {table} where {column} is not null and ({where_expr})" + c = _scalar(executor, sql) + dprint("between:", sql, "=>", c) + + if c and c != 0: + if min_value is not None and max_value is not None: + raise TestFailure( + f"{table}.{column} has {c} value(s) outside inclusive range " + f"[{min_value}, {max_value}]" + ) + elif min_value is not None: + raise TestFailure(f"{table}.{column} has {c} value(s) < {min_value}") + else: + raise TestFailure(f"{table}.{column} has {c} value(s) > {max_value}") + + +def regex_match( + executor: BaseExecutor, + table: str, + column: str, + pattern: str, + where: str | None = None, +) -> None: + """ + Fail if any non-NULL value in table.column does not match the given + Python regex pattern. This is implemented client-side for engine + independence: + + SELECT column FROM table [WHERE ...] + -> evaluate in Python -> fail on first few mismatches. + """ + try: + regex = re.compile(pattern) + except re.error as exc: + raise TestFailure(f"Invalid regex pattern {pattern!r} for {table}.{column}: {exc}") from exc + + sql = f"select {column} from {table}" + if where: + sql += f" where ({where})" + + res = executor.execute_test_sql(sql) + rows: list = getattr(res, "fetchall", lambda: [])() + + bad_values: list[Any] = [] + for row in rows: + val = row[0] + if val is None: + continue + if not regex.match(str(val)): + bad_values.append(val) + if len(bad_values) >= 5: + break + + dprint("regex_match:", sql, "=> bad_values:", bad_values) + + if bad_values: + raise TestFailure( + f"{table}.{column} has values not matching regex {pattern!r}; examples: {bad_values}" + ) + + +def non_negative_sum(executor: BaseExecutor, table: str, column: str) -> None: sql = f"select coalesce(sum({column}),0) from {table}" - s = _scalar(con, sql) + s = _scalar(executor, sql) dprint("non_negative_sum:", sql, "=>", s) if s is not None and s < 0: raise TestFailure(f"sum({table}.{column}) is negative: {s}") -def row_count_between(con: Any, table: str, min_rows: int = 1, max_rows: int | None = None) -> None: +def row_count_between( + executor: BaseExecutor, table: str, min_rows: int = 1, max_rows: int | None = None +) -> None: sql = f"select count(*) from {table}" - c = _scalar(con, sql) + c = _scalar(executor, sql) dprint("row_count_between:", sql, "=>", c) if c is None or c < min_rows: raise TestFailure(f"{table} has too few rows: {c} < {min_rows}") @@ -239,142 +250,77 @@ def row_count_between(con: Any, table: str, min_rows: int = 1, max_rows: int | N raise TestFailure(f"{table} has too many rows: {c} > {max_rows}") -def _freshness_probe(con: Any, table: str, ts_col: str) -> Any: +def _freshness_probe(executor: BaseExecutor, table: str, ts_col: str) -> Any: """Read max(ts_col) and wrap engine errors with context.""" probe_sql = f"select max({ts_col}) from {table}" try: - return _scalar(con, probe_sql) + return _scalar(executor, probe_sql) except Exception as e: # Column missing or other metadata-related DB error raise _wrap_db_error("freshness", table, ts_col, probe_sql, e) from e -def _detect_engine(con: Any) -> tuple[bool, bool, bool, bool]: +def _resolve_expected_physical( + physical_cfg: PhysicalTypeConfig | None, + engine_key: str, +) -> str | None: """ - Detect engine flavour from the connection object. + Given the PhysicalTypeConfig and an engine key, return the expected + physical type string for that engine, or None if nothing is declared. - Returns: - (is_spark_like, is_bigquery, is_snowflake, is_duckdb) + Precedence: + 1) physical. + 2) physical.default """ - con_type = type(con) - mod = getattr(con_type, "__module__", "") or "" - name = getattr(con_type, "__name__", "") or "" - mod_l = mod.lower() - name_l = name.lower() - - is_spark_like = any(token in mod_l or token in name_l for token in ("spark", "databricks")) - is_bigquery = ( - "bigquery" in mod_l - or "bigquery" in name_l - or str(getattr(con, "marker", "")).upper() == "BQ_SHIM" - ) - is_snowflake = ( - "snowflake" in mod_l - or "snowpark" in mod_l - or "snowflake" in name_l - or "snowpark" in name_l - or hasattr(con, "_session") - ) - is_duckdb = "duckdb" in mod_l or "duckdb" in name_l + if physical_cfg is None: + return None - return is_spark_like, is_bigquery, is_snowflake, is_duckdb + # Engine-specific override + eng_val = getattr(physical_cfg, engine_key, None) + if isinstance(eng_val, str) and eng_val.strip(): + return eng_val.strip() + # Fallback to default + if isinstance(physical_cfg.default, str) and physical_cfg.default.strip(): + return physical_cfg.default.strip() -def _compute_delay_minutes( - con: Any, + return None + + +def column_physical_type( + executor: BaseExecutor, table: str, - ts_col: str, - is_spark_like: bool, - is_bigquery: bool, - is_snowflake: bool, - is_duckdb: bool, -) -> tuple[float | None, str]: + column: str, + physical_cfg: PhysicalTypeConfig | None, +) -> None: """ - Compute delay in minutes for max(ts_col) depending on engine type. - - Returns: - (delay_minutes, sql_used) + Assert that the physical DB type of table.column matches the contract's + PhysicalTypeConfig for the current engine. """ - # Primary SQL (Postgres / DuckDB style) - now_expr = "now()" - if is_duckdb: - now_expr = "cast(now() as timestamp)" + engine_key = executor.engine_name + expected = _resolve_expected_physical(physical_cfg, engine_key) + if not expected: + # No expectation configured for this engine → nothing to enforce. + return - sql_primary = ( - f"select date_part('epoch', {now_expr} - max({ts_col})) / 60.0 as delay_min from {table}" - ) + actual = executor.introspect_column_physical_type(table, column) + if actual is None: + raise TestFailure( + f"[column_physical_type] Could not determine physical type for {table}.{column} " + f"(engine={engine_key}). Ensure the table exists and the column name is correct." + ) - # Spark / Databricks: unix_timestamp over timestamps - sql_spark = ( - f"select (unix_timestamp(current_timestamp()) - unix_timestamp(max({ts_col}))) / 60.0 " - f"as delay_min from {table}" - ) + exp_norm = str(expected).strip().lower() + act_norm = str(actual).strip().lower() - # BigQuery: TIMESTAMP_DIFF returns integer minutes; keep float compatibility - sql_bigquery = ( - f"select cast(TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), max({ts_col}), MINUTE) as float64) " - f"as delay_min from {table}" - ) + if exp_norm != act_norm: + raise TestFailure( + f"{table}.{column} has physical type {actual!r}, expected {expected!r} " + f"for engine {engine_key}" + ) - # Snowflake: DATEDIFF on minutes; cast to float to align with other engines - sql_snowflake = ( - f"select DATEDIFF('minute', max({ts_col}), " - f"CURRENT_TIMESTAMP())::float as delay_min from {table}" - ) - delay: float | None = None - sql_used: str - - if is_spark_like: - # For Spark-like engines we never send the date_part('epoch', ...) SQL, - # to avoid INVALID_EXTRACT_FIELD noise in the logs. - sql_used = sql_spark - try: - delay = _scalar(con, sql_spark) - except Exception as e: - raise _wrap_db_error("freshness", table, ts_col, sql_spark, e) from e - - elif is_bigquery: - sql_used = sql_bigquery - try: - delay = _scalar(con, sql_bigquery) - except Exception as e: - # BigQuery error messages don't mention EXTRACT/EPOCH; surface directly. - raise _wrap_db_error("freshness", table, ts_col, sql_bigquery, e) from e - - elif is_snowflake: - sql_used = sql_snowflake - try: - delay = _scalar(con, sql_snowflake) - except Exception as e: - raise _wrap_db_error("freshness", table, ts_col, sql_snowflake, e) from e - - else: - # Non-Spark engines: try the Postgres/DuckDB expression first. - sql_used = sql_primary - try: - delay = _scalar(con, sql_primary) - except Exception as e: - txt = str(e).lower() - # If the engine complains about invalid extract fields / epoch, - # attempt the Spark-style expression as a fallback. - if ( - "invalid_extract_field" in txt - or "cannot extract" in txt - or ("epoch" in txt and "extract" in txt) - ): - sql_used = sql_spark - try: - delay = _scalar(con, sql_spark) - except Exception as e2: - raise _wrap_db_error("freshness", table, ts_col, sql_spark, e2) from e2 - else: - raise _wrap_db_error("freshness", table, ts_col, sql_primary, e) from e - - return delay, sql_used - - -def freshness(con: Any, table: str, ts_col: str, max_delay_minutes: int) -> None: +def freshness(executor: BaseExecutor, table: str, ts_col: str, max_delay_minutes: int) -> None: """ Fail if the latest timestamp in `ts_col` is older than `max_delay_minutes`. @@ -391,7 +337,7 @@ def freshness(con: Any, table: str, ts_col: str, max_delay_minutes: int) -> None we do not trigger noisy INVALID_EXTRACT_FIELD logs from the planner. """ # 1) Probe type: read max(ts_col) and inspect the Python value that comes back. - probe = _freshness_probe(con, table, ts_col) + probe = _freshness_probe(executor, table, ts_col) # If max(...) comes back as a string, this is almost certainly a typed-as-VARCHAR # timestamp column. Fail with a clear hint instead of letting the engine throw. @@ -404,17 +350,8 @@ def freshness(con: Any, table: str, ts_col: str, max_delay_minutes: int) -> None "and then reference that column in the freshness test." ) - # 2) Compute delay based on connection type. - is_spark_like, is_bigquery, is_snowflake, is_duckdb = _detect_engine(con) - delay, sql_used = _compute_delay_minutes( - con=con, - table=table, - ts_col=ts_col, - is_spark_like=is_spark_like, - is_bigquery=is_bigquery, - is_snowflake=is_snowflake, - is_duckdb=is_duckdb, - ) + # 2) Compute delay based on executor (engine-specific hook). + delay, sql_used = executor.compute_freshness_delay_minutes(table, ts_col) dprint("freshness:", sql_used, "=>", delay) @@ -429,15 +366,15 @@ def freshness(con: Any, table: str, ts_col: str, max_delay_minutes: int) -> None # ===== Cross-table reconciliations (FF-310) ====================================== -def _scalar_where(con: Any, table: str, expr: str, where: str | None = None) -> Any: +def _scalar_where(executor: BaseExecutor, table: str, expr: str, where: str | None = None) -> Any: """Return the first scalar from `SELECT {expr} FROM {table} [WHERE ...]`.""" sql = f"select {expr} from {table}" + (f" where {where}" if where else "") dprint("reconcile:", sql) - return _scalar(con, sql) + return _scalar(executor, sql) def reconcile_equal( - con: Any, + executor: BaseExecutor, left: dict, right: dict, abs_tolerance: float | None = None, @@ -448,8 +385,8 @@ def reconcile_equal( Both sides are dictionaries: {"table": str, "expr": str, "where": Optional[str]}. If both tolerances are omitted, exact equality is enforced. """ - L = _scalar_where(con, left["table"], left["expr"], left.get("where")) - R = _scalar_where(con, right["table"], right["expr"], right.get("where")) + L = _scalar_where(executor, left["table"], left["expr"], left.get("where")) + R = _scalar_where(executor, right["table"], right["expr"], right.get("where")) if L is None or R is None: raise TestFailure(f"One side is NULL (left={L}, right={R})") diff = abs(float(L) - float(R)) @@ -475,11 +412,11 @@ def reconcile_equal( def reconcile_ratio_within( - con: Any, left: dict, right: dict, min_ratio: float, max_ratio: float + executor: BaseExecutor, left: dict, right: dict, min_ratio: float, max_ratio: float ) -> None: """Assert min_ratio <= (left/right) <= max_ratio.""" - L = _scalar_where(con, left["table"], left["expr"], left.get("where")) - R = _scalar_where(con, right["table"], right["expr"], right.get("where")) + L = _scalar_where(executor, left["table"], left["expr"], left.get("where")) + R = _scalar_where(executor, right["table"], right["expr"], right.get("where")) if L is None or R is None: raise TestFailure(f"One side is NULL (left={L}, right={R})") eps = 1e-12 @@ -491,10 +428,12 @@ def reconcile_ratio_within( ) -def reconcile_diff_within(con: Any, left: dict, right: dict, max_abs_diff: float) -> None: +def reconcile_diff_within( + executor: BaseExecutor, left: dict, right: dict, max_abs_diff: float +) -> None: """Assert |left - right| <= max_abs_diff.""" - L = _scalar_where(con, left["table"], left["expr"], left.get("where")) - R = _scalar_where(con, right["table"], right["expr"], right.get("where")) + L = _scalar_where(executor, left["table"], left["expr"], left.get("where")) + R = _scalar_where(executor, right["table"], right["expr"], right.get("where")) if L is None or R is None: raise TestFailure(f"One side is NULL (left={L}, right={R})") diff = abs(float(L) - float(R)) @@ -503,7 +442,7 @@ def reconcile_diff_within(con: Any, left: dict, right: dict, max_abs_diff: float def reconcile_coverage( - con: Any, + executor: BaseExecutor, source: dict, target: dict, source_where: str | None = None, @@ -521,14 +460,14 @@ def reconcile_coverage( left join tgt t on s.k = t.k where t.k is null """ - missing = _scalar(con, sql) + missing = _scalar(executor, sql) dprint("reconcile_coverage:", sql, "=>", missing) if missing and missing != 0: raise TestFailure(f"Coverage failed: {missing} source keys missing in target") def relationships( - con: Any, + executor: BaseExecutor, table: str, field: str, to_table: str, @@ -551,7 +490,7 @@ def relationships( where p.k is null """ try: - missing = _scalar(con, sql) + missing = _scalar(executor, sql) except Exception as e: raise _wrap_db_error("relationships", table, field, sql, e) from e dprint("relationships:", sql, "=>", missing) diff --git a/src/fastflowtransform/testing/registry.py b/src/fastflowtransform/testing/registry.py index 950e4fc..76d1b96 100644 --- a/src/fastflowtransform/testing/registry.py +++ b/src/fastflowtransform/testing/registry.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError from fastflowtransform.core import REGISTRY +from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.logging import get_logger from fastflowtransform.testing import base as testing from fastflowtransform.testing.base import _scalar @@ -26,7 +27,7 @@ class Runner(Protocol): __name__: str def __call__( - self, con: Any, table: str, column: str | None, params: dict[str, Any] + self, executor: BaseExecutor, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: ... @@ -35,11 +36,6 @@ def __call__( # --------------------------------------------------------------------------- -def _example_where(where: str | None) -> str: - """Return a ' where (...)' suffix if where is provided, otherwise empty string.""" - return f" where ({where})" if where else "" - - def _format_param_validation_error( kind: str, origin: str | None, @@ -71,7 +67,7 @@ def _format_param_validation_error( def run_not_null( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: where = params.get("where") example = f"select count(*) from {table} where {column} is null" + ( @@ -82,14 +78,14 @@ def run_not_null( return False, "missing required parameter: column", example col = column try: - testing.not_null(con, table, col, where=where) + testing.not_null(executor, table, col, where=where) return True, None, example except testing.TestFailure as e: return False, str(e), example def run_unique( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: where = params.get("where") example = ( @@ -101,14 +97,14 @@ def run_unique( return False, "missing required parameter: column", example col = column try: - testing.unique(con, table, col, where=where) + testing.unique(executor, table, col, where=where) return True, None, example except testing.TestFailure as e: return False, str(e), example def run_accepted_values( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.accepted_values.""" values = params.get("values") or [] @@ -133,14 +129,14 @@ def run_accepted_values( col = column try: - testing.accepted_values(con, table, col, values=values, where=where) + testing.accepted_values(executor, table, col, values=values, where=where) return True, None, example except testing.TestFailure as e: return False, str(e), example def run_greater_equal( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.greater_equal (column >= threshold).""" threshold = float(params.get("threshold", 0.0)) @@ -151,14 +147,117 @@ def run_greater_equal( example = f"select count(*) from {table} where {column} < {threshold}" col = column try: - testing.greater_equal(con, table, col, threshold=threshold) + testing.greater_equal(executor, table, col, threshold=threshold) + return True, None, example + except testing.TestFailure as e: + return False, str(e), example + + +def run_between( + executor: Any, table: str, column: str | None, params: dict[str, Any] +) -> tuple[bool, str | None, str | None]: + """Runner for testing.between (inclusive numeric range).""" + if column is None: + example = f"select count(*) from {table} where < or > " + return False, "missing required parameter: column", example + + min_val = params.get("min") + max_val = params.get("max") + + if min_val is None and max_val is None: + example = f"-- between: no min/max provided for {table}.{column}" + return ( + False, + "between test requires at least one of 'min' or 'max'", + example, + ) + + conds: list[str] = [] + if min_val is not None: + conds.append(f"{column} < {min_val}") + if max_val is not None: + conds.append(f"{column} > {max_val}") + where_expr = " or ".join(conds) + example = f"select count(*) from {table} where {column} is not null and ({where_expr})" + + col = column + try: + testing.between( + executor, + table, + col, + min_value=min_val, + max_value=max_val, + ) + return True, None, example + except testing.TestFailure as e: + return False, str(e), example + + +def run_regex_match( + executor: Any, table: str, column: str | None, params: dict[str, Any] +) -> tuple[bool, str | None, str | None]: + """Runner for testing.regex_match (Python-side regex evaluation).""" + pattern = params.get("pattern") or params.get("regex") + where = params.get("where") + + if column is None: + example = f"select {column or ''} from {table}" + return False, "missing required parameter: column", example + + if not pattern: + example = f"select {column} from {table} -- pattern missing" + return False, "missing required parameter: pattern", example + + example = f"select {column} from {table}" + if where: + example += f" where ({where})" + + col = column + try: + testing.regex_match( + executor, + table, + col, + pattern=str(pattern), + where=where, + ) + return True, None, example + except testing.TestFailure as e: + return False, str(e), example + + +def run_column_physical_type( + executor: Any, table: str, column: str | None, params: dict[str, Any] +) -> tuple[bool, str | None, str | None]: + """ + Runner for testing.column_physical_type (schema/DDL assertion). + + Params: + - physical: string or mapping {engine_key: type, default: type} + """ + physical_cfg = params.get("physical") + + if column is None: + example = "-- column_physical_type: column parameter is required" + return False, "missing required parameter: column", example + + if physical_cfg is None: + # Nothing to enforce; treat as noop (passes). + example = f"-- column_physical_type: no 'physical' configured for {table}.{column}" + return True, None, example + + example = f"-- physical type check for {table}.{column} via information_schema.columns" + + try: + testing.column_physical_type(executor, table, column, physical_cfg) return True, None, example except testing.TestFailure as e: return False, str(e), example def run_non_negative_sum( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.non_negative_sum.""" if column is None: @@ -168,7 +267,7 @@ def run_non_negative_sum( example = f"select coalesce(sum({column}), 0) from {table}" col = column try: - testing.non_negative_sum(con, table, col) + testing.non_negative_sum(executor, table, col) return True, None, example except testing.TestFailure as e: return False, str(e), example @@ -180,7 +279,7 @@ def run_non_negative_sum( def run_row_count_between( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.row_count_between.""" min_rows = int(params.get("min_rows", 1)) @@ -189,14 +288,14 @@ def run_row_count_between( example = f"select count(*) from {table}" try: - testing.row_count_between(con, table, min_rows=min_rows, max_rows=max_rows) + testing.row_count_between(executor, table, min_rows=min_rows, max_rows=max_rows) return True, None, example except testing.TestFailure as e: return False, str(e), example def run_freshness( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.freshness (max timestamp delay in minutes).""" if column is None: @@ -222,7 +321,7 @@ def run_freshness( col = column try: - testing.freshness(con, table, col, max_delay_minutes=max_delay_int) + testing.freshness(executor, table, col, max_delay_minutes=max_delay_int) return True, None, example except testing.TestFailure as e: return False, str(e), example @@ -291,7 +390,7 @@ def _example_relationship_sql( def run_reconcile_equal( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.reconcile_equal (left == right within tolerances).""" left = params.get("left") @@ -307,7 +406,7 @@ def run_reconcile_equal( try: testing.reconcile_equal( - con, + executor, left=left, right=right, abs_tolerance=abs_tol, @@ -319,7 +418,7 @@ def run_reconcile_equal( def run_reconcile_ratio_within( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.reconcile_ratio_within (min_ratio <= L/R <= max_ratio).""" left = params.get("left") @@ -339,7 +438,7 @@ def run_reconcile_ratio_within( try: testing.reconcile_ratio_within( - con, + executor, left=left, right=right, min_ratio=float(min_ratio), @@ -351,7 +450,7 @@ def run_reconcile_ratio_within( def run_reconcile_diff_within( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.reconcile_diff_within (|L - R| <= max_abs_diff).""" left = params.get("left") @@ -370,7 +469,7 @@ def run_reconcile_diff_within( try: testing.reconcile_diff_within( - con, + executor, left=left, right=right, max_abs_diff=float(max_abs_diff), @@ -381,7 +480,7 @@ def run_reconcile_diff_within( def run_reconcile_coverage( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.reconcile_coverage (anti-join count == 0).""" source = params.get("source") @@ -397,7 +496,7 @@ def run_reconcile_coverage( try: testing.reconcile_coverage( - con, + executor, source=source, target=target, source_where=source_where, @@ -409,7 +508,7 @@ def run_reconcile_coverage( def run_relationships( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: """Runner for testing.relationships (FK-style anti join).""" field = params.get("field") or column @@ -429,7 +528,7 @@ def run_relationships( try: testing.relationships( - con, + executor, table=table, field=field, to_table=to_table, @@ -471,6 +570,10 @@ def run_relationships( "reconcile_ratio_within": run_reconcile_ratio_within, "reconcile_diff_within": run_reconcile_diff_within, "reconcile_coverage": run_reconcile_coverage, + # Contracts helpers + "between": run_between, + "regex_match": run_regex_match, + "column_physical_type": run_column_physical_type, } @@ -557,7 +660,7 @@ def register_sql_test( META_KEYS = {"type", "table", "column", "severity", "tags", "name"} def _runner( - con: Any, table: str, column: str | None, params: dict[str, Any] + executor: Any, table: str, column: str | None, params: dict[str, Any] ) -> tuple[bool, str | None, str | None]: # 1) Strip generic test metadata and validate params if a schema is provided raw_params: dict[str, Any] = dict(params or {}) @@ -596,7 +699,7 @@ def _runner( ) from exc # 3) Execute the SQL: convention here is "fail if count(*) > 0" - n = _scalar(con, sql) + n = _scalar(executor, sql) ok = int(n or 0) == 0 msg: str | None = None if ok else f"{kind} failed: {n} offending row(s)" example_sql = sql diff --git a/tests/common/fixtures.py b/tests/common/fixtures.py index f0a3dd2..c81788f 100644 --- a/tests/common/fixtures.py +++ b/tests/common/fixtures.py @@ -45,13 +45,9 @@ # Snowflake try: - from fastflowtransform.executors.snowflake_snowpark import ( - SnowflakeSnowparkExecutor, - _SFCursorShim, - ) + from fastflowtransform.executors.snowflake_snowpark import SnowflakeSnowparkExecutor except ModuleNotFoundError: # pragma: no cover SnowflakeSnowparkExecutor = None # type: ignore[assignment] - _SFCursorShim = None # type: ignore[assignment] # ---- Jinja env ---------------------------------------------------------------- @@ -450,11 +446,6 @@ def snowflake_executor_fake() -> Any: # Fake Snowflake session and cursor shim. session = FakeSnowflakeSession() ex.session = session - if _SFCursorShim is not None: - ex.con = _SFCursorShim(session) # type: ignore[arg-type] - else: - # Cheap fallback if for some reason the shim isn't available - ex.con = SimpleNamespace(execute=lambda sql, params=None: None) return ex diff --git a/tests/integration/cli/test_cmd/test_test_cmd_schema_merge_integration.py b/tests/integration/cli/test_cmd/test_test_cmd_schema_merge_integration.py index 1941ac5..296099a 100644 --- a/tests/integration/cli/test_cmd/test_test_cmd_schema_merge_integration.py +++ b/tests/integration/cli/test_cmd/test_test_cmd_schema_merge_integration.py @@ -32,7 +32,7 @@ def test_merge_project_yaml_and_schema_yaml(tmp_path: Path): ) (tmp_path / "models" / "users.yml").write_text( """ -version: 2 +version: 1 models: - name: users.ff tags: [schema] @@ -64,9 +64,9 @@ def test_merge_project_yaml_and_schema_yaml(tmp_path: Path): schema_specs = load_schema_tests(tmp_path) legacy_only = _apply_legacy_tag_filter(legacy + schema_specs, ["legacy"], legacy_token=True) - res_legacy = _run_dq_tests(ex.con, legacy_only, ex) + res_legacy = _run_dq_tests(ex, legacy_only) assert all(r.ok for r in res_legacy) schema_only = _apply_legacy_tag_filter(legacy + schema_specs, ["schema"], legacy_token=True) - res_schema = _run_dq_tests(ex.con, schema_only, ex) + res_schema = _run_dq_tests(ex, schema_only) assert all(r.ok or r.severity == "warn" for r in res_schema) diff --git a/tests/integration/executors/duckdb/test_ephemeral_inlining_integration.py b/tests/integration/executors/duckdb/test_ephemeral_inlining_integration.py index 16f95f9..f22daba 100644 --- a/tests/integration/executors/duckdb/test_ephemeral_inlining_integration.py +++ b/tests/integration/executors/duckdb/test_ephemeral_inlining_integration.py @@ -32,7 +32,7 @@ def test_ephemeral_inlining_end_to_end(tmp_path: Path): # sources.yml _w( proj / "sources.yml", - """version: 2 + """version: 1 sources: - name: crm diff --git a/tests/integration/schema_loader/test_schema_yaml_basic_integration.py b/tests/integration/schema_loader/test_schema_yaml_basic_integration.py index 3c491d3..982b124 100644 --- a/tests/integration/schema_loader/test_schema_yaml_basic_integration.py +++ b/tests/integration/schema_loader/test_schema_yaml_basic_integration.py @@ -18,7 +18,7 @@ def test_schema_yaml_runs_basic_checks(tmp_path: Path): ) (tmp_path / "models" / "users.yml").write_text( """ -version: 2 +version: 1 models: - name: users.ff tags: [batch] @@ -44,7 +44,7 @@ def test_schema_yaml_runs_basic_checks(tmp_path: Path): specs = load_schema_tests(tmp_path) specs = _apply_legacy_tag_filter(specs, ["batch"], legacy_token=True) - results = _run_dq_tests(ex.con, specs, ex) + results = _run_dq_tests(ex, specs) error_fails = [r for r in results if (not r.ok) and r.severity != "warn"] assert error_fails == [] diff --git a/tests/integration/schema_loader/test_schema_yaml_registry_mix_integration.py b/tests/integration/schema_loader/test_schema_yaml_registry_mix_integration.py index d80d454..d3294a8 100644 --- a/tests/integration/schema_loader/test_schema_yaml_registry_mix_integration.py +++ b/tests/integration/schema_loader/test_schema_yaml_registry_mix_integration.py @@ -19,7 +19,7 @@ def test_mix_multiple_tests_per_column(tmp_path: Path): ) (tmp_path / "models" / "u.yml").write_text( """ -version: 2 +version: 1 models: - name: u.ff columns: @@ -39,7 +39,7 @@ def test_mix_multiple_tests_per_column(tmp_path: Path): ex = DuckExecutor(":memory:") ex.run_sql(REGISTRY.get_node("u.ff"), env) specs = load_schema_tests(tmp_path) - res = _run_dq_tests(ex.con, specs, ex) + res = _run_dq_tests(ex, specs) # Both should fail with error severity assert any((not r.ok) and r.kind == "unique" for r in res) assert any((not r.ok) and r.kind == "accepted_values" for r in res) diff --git a/tests/integration/streaming/test_smoke_streaming.py b/tests/integration/streaming/test_smoke_streaming.py deleted file mode 100644 index 13216d4..0000000 --- a/tests/integration/streaming/test_smoke_streaming.py +++ /dev/null @@ -1,71 +0,0 @@ -from datetime import UTC, datetime, timedelta - -import duckdb -import pandas as pd -import pytest - -from fastflowtransform.streaming import StreamSessionizer -from fastflowtransform.testing import base as testing - - -@pytest.mark.integration -@pytest.mark.streaming -def test_stream_sessionizer_produces_sessions(): - delay_amt = 15 - expected_rows_count = 2 - con = duckdb.connect(":memory:") - sess = StreamSessionizer(con) - - # Fixe "now" Referenz & echte Timestamps (UTC, pandas dtype) - now = datetime.now(UTC) - events = pd.DataFrame( - [ - { - "user_id": "u1", - "session_id": "s1", - "source": "ads", - "event_type": "page_view", - "event_timestamp": now - timedelta(minutes=2), - "amount": None, - }, - { - "user_id": "u1", - "session_id": "s1", - "source": "ads", - "event_type": "purchase", - "event_timestamp": now - timedelta(minutes=1, seconds=30), - "amount": 19.9, - }, - { - "user_id": "u2", - "session_id": "s2", - "source": "organic", - "event_type": "page_view", - "event_timestamp": now - timedelta(minutes=1), - "amount": None, - }, - ] - ) - events["event_timestamp"] = pd.to_datetime(events["event_timestamp"], utc=True) - - # Prozessieren - sess.process_batch(events) - - # Tabelle existiert & hat Zeilen - row = con.execute("select count(*) from fct_sessions_streaming").fetchone() - assert row is not None, "Query lieferte keine Zeile" - rows = int(row[0]) - assert rows >= expected_rows_count - - # Basis-Checks (funktionieren in DuckDB & PG) - testing.greater_equal(con, "fct_sessions_streaming", "revenue", 0) - testing.non_negative_sum(con, "fct_sessions_streaming", "revenue") - - # Freshness: DuckDB-sichere Variante direkt im Test (um SQL-Dialekte zu umgehen) - # DuckDB: date_diff('minute', max(ts), now()) - delay_min = con.execute( - "select date_diff('minute', max(session_end), current_timestamp) " - "from fct_sessions_streaming" - ).fetchone() - delay_min = int(delay_min[0]) if delay_min is not None else None - assert delay_min is not None and delay_min <= delay_amt, f"freshness too old: {delay_min} min" diff --git a/tests/integration/test_artifacts_integration.py b/tests/integration/test_artifacts_integration.py index 5f78c82..d6733ce 100644 --- a/tests/integration/test_artifacts_integration.py +++ b/tests/integration/test_artifacts_integration.py @@ -22,7 +22,7 @@ def test_artifacts_all_written(tmp_path: Path): "create or replace table m as select 1 as id", encoding="utf-8" ) (tmp_path / "sources.yml").write_text( - "version: 2\nsources: []\n", + "version: 1\nsources: []\n", encoding="utf-8", ) diff --git a/tests/integration/testing/registry/test_dispatch_integration.py b/tests/integration/testing/registry/test_dispatch_integration.py index 6708687..ec6d260 100644 --- a/tests/integration/testing/registry/test_dispatch_integration.py +++ b/tests/integration/testing/registry/test_dispatch_integration.py @@ -11,11 +11,11 @@ def test_registry_not_null_and_unique_and_params_and_sql(): ex.con.execute("create table t(id int, email varchar)") ex.con.execute("insert into t values (1,'a@example.com'),(1,'b@example.com'),(2,null)") - ok1, msg1, sql1 = TESTS["not_null"](ex.con, "t", "email", {}) + ok1, msg1, sql1 = TESTS["not_null"](ex, "t", "email", {}) assert not ok1 and "is null" in (msg1 or "").lower() assert "select count(*) from t where email is null" in (sql1 or "").lower() - ok2, msg2, sql2 = TESTS["unique"](ex.con, "t", "id", {}) + ok2, msg2, sql2 = TESTS["unique"](ex, "t", "id", {}) assert not ok2 and "duplicate" in (msg2 or "").lower() assert "group by 1 having count(*) > 1" in (sql2 or "").lower() @@ -30,7 +30,7 @@ def test_registry_relationships_runner(): ex.con.execute("insert into fact_users values (1),(2)") ok, msg, sql = TESTS["relationships"]( - ex.con, + ex, "fact_users", "user_id", {"to": "dim_users"}, @@ -41,7 +41,7 @@ def test_registry_relationships_runner(): ex.con.execute("delete from fact_users where user_id = 2") ok2, msg2, _ = TESTS["relationships"]( - ex.con, + ex, "fact_users", "user_id", {"to": "dim_users"}, diff --git a/tests/unit/artifacts/test_manifest_unit.py b/tests/unit/artifacts/test_manifest_unit.py index add2452..db1f10a 100644 --- a/tests/unit/artifacts/test_manifest_unit.py +++ b/tests/unit/artifacts/test_manifest_unit.py @@ -12,7 +12,7 @@ def test_manifest_minimal(tmp_path: Path): (tmp_path / "models").mkdir(parents=True) (tmp_path / "models" / "m.ff.sql").write_text("select 1 as x", encoding="utf-8") (tmp_path / "sources.yml").write_text( - "version: 2\nsources: []\n", + "version: 1\nsources: []\n", encoding="utf-8", ) diff --git a/tests/unit/cli/test_bootstrap_unit.py b/tests/unit/cli/test_bootstrap_unit.py index ef79b62..790a8dc 100644 --- a/tests/unit/cli/test_bootstrap_unit.py +++ b/tests/unit/cli/test_bootstrap_unit.py @@ -3,7 +3,6 @@ import datetime from pathlib import Path -from types import SimpleNamespace from typing import cast import pytest @@ -244,31 +243,3 @@ def test_resolve_project_path_missing_models(tmp_path: Path): # no models/ → should raise with pytest.raises(bootstrap.typer.BadParameter): bootstrap._resolve_project_path(str(tmp_path)) - - -# --------------------------------------------------------------------------- -# _get_test_con - just smoke test -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_get_test_con_prefers_executor_con(): - class ExecWithCon: - def __init__(self): - self.con = SimpleNamespace(execute=lambda *_: "ok") - - ex = ExecWithCon() - con = bootstrap._get_test_con(ex) - assert con.execute("SELECT 1") == "ok" - - -@pytest.mark.unit -def test_get_test_con_falls_back_to_executor(): - class ExecSimple: - def run(self): - return "ran" - - ex = ExecSimple() - con = bootstrap._get_test_con(ex) - # we just get the executor back - assert con is ex diff --git a/tests/unit/cli/test_source_cmd_unit.py b/tests/unit/cli/test_source_cmd_unit.py index 88a9296..ec68211 100644 --- a/tests/unit/cli/test_source_cmd_unit.py +++ b/tests/unit/cli/test_source_cmd_unit.py @@ -84,7 +84,7 @@ def fake_prepare_context(project, env_name, engine, vars): monkeypatch.setattr(source_cmd, "_prepare_context", fake_prepare_context, raising=True) - def fake_run_source_freshness(executor, con, engine): + def fake_run_source_freshness(executor, engine): return [ SourceFreshnessResult( source_name="crm", diff --git a/tests/unit/config/test_config_hook_unit.py b/tests/unit/config/test_config_hook_unit.py index 57106ae..39a4dab 100644 --- a/tests/unit/config/test_config_hook_unit.py +++ b/tests/unit/config/test_config_hook_unit.py @@ -9,7 +9,7 @@ def test_sql_model_config_materialized_view(tmp_path: Path): (tmp_path / "models").mkdir() (tmp_path / "sources.yml").write_text( - "version: 2\nsources: []\n", + "version: 1\nsources: []\n", encoding="utf-8", ) (tmp_path / "models" / "users.ff.sql").write_text( diff --git a/tests/unit/core/test_macros_loading_unit.py b/tests/unit/core/test_macros_loading_unit.py index e67d44a..3e61269 100644 --- a/tests/unit/core/test_macros_loading_unit.py +++ b/tests/unit/core/test_macros_loading_unit.py @@ -13,7 +13,7 @@ def test_macros_are_loaded_and_callable(tmp_path: Path): models = tmp_path / "models" / "macros" models.mkdir(parents=True, exist_ok=True) (tmp_path / "sources.yml").write_text( - "version: 2\nsources: []\n", + "version: 1\nsources: []\n", encoding="utf-8", ) diff --git a/tests/unit/executors/test_databricks_spark_exec_unit.py b/tests/unit/executors/test_databricks_spark_exec_unit.py index db83455..258ff07 100644 --- a/tests/unit/executors/test_databricks_spark_exec_unit.py +++ b/tests/unit/executors/test_databricks_spark_exec_unit.py @@ -9,10 +9,6 @@ from fastflowtransform.core import REGISTRY, Node from fastflowtransform.executors import databricks_spark as mod -from fastflowtransform.executors.databricks_spark import ( - _SparkConnShim, - _split_db_table, -) from fastflowtransform.table_formats.spark_iceberg import IcebergFormatHandler @@ -40,14 +36,6 @@ def test_non_delta_leaves_catalog_unset(exec_factory): assert catalog_values == [] -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_split_db_table_unit(): - assert _split_db_table("db.tbl") == ("db", "tbl") - assert _split_db_table("`db`.`tbl`") == ("db`", "`tbl") - assert _split_db_table("tbl") == (None, "tbl") - - @pytest.mark.unit @pytest.mark.databricks_spark def test_q_ident_unit(exec_minimal): @@ -327,19 +315,6 @@ def bad_upsert(executor, node_name, relation, fingerprint, engine): exec_minimal.on_node_built(node, "demo_tbl", "abc123") -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_spark_conn_shim_execute_runs_select(monkeypatch): - """_SparkConnShim.execute should return rows collected from spark.sql.""" - fake_spark = MagicMock() - fake_spark.sql.return_value.collect.return_value = [("a",), ("b",)] - shim = _SparkConnShim(fake_spark) - - res = shim.execute("SELECT 'a'") - assert res.fetchall() == [("a",), ("b",)] - assert res.fetchone() == ("a",) - - @pytest.mark.unit @pytest.mark.databricks_spark def test_read_relation_uses_spark_table(exec_minimal): diff --git a/tests/unit/executors/test_shims_unit.py b/tests/unit/executors/test_shims_unit.py deleted file mode 100644 index e8549da..0000000 --- a/tests/unit/executors/test_shims_unit.py +++ /dev/null @@ -1,262 +0,0 @@ -# tests/unit/executors/test_shims_unit.py -from __future__ import annotations - -from collections.abc import Sequence -from types import SimpleNamespace -from typing import Any, cast - -import pytest -from sqlalchemy import text as sa_text -from sqlalchemy.engine import Engine - -from fastflowtransform.executors._shims import ( - BigQueryConnShim, - SAConnShim, - _rewrite_pg_create_or_replace_table, -) -from fastflowtransform.typing import Client - -# --------------------------------------------------------------------------- -# Fakes / helpers -# --------------------------------------------------------------------------- - - -class _FakeConn: - """Collects executed statements for assertions.""" - - def __init__(self) -> None: - self.executed: list[tuple[str, dict[str, Any] | None]] = [] - - # SQLAlchemy-style execute - def execute(self, stmt: Any, params: dict[str, Any] | None = None) -> Any: - # store string form for easier asserts - sql_str = stmt.text if hasattr(stmt, "text") else str(stmt) - self.executed.append((sql_str, params)) - # return something fetchable - return SimpleNamespace(fetchone=lambda: None, fetchall=lambda: []) - - # context manager API - def __enter__(self) -> _FakeConn: - return self - - def __exit__(self, exc_type, exc, tb) -> None: - return None - - # so SAConnShim can call conn.begin() in the fallback (we never do here) - def begin(self) -> _FakeConn: - return self - - -class _FakeEngine: - """Engine that always returns the same connection.""" - - def __init__(self) -> None: - self.conn = _FakeConn() - - def begin(self) -> _FakeConn: - return self.conn - - -# --------------------------------------------------------------------------- -# _rewrite_pg_create_or_replace_table -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_rewrite_pg_create_or_replace_table_simple(): - sql = "CREATE OR REPLACE TABLE public.t AS SELECT 1" - out = _rewrite_pg_create_or_replace_table(sql) - assert "DROP TABLE IF EXISTS public.t CASCADE;" in out - assert "CREATE TABLE public.t AS SELECT 1" in out - - -@pytest.mark.unit -def test_rewrite_pg_create_or_replace_table_with_schema_and_quotes(): - sql = ' create or replace table "raw"."users" as select * from src ' - out = _rewrite_pg_create_or_replace_table(sql) - # two statements - assert 'DROP TABLE IF EXISTS "raw"."users" CASCADE;' in out - assert 'CREATE TABLE "raw"."users" AS select * from src' in out - - -@pytest.mark.unit -def test_rewrite_pg_create_or_replace_table_untouched_for_other_sql(): - sql = "SELECT 1" - out = _rewrite_pg_create_or_replace_table(sql) - assert out == sql - - -# --------------------------------------------------------------------------- -# SAConnShim -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_sa_shim_executes_plain_sql_without_schema(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema=None) - - shim.execute("SELECT 1") - - executed = eng.conn.executed - assert len(executed) == 1 - sql, params = executed[0] - assert sql.strip().upper().startswith("SELECT 1") - assert params is None - - -@pytest.mark.unit -def test_sa_shim_sets_search_path_when_schema_given(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema="public") - - shim.execute("SELECT 42") - - executed = eng.conn.executed - # 1) SET LOCAL ... 2) SELECT 42 - assert len(executed) == 2 - assert 'SET LOCAL search_path = "public"' in executed[0][0] - assert "SELECT 42" in executed[1][0] - - -@pytest.mark.unit -def test_sa_shim_rewrites_cor_table_into_two_statements(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema=None) - - shim.execute("CREATE OR REPLACE TABLE my_tbl AS SELECT 1") - - executed = eng.conn.executed - # should have been split into DROP + CREATE - assert len(executed) == 2 - assert "DROP TABLE IF EXISTS my_tbl CASCADE" in executed[0][0] - assert "CREATE TABLE my_tbl AS SELECT 1" in executed[1][0] - - -@pytest.mark.unit -def test_sa_shim_executes_iterable_sequentially(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema=None) - - shim.execute(["SELECT 1", "SELECT 2"]) - - executed = [sql for (sql, _) in eng.conn.executed] - assert executed == ["SELECT 1", "SELECT 2"] - - -@pytest.mark.unit -def test_sa_shim_executes_tuple_with_params_on_last_statement(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema=None) - - shim.execute(("SELECT :x", {"x": 10})) - - executed = eng.conn.executed - assert len(executed) == 1 - sql, params = executed[0] - assert "SELECT :x" in sql - assert params == {"x": 10} - - -@pytest.mark.unit -def test_sa_shim_executes_sqlalchemy_clauseelement(): - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng), schema=None) - - stmt = sa_text("SELECT 1") - shim.execute(stmt) - - executed = eng.conn.executed - assert len(executed) == 1 - assert "SELECT 1" in executed[0][0] - - -# --------------------------------------------------------------------------- -# BigQueryConnShim -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_bq_shim_executes_single_sql(): - calls: dict[str, Any] = {} - - class FakeJob: - def __init__(self) -> None: - self.result_called = False - - def result(self): - self.result_called = True - # Simulate a RowIterator; list is fine for the wrapper. - return ["ROW-1"] - - class FakeClient: - def query(self, sql: str, location: str | None = None): - calls["sql"] = sql - calls["location"] = location - return FakeJob() - - fake = FakeClient() - shim = BigQueryConnShim(cast(Client, fake), location="EU") - res = shim.execute("SELECT 1") - - # Shim now returns a cursor-like wrapper - assert isinstance(res, BigQueryConnShim._ResultWrapper) - assert calls["sql"] == "SELECT 1" - # We don't pass location into client.query, so it should be None. - assert calls["location"] is None - # And fetchone() should give the first row. - assert res.fetchone() == "ROW-1" - assert res.fetchone() is None - - -@pytest.mark.unit -def test_bq_shim_executes_sequence_and_returns_last_job(): - seen: list[str] = [] - - class FakeJob: - def result(self) -> None: - return None - - class FakeClient: - def query(self, sql: str, location: str | None = None): - seen.append(sql) - return FakeJob() - - fake = FakeClient() - shim = BigQueryConnShim(cast(Client, fake), location="EU") - res = shim.execute(["SELECT 1", "SELECT 2", "SELECT 3"]) - - # should have executed all - assert seen == ["SELECT 1", "SELECT 2", "SELECT 3"] - # and returned a cursor-like wrapper over the last result - assert isinstance(res, BigQueryConnShim._ResultWrapper) - - -def test_bq_shim_raises_on_unsupported_type(): - fake_client = SimpleNamespace(query=lambda *a, **k: None) - - shim = BigQueryConnShim(client=cast(Client, fake_client)) - - with pytest.raises(TypeError): - shim.execute(123) - - -# --------------------------------------------------------------------------- -# Mixed / defensive -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_sa_shim_iterable_of_mixed_types(): - """Ensure iterable with strings and ClauseElements is executed in order.""" - eng = _FakeEngine() - shim = SAConnShim(cast(Engine, eng)) - - stmts: Sequence[Any] = ["SELECT 1", sa_text("SELECT 2"), "SELECT 3"] - shim.execute(stmts) - - executed = [sql for (sql, _) in eng.conn.executed] - # sql text may contain trailing semicolons/spaces from TextClause - assert "SELECT 1" in executed[0] - assert "SELECT 2" in executed[1] - assert "SELECT 3" in executed[2] diff --git a/tests/unit/executors/test_snowflake_snowpark_exec.py b/tests/unit/executors/test_snowflake_snowpark_exec.py index fdd1fd4..5dcc7a6 100644 --- a/tests/unit/executors/test_snowflake_snowpark_exec.py +++ b/tests/unit/executors/test_snowflake_snowpark_exec.py @@ -10,7 +10,6 @@ import pytest import fastflowtransform.executors.snowflake_snowpark as sf_mod -from fastflowtransform.executors.snowflake_snowpark import _SFResult # --------------------------------------------------------------------------- # 1) Install a fake snowflake.snowpark BEFORE importing the executor module @@ -147,7 +146,6 @@ def create(self) -> FakeSession: from fastflowtransform.core import Node # noqa: E402 from fastflowtransform.executors.snowflake_snowpark import ( # noqa: E402 SnowflakeSnowparkExecutor, - _SFCursorShim, ) @@ -184,8 +182,6 @@ def sf_exec(monkeypatch): def test_init_sets_db_schema_and_con(sf_exec): assert sf_exec.database == "DB1" assert sf_exec.schema == "SC1" - # con must be present - assert isinstance(sf_exec.con, _SFCursorShim) @pytest.mark.unit @@ -577,9 +573,8 @@ def fake_sql(sql: str): sf_exec.session.sql = fake_sql # type: ignore[assignment] # ACT - res = sf_exec.con.execute("SELECT 1") - - # ASSERT - assert isinstance(res, _SFResult) + res = sf_exec.execute_test_sql("SELECT 1") assert res.fetchall() == [(1, "x"), (2, "y")] + + res = sf_exec.execute_test_sql("SELECT 1") assert res.fetchone() == (1, "x") diff --git a/tests/unit/schema/test_schema_loader_unit.py b/tests/unit/schema/test_schema_loader_unit.py index 0153019..35c10bd 100644 --- a/tests/unit/schema/test_schema_loader_unit.py +++ b/tests/unit/schema/test_schema_loader_unit.py @@ -10,7 +10,7 @@ def test_parse_schema_yaml_column_tests(tmp_path: Path): (tmp_path / "models").mkdir(parents=True) (tmp_path / "models" / "users_enriched.yml").write_text( """ -version: 2 +version: 1 models: - name: users_enriched tags: [batch] diff --git a/tests/unit/test_contracts_unit.py b/tests/unit/test_contracts_unit.py new file mode 100644 index 0000000..1885935 --- /dev/null +++ b/tests/unit/test_contracts_unit.py @@ -0,0 +1,406 @@ +# tests/test_contracts_module.py +from __future__ import annotations + +import textwrap +from pathlib import Path + +import pytest + +from fastflowtransform.config.contracts import ( + ColumnContractModel, + ColumnDefaultsRuleModel, + ColumnMatchModel, + ContractsDefaultsModel, + ContractsFileModel, + PhysicalTypeConfig, +) +from fastflowtransform.contracts import ( + _apply_column_defaults, + _contract_tests_for_table, + _discover_contract_paths, + build_contract_tests, + load_contract_tests, + load_contracts, +) +from fastflowtransform.schema_loader import TestSpec as _TestSpec + +# --------------------------------------------------------------------------- +# Discovery + loading +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_discover_contract_paths_no_models_dir(tmp_path: Path) -> None: + project_dir = tmp_path / "proj" + project_dir.mkdir() + + paths = _discover_contract_paths(project_dir) + assert paths == [] + + +@pytest.mark.unit +def test_discover_contract_paths_with_files(tmp_path: Path) -> None: + project_dir = tmp_path / "proj" + models_dir = project_dir / "models" / "staging" + models_dir.mkdir(parents=True, exist_ok=True) + + f1 = models_dir / "customers.contracts.yml" + f2 = models_dir / "orders.contracts.yml" + f1.write_text("version: 1\ntable: customers\ncolumns: {id: {}}\n", encoding="utf-8") + f2.write_text("version: 1\ntable: orders\ncolumns: {id: {}}\n", encoding="utf-8") + + paths = _discover_contract_paths(project_dir) + assert f1 in paths + assert f2 in paths + # deterministic order (sorted) + assert paths == sorted(paths) + + +@pytest.mark.unit +def test_load_contracts_parses_and_maps_by_table(tmp_path: Path) -> None: + project_dir = tmp_path / "proj" + models_dir = project_dir / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + # two different tables + (models_dir / "customers.contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + table: customers + columns: + id: {} + """ + ), + encoding="utf-8", + ) + + (models_dir / "orders.contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + table: orders + columns: + id: {} + """ + ), + encoding="utf-8", + ) + + contracts = load_contracts(project_dir) + assert set(contracts.keys()) == {"customers", "orders"} + assert isinstance(contracts["customers"], ContractsFileModel) + assert contracts["customers"].table == "customers" + assert "id" in contracts["customers"].columns + + +@pytest.mark.unit +def test_load_contracts_duplicate_tables_last_wins(tmp_path: Path) -> None: + project_dir = tmp_path / "proj" + models_dir = project_dir / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + # First definition + (models_dir / "a_customers.contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + table: customers + columns: + id: + type: string + """ + ), + encoding="utf-8", + ) + + # Second definition (should win) + (models_dir / "z_customers.contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + table: customers + columns: + id: + type: integer + """ + ), + encoding="utf-8", + ) + + contracts = load_contracts(project_dir) + assert list(contracts.keys()) == ["customers"] + customers_cfg = contracts["customers"] + assert customers_cfg.columns["id"].type == "integer" + + +# --------------------------------------------------------------------------- +# Column defaults application +# --------------------------------------------------------------------------- + + +def _defaults_for_id_columns() -> ContractsDefaultsModel: + """ + Helper: one default rule for *_id columns. + """ + rule = ColumnDefaultsRuleModel( + match=ColumnMatchModel(name=r".*_id$"), + type="integer", + physical=PhysicalTypeConfig(default="BIGINT"), + nullable=False, + unique=True, + enum=None, + ) + return ContractsDefaultsModel(columns=[rule]) + + +@pytest.mark.unit +def test_apply_column_defaults_no_defaults_returns_same_instance() -> None: + col = ColumnContractModel() # all None + result = _apply_column_defaults("customer_id", "customers", col, defaults=None) + + # We re-validate into a new instance, so equality on fields is what matters + assert result.type is None + assert result.nullable is None + assert result.unique is None + assert result.enum is None + assert result.physical is None + + +@pytest.mark.unit +def test_apply_column_defaults_matches_by_name_and_sets_missing_fields() -> None: + defaults = _defaults_for_id_columns() + col = ColumnContractModel() # no explicit settings + + eff = _apply_column_defaults("customer_id", "customers", col, defaults) + + assert eff.type == "integer" + assert eff.nullable is False + assert eff.unique is True + assert eff.enum is None + assert isinstance(eff.physical, PhysicalTypeConfig) + assert eff.physical.default == "BIGINT" + + +@pytest.mark.unit +def test_apply_column_defaults_respects_existing_values() -> None: + defaults = _defaults_for_id_columns() + # Explicitly set type and nullable; defaults should NOT override these. + col = ColumnContractModel(type="string", nullable=True, unique=None) + + eff = _apply_column_defaults("customer_id", "customers", col, defaults) + + # Existing values are kept + assert eff.type == "string" + assert eff.nullable is True + # Only unset attributes are filled + assert eff.unique is True + assert isinstance(eff.physical, PhysicalTypeConfig) + + +@pytest.mark.unit +def test_apply_column_defaults_table_regex_filters_rules() -> None: + # Rule only applies to tables whose name matches 'orders_.*' + rule = ColumnDefaultsRuleModel( + match=ColumnMatchModel(name=r".*_id$", table=r"^orders_.*$"), + type="integer", + nullable=False, + ) + defaults = ContractsDefaultsModel(columns=[rule]) + + col = ColumnContractModel() + + # For non-matching table, no defaults applied + eff_customers = _apply_column_defaults("customer_id", "customers", col, defaults) + assert eff_customers.type is None + assert eff_customers.nullable is None + + # For matching table, defaults apply + eff_orders = _apply_column_defaults("customer_id", "orders_daily", col, defaults) + assert eff_orders.type == "integer" + assert eff_orders.nullable is False + + +@pytest.mark.unit +def test_apply_column_defaults_multiple_rules_extend_but_do_not_override_same_field() -> None: + # Rule1 sets nullable, rule2 sets type for all columns. + rule1 = ColumnDefaultsRuleModel( + match=ColumnMatchModel(name=r".*"), + nullable=False, + ) + rule2 = ColumnDefaultsRuleModel( + match=ColumnMatchModel(name=r".*"), + type="integer", + ) + defaults = ContractsDefaultsModel(columns=[rule1, rule2]) + + col = ColumnContractModel() + eff = _apply_column_defaults("customer_id", "customers", col, defaults) + + # Both fields get filled (because they were None) + assert eff.nullable is False + assert eff.type == "integer" + + +# --------------------------------------------------------------------------- +# Contract → TestSpec expansion +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_contract_tests_for_table_basic() -> None: + contract = ContractsFileModel( + version=1, + table="customers", + columns={ + "customer_id": ColumnContractModel(nullable=False, unique=True), + "status": ColumnContractModel(enum=["active", "inactive"]), + "amount": ColumnContractModel(min=0, max=100), + }, + ) + + specs = _contract_tests_for_table( + "customers", + contract, + defaults=None, + default_severity="error", + ) + + # We expect: + # - not_null + unique for customer_id + # - accepted_values for status + # - between for amount + types_by_col = {(s.column, s.type) for s in specs} + + assert ("customer_id", "not_null") in types_by_col + assert ("customer_id", "unique") in types_by_col + assert ("status", "accepted_values") in types_by_col + assert ("amount", "between") in types_by_col + + # All tests should have table="customers" and tag "contract" + for s in specs: + assert s.table == "customers" + assert "contract" in (s.tags or []) + + +@pytest.mark.unit +def test_contract_tests_for_table_with_physical_type_defaults() -> None: + # Column has no physical type itself; defaults will add it + contract = ContractsFileModel( + version=1, + table="customers", + columns={ + "customer_id": ColumnContractModel(nullable=False), + }, + ) + + defaults = ContractsDefaultsModel( + columns=[ + ColumnDefaultsRuleModel( + match=ColumnMatchModel(name=r"^customer_id$"), + # IMPORTANT: pass a mapping, not a PhysicalTypeConfig instance + physical=PhysicalTypeConfig(duckdb="BIGINT"), + ) + ] + ) + + specs = _contract_tests_for_table( + "customers", + contract, + defaults=defaults, + default_severity="error", + ) + + # Expect both column_physical_type and not_null + types_by_col = {(s.column, s.type) for s in specs} + assert ("customer_id", "column_physical_type") in types_by_col + assert ("customer_id", "not_null") in types_by_col + + physical_specs = [s for s in specs if s.type == "column_physical_type"] + assert len(physical_specs) == 1 + p = physical_specs[0] + + # params["physical"] is a PhysicalTypeConfig produced by Pydantic + phys_cfg = p.params["physical"] + assert hasattr(phys_cfg, "duckdb") + assert phys_cfg.duckdb == "BIGINT" + + +# --------------------------------------------------------------------------- +# build_contract_tests / load_contract_tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +def test_build_contract_tests_multiple_tables() -> None: + customers = ContractsFileModel( + version=1, + table="customers", + columns={"id": ColumnContractModel(nullable=False)}, + ) + orders = ContractsFileModel( + version=1, + table="orders", + columns={"id": ColumnContractModel(unique=True)}, + ) + + contracts = {"customers": customers, "orders": orders} + + specs = build_contract_tests(contracts, defaults=None, default_severity="warn") + + assert {s.table for s in specs} == {"customers", "orders"} + # one not_null + one unique + assert any(s.table == "customers" and s.type == "not_null" for s in specs) + assert any(s.table == "orders" and s.type == "unique" for s in specs) + # severity is propagated + assert all(s.severity == "warn" for s in specs) + + +@pytest.mark.unit +def test_load_contract_tests_integration_with_project_defaults(tmp_path: Path) -> None: + project_dir = tmp_path / "proj" + models_dir = project_dir / "models" / "staging" + models_dir.mkdir(parents=True, exist_ok=True) + + # Per-table contract + (models_dir / "customers.contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + table: customers + columns: + customer_id: {} + status: + enum: ["active", "inactive"] + """ + ), + encoding="utf-8", + ) + + # Project-level contracts.yml with a default for *_id columns + (project_dir / "contracts.yml").write_text( + textwrap.dedent( + """ + version: 1 + defaults: + columns: + - match: + name: ".*_id$" + nullable: false + """ + ), + encoding="utf-8", + ) + + specs = load_contract_tests(project_dir) + + # We expect: + # - not_null on customers.customer_id (from project defaults) + # - accepted_values on customers.status (from per-table enum) + by_col_type = {(s.column, s.type) for s in specs} + + assert ("customer_id", "not_null") in by_col_type + assert ("status", "accepted_values") in by_col_type + + # All specs should be TestSpec instances + assert all(isinstance(s, _TestSpec) for s in specs) diff --git a/tests/unit/test_testing_unit.py b/tests/unit/test_testing_unit.py index 837ab19..eda232f 100644 --- a/tests/unit/test_testing_unit.py +++ b/tests/unit/test_testing_unit.py @@ -1,18 +1,17 @@ # tests/unit/test_testing_unit.py from __future__ import annotations -from typing import Any +from typing import Any, cast import pytest +from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.testing.base import ( TestFailure, - _exec, _fail, _pretty_sql, _scalar, accepted_values, - freshness, greater_equal, non_negative_sum, not_null, @@ -40,6 +39,29 @@ def fetchall(self) -> list[tuple]: return self._rows +class _FakeExecutor: + """ + Minimal executor-like helper for tests. + + - execute_test_sql: returns handler(sql) + - execute: forwards to execute_test_sql so _scalar/_exec paths work + - optional compute_freshness_delay_minutes hook when provided + """ + + def __init__(self, handler: Any, freshness_handler: Any | None = None): + self.handler = handler + self.freshness_handler = freshness_handler + self.calls: list[Any] = [] + + def execute_test_sql(self, sql: Any) -> Any: + self.calls.append(sql) + return self.handler(sql) + + def execute(self, sql: Any) -> Any: + # allow _exec to call .execute on non-executor objects + return self.execute_test_sql(sql) + + # --------------------------------------------------------------------------- # _pretty_sql / _sql_list # --------------------------------------------------------------------------- @@ -73,129 +95,6 @@ def test_sql_list_various_types(): assert sql_list([None, "O'Reilly"]) == "NULL, 'O''Reilly'" -# --------------------------------------------------------------------------- -# _exec: branch 1 - connection has .execute -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_exec_direct_non_sqlalchemy(): - calls: list[Any] = [] - - class FakeCon: - def execute(self, sql): - calls.append(sql) - return _FakeResult([(1,)]) - - con = FakeCon() - res = _exec(con, "select 1") - assert isinstance(res, _FakeResult) - assert calls == ["select 1"] - - -@pytest.mark.unit -def test_exec_direct_sqlalchemy_like_string(monkeypatch): - # simulate a SA-like connection (module name contains "sqlalchemy") - class FakeSACon: - __module__ = "sqlalchemy.engine.mock" - - def __init__(self): - self.calls: list[Any] = [] - - def execute(self, stmt, params=None): - # sqlalchemy.text(...) should have been called - self.calls.append((stmt, params)) - return _FakeResult([(1,)]) - - con = FakeSACon() - res = _exec(con, "select 1") - assert isinstance(res, _FakeResult) - # first arg should be a TextClause - assert len(con.calls) == 1 - assert str(con.calls[0][0]).strip().lower().startswith("select 1") - - -@pytest.mark.unit -def test_exec_direct_sqlalchemy_like_tuple_params(): - class FakeSACon: - __module__ = "sqlalchemy.engine.mock" - - def __init__(self): - self.calls: list[Any] = [] - - def execute(self, stmt, params=None): - self.calls.append((stmt, params)) - return _FakeResult([(1,)]) - - con = FakeSACon() - res = _exec(con, ("select :x", {"x": 10})) - assert isinstance(res, _FakeResult) - assert len(con.calls) == 1 - sql_obj, params = con.calls[0] - assert "select :x" in str(sql_obj).lower() - assert params == {"x": 10} - - -# --------------------------------------------------------------------------- -# _exec: branch 2 - no .execute, but .begin() (SQLAlchemy fallback) -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_exec_fallback_begin_with_sequence(): - executed: list[str] = [] - - class FakeCtx: - def __init__(self, outer): - self.outer = outer - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def execute(self, stmt, params=None): - # stmt may be TextClause - if hasattr(stmt, "text"): - executed.append(stmt.text) - else: - executed.append(str(stmt)) - return _FakeResult([(1,)]) - - class FakeCon: - def begin(self): - return FakeCtx(self) - - con = FakeCon() - res = _exec(con, ["select 1", "select 2"]) - assert isinstance(res, _FakeResult) - assert executed == ["select 1", "select 2"] - - -@pytest.mark.unit -def test_exec_fallback_unsupported_type_raises(): - class FakeCtx: - def __enter__(self): - """Enter.""" - return self - - def __exit__(self, exc_type, exc, tb): - """Exit.""" - return False - - def execute(self, *_a, **_k): - return _FakeResult([]) - - class FakeCon: - def begin(self): - return FakeCtx() - - con = FakeCon() - with pytest.raises(TypeError): - _exec(con, object()) - - # --------------------------------------------------------------------------- # _scalar # --------------------------------------------------------------------------- @@ -203,22 +102,16 @@ def begin(self): @pytest.mark.unit def test_scalar_returns_first_value(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(42, "x")]) - - v = _scalar(FakeCon(), "select 42") + execu = _FakeExecutor(lambda sql: _FakeResult([(42, "x")])) + v = _scalar(cast(BaseExecutor, execu), "select 42") expected_value = 42 assert v == expected_value @pytest.mark.unit def test_scalar_returns_none_on_empty(): - class FakeCon: - def execute(self, sql): - return _FakeResult([]) - - v = _scalar(FakeCon(), "select 42") + execu = _FakeExecutor(lambda sql: _FakeResult([])) + v = _scalar(cast(BaseExecutor, execu), "select 42") assert v is None @@ -231,38 +124,20 @@ def execute(self, sql): def test_accepted_values_ok(): # first call: count(*) = 0 → ok # second call (sample) should not be executed - class FakeCon: - def __init__(self): - self.calls = 0 - - def execute(self, sql): - self.calls += 1 - if "count(*)" in sql: - return _FakeResult([(0,)]) - return _FakeResult([]) - - con = FakeCon() - accepted_values(con, "tbl", "col", values=["a", "b"]) - assert con.calls == 1 + execu = _FakeExecutor( + lambda sql: _FakeResult([(0,)]) if "count(*)" in sql else _FakeResult([]), + ) + accepted_values(cast(BaseExecutor, execu), "tbl", "col", values=["a", "b"]) + assert len(execu.calls) == 1 @pytest.mark.unit def test_accepted_values_fail_collects_samples(): - class FakeCon: - def __init__(self): - self.queries: list[str] = [] - - def execute(self, sql): - self.queries.append(sql) - if "count(*)" in sql: - return _FakeResult([(3,)]) - if "distinct" in sql: - return _FakeResult([("X",), ("Y",)]) - return _FakeResult([]) - - con = FakeCon() + execu = _FakeExecutor( + lambda sql: _FakeResult([(3,)]) if "count(*)" in sql else _FakeResult([("X",), ("Y",)]), + ) with pytest.raises(TestFailure) as exc: - accepted_values(con, "x.tbl", "kind", values=["A", "B"]) + accepted_values(cast(BaseExecutor, execu), "x.tbl", "kind", values=["A", "B"]) msg = str(exc.value) assert "x.tbl.kind has 3 value(s) outside accepted set" in msg # should include sample values @@ -276,33 +151,26 @@ def execute(self, sql): @pytest.mark.unit def test_not_null_ok(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(0,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) # should not raise - not_null(FakeCon(), "tbl", "col") + not_null(cast(BaseExecutor, execu), "tbl", "col") @pytest.mark.unit def test_not_null_fails_on_nulls(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(2,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(2,)])) with pytest.raises(TestFailure) as exc: - not_null(FakeCon(), "tbl", "col") + not_null(cast(BaseExecutor, execu), "tbl", "col") assert "has 2 NULL-values" in str(exc.value) @pytest.mark.unit def test_not_null_wraps_db_error(): - class FakeCon: - def execute(self, sql): - raise RuntimeError("undefinedcolumn: foo HAVING") - + execu = _FakeExecutor( + lambda sql: (_ for _ in ()).throw(RuntimeError("undefinedcolumn: foo HAVING")) + ) with pytest.raises(TestFailure) as exc: - not_null(FakeCon(), "tbl", "col") + not_null(cast(BaseExecutor, execu), "tbl", "col") msg = str(exc.value).lower() assert "error in tbl.col" in msg assert "undefinedcolumn" in msg @@ -311,21 +179,15 @@ def execute(self, sql): @pytest.mark.unit def test_unique_ok(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(0,)]) - - unique(FakeCon(), "tbl", "col") + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) + unique(cast(BaseExecutor, execu), "tbl", "col") @pytest.mark.unit def test_unique_fails(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(5,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(5,)])) with pytest.raises(TestFailure) as exc: - unique(FakeCon(), "tbl", "col") + unique(cast(BaseExecutor, execu), "tbl", "col") assert "contains 5 duplicates" in str(exc.value) @@ -336,92 +198,50 @@ def execute(self, sql): @pytest.mark.unit def test_greater_equal_ok(): - class FakeCon: - def execute(self, sql): - # no rows with < threshold - return _FakeResult([(0,)]) - - greater_equal(FakeCon(), "tbl", "amount", threshold=10) + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) + greater_equal(cast(BaseExecutor, execu), "tbl", "amount", threshold=10) @pytest.mark.unit def test_greater_equal_fails(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(3,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(3,)])) with pytest.raises(TestFailure) as exc: - greater_equal(FakeCon(), "tbl", "amount", threshold=10) + greater_equal(cast(BaseExecutor, execu), "tbl", "amount", threshold=10) assert "has 3 values < 10" in str(exc.value) @pytest.mark.unit def test_non_negative_sum_ok(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(0,)]) - - non_negative_sum(FakeCon(), "tbl", "amount") + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) + non_negative_sum(cast(BaseExecutor, execu), "tbl", "amount") @pytest.mark.unit def test_non_negative_sum_fails(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(-5,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(-5,)])) with pytest.raises(TestFailure) as exc: - non_negative_sum(FakeCon(), "tbl", "amount") + non_negative_sum(cast(BaseExecutor, execu), "tbl", "amount") assert "is negative: -5" in str(exc.value) @pytest.mark.unit def test_row_count_between_ok(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(5,)]) - - row_count_between(FakeCon(), "tbl", min_rows=1, max_rows=10) + execu = _FakeExecutor(lambda sql: _FakeResult([(5,)])) + row_count_between(cast(BaseExecutor, execu), "tbl", min_rows=1, max_rows=10) @pytest.mark.unit def test_row_count_between_too_few(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(0,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) with pytest.raises(TestFailure): - row_count_between(FakeCon(), "tbl", min_rows=1) + row_count_between(cast(BaseExecutor, execu), "tbl", min_rows=1) @pytest.mark.unit def test_row_count_between_too_many(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(50,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(50,)])) with pytest.raises(TestFailure): - row_count_between(FakeCon(), "tbl", min_rows=1, max_rows=10) - - -@pytest.mark.unit -def test_freshness_ok(): - class FakeCon: - def execute(self, sql): - # pretend last update was 3 min ago - return _FakeResult([(3.0,)]) - - freshness(FakeCon(), "tbl", "ts", max_delay_minutes=5) - - -@pytest.mark.unit -def test_freshness_too_old(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(99.0,)]) - - with pytest.raises(TestFailure): - freshness(FakeCon(), "tbl", "ts", max_delay_minutes=10) + row_count_between(cast(BaseExecutor, execu), "tbl", min_rows=1, max_rows=10) # --------------------------------------------------------------------------- @@ -431,13 +251,9 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_equal_exact_ok(): - class FakeCon: - def execute(self, sql): - # both scalar_where calls will read this - return _FakeResult([(10,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(10,)])) reconcile_equal( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "a", "expr": "sum(x)"}, right={"table": "b", "expr": "sum(y)"}, ) @@ -445,18 +261,15 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_equal_abs_tolerance_ok(): - class FakeCon: - def __init__(self): - self.calls = 0 + calls: list[Any] = [] - def execute(self, sql): - self.calls += 1 - if self.calls == 1: - return _FakeResult([(10.0,)]) - return _FakeResult([(11.0,)]) + def handler(sql): + calls.append(sql) + return _FakeResult([(10.0,)]) if len(calls) == 1 else _FakeResult([(11.0,)]) + execu = _FakeExecutor(handler) reconcile_equal( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "a", "expr": "v"}, right={"table": "b", "expr": "v"}, abs_tolerance=1.5, @@ -465,19 +278,16 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_equal_fails(): - class FakeCon: - def __init__(self): - self.calls = 0 + calls: list[Any] = [] - def execute(self, sql): - self.calls += 1 - if self.calls == 1: - return _FakeResult([(10.0,)]) - return _FakeResult([(20.0,)]) + def handler(sql): + calls.append(sql) + return _FakeResult([(10.0,)]) if len(calls) == 1 else _FakeResult([(20.0,)]) + execu = _FakeExecutor(handler) with pytest.raises(TestFailure): reconcile_equal( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "a", "expr": "v"}, right={"table": "b", "expr": "v"}, ) @@ -485,19 +295,15 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_ratio_within_ok(): - class FakeCon: - def __init__(self): - self.calls = 0 + calls: list[Any] = [] - def execute(self, sql): - self.calls += 1 - if self.calls == 1: - return _FakeResult([(100.0,)]) - return _FakeResult([(50.0,)]) + def handler(sql): + calls.append(sql) + return _FakeResult([(100.0,)]) if len(calls) == 1 else _FakeResult([(50.0,)]) - # ratio = 100 / 50 = 2.0 + execu = _FakeExecutor(handler) reconcile_ratio_within( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "l", "expr": "x"}, right={"table": "r", "expr": "y"}, min_ratio=1.5, @@ -507,15 +313,13 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_ratio_within_fails(): - class FakeCon: - def execute(self, sql): - if "from l" in sql: - return _FakeResult([(10.0,)]) - return _FakeResult([(100.0,)]) + execu = _FakeExecutor( + lambda sql: _FakeResult([(10.0,)]) if "from l" in sql else _FakeResult([(100.0,)]) + ) with pytest.raises(TestFailure): reconcile_ratio_within( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "l", "expr": "x"}, right={"table": "r", "expr": "y"}, min_ratio=0.5, @@ -525,18 +329,15 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_diff_within_ok(): - class FakeCon: - def __init__(self): - self.calls = 0 + calls: list[Any] = [] - def execute(self, sql): - self.calls += 1 - if self.calls == 1: - return _FakeResult([(50.0,)]) - return _FakeResult([(53.0,)]) + def handler(sql): + calls.append(sql) + return _FakeResult([(50.0,)]) if len(calls) == 1 else _FakeResult([(53.0,)]) + execu = _FakeExecutor(handler) reconcile_diff_within( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "l", "expr": "x"}, right={"table": "r", "expr": "y"}, max_abs_diff=5.0, @@ -545,15 +346,13 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_diff_within_fails(): - class FakeCon: - def execute(self, sql): - if "from l" in sql: - return _FakeResult([(10.0,)]) - return _FakeResult([(25.0,)]) + execu = _FakeExecutor( + lambda sql: _FakeResult([(10.0,)]) if "from l" in sql else _FakeResult([(25.0,)]) + ) with pytest.raises(TestFailure): reconcile_diff_within( - FakeCon(), + cast(BaseExecutor, execu), left={"table": "l", "expr": "x"}, right={"table": "r", "expr": "y"}, max_abs_diff=5.0, @@ -562,13 +361,9 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_coverage_ok(): - class FakeCon: - def execute(self, sql): - # anti-join count(*) == 0 - return _FakeResult([(0,)]) - + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) reconcile_coverage( - FakeCon(), + cast(BaseExecutor, execu), source={"table": "src", "key": "id"}, target={"table": "tgt", "key": "id"}, ) @@ -576,13 +371,11 @@ def execute(self, sql): @pytest.mark.unit def test_reconcile_coverage_fails(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(3,)]) + execu = _FakeExecutor(lambda sql: _FakeResult([(3,)])) with pytest.raises(TestFailure): reconcile_coverage( - FakeCon(), + cast(BaseExecutor, execu), source={"table": "src", "key": "id"}, target={"table": "tgt", "key": "id"}, ) @@ -590,12 +383,10 @@ def execute(self, sql): @pytest.mark.unit def test_relationships_ok(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(0,)]) + execu = _FakeExecutor(lambda sql: _FakeResult([(0,)])) relationships( - FakeCon(), + cast(BaseExecutor, execu), table="fact_events", field="user_id", to_table="dim_users", @@ -605,13 +396,11 @@ def execute(self, sql): @pytest.mark.unit def test_relationships_fails_on_orphans(): - class FakeCon: - def execute(self, sql): - return _FakeResult([(5,)]) + execu = _FakeExecutor(lambda sql: _FakeResult([(5,)])) with pytest.raises(TestFailure): relationships( - FakeCon(), + cast(BaseExecutor, execu), table="fact_events", field="user_id", to_table="dim_users", @@ -621,13 +410,11 @@ def execute(self, sql): @pytest.mark.unit def test_relationships_wraps_db_errors(): - class FakeCon: - def execute(self, sql): - raise RuntimeError("no such column") + execu = _FakeExecutor(lambda sql: (_ for _ in ()).throw(RuntimeError("no such column"))) with pytest.raises(TestFailure) as exc: relationships( - FakeCon(), + cast(BaseExecutor, execu), table="fact_events", field="user_id", to_table="dim_users", diff --git a/tests/unit/testing/test_accepted_values_unit.py b/tests/unit/testing/test_accepted_values_unit.py index aef12c5..331396a 100644 --- a/tests/unit/testing/test_accepted_values_unit.py +++ b/tests/unit/testing/test_accepted_values_unit.py @@ -10,8 +10,8 @@ def test_accepted_values_pass_and_fail(): ex.con.execute("create table t(id int, email varchar)") ex.con.execute("insert into t values (1,'a@example.com'),(2,'b@example.com')") # Pass - accepted_values(ex.con, "t", "email", values=["a@example.com", "b@example.com"]) + accepted_values(ex, "t", "email", values=["a@example.com", "b@example.com"]) # Fail with pytest.raises(TestFailure): - accepted_values(ex.con, "t", "email", values=["a@example.com"]) + accepted_values(ex, "t", "email", values=["a@example.com"]) diff --git a/uv.lock b/uv.lock index 31076da..9cb4ee5 100644 --- a/uv.lock +++ b/uv.lock @@ -733,7 +733,7 @@ wheels = [ [[package]] name = "fastflowtransform" -version = "0.6.11" +version = "0.6.13" source = { editable = "." } dependencies = [ { name = "duckdb" }, From 3b2b70efd234409850e3dd0af60325f1eb78bdcf Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Mon, 8 Dec 2025 14:54:48 +0100 Subject: [PATCH 2/7] refactored seeding --- docs/Quickstart.md | 15 +- src/fastflowtransform/executors/base.py | 12 + .../executors/bigquery/_bigquery_mixin.py | 13 + .../executors/bigquery/base.py | 20 + .../executors/databricks_spark.py | 41 + src/fastflowtransform/executors/duckdb.py | 26 + src/fastflowtransform/executors/postgres.py | 24 + .../executors/snowflake_snowpark.py | 31 + src/fastflowtransform/seeding.py | 699 ++---------------- tests/unit/test_seeding_unit.py | 357 +-------- 10 files changed, 227 insertions(+), 1011 deletions(-) diff --git a/docs/Quickstart.md b/docs/Quickstart.md index 9ea17ab..6c61a43 100644 --- a/docs/Quickstart.md +++ b/docs/Quickstart.md @@ -18,7 +18,7 @@ The command is non-interactive, refuses to overwrite existing directories, and l python3 -m venv .venv . .venv/bin/activate # or source .venv/bin/activate pip install --upgrade pip -pip install -e . # run from the repo root; use `uv pip install --editable .` if you prefer uv +pip install fastflowtransform fft --help ``` @@ -26,20 +26,20 @@ Choose extras if you target other engines (combine as needed): ```bash # Postgres -pip install -e .[postgres] +pip install "fastflowtransform[postgres]" # BigQuery (pandas) or BigFrames -pip install -e .[bigquery] # pandas -pip install -e .[bigquery_bf] # BigFrames +pip install "fastflowtransform[bigquery]" # pandas +pip install "fastflowtransform[bigquery_bf]" # BigFrames # Databricks/Spark + Delta -pip install -e .[spark] +pip install "fastflowtransform[spark]" # Snowflake Snowpark -pip install -e .[snowflake] +pip install "fastflowtransform[snowflake]" # Everything -pip install -e .[full] +pip install "fastflowtransform[full]" ``` ## 2. Create project layout @@ -106,7 +106,6 @@ You should see log lines similar to `✓ L01 [DUCK] users.ff`. The resulting tab ## 7. Next steps - Add `project.yml` for reusable `vars:` and metadata -- Explore `fft docs` to generate HTML documentation - Use engine profiles under `profiles.yml` to target Postgres, BigQuery, or Databricks (path-based sources supported via `format` + `location` overrides) - Render the DAG site for this project: `fft dag demo --env dev --html` (find it under `demo/site/dag/index.html`) diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 64f63df..306cf4c 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -1132,6 +1132,18 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None f"engine '{self.engine_name}'." ) + # ── Seed loading hook ─────────────────────────────────────────────── + def load_seed( + self, table: str, df: Any, schema: str | None = None + ) -> tuple[bool, str, bool]: # pragma: no cover - interface + """ + Materialize a seed DataFrame into the target engine. Executors that + support seeds should override and return True when handled. + """ + raise NotImplementedError( + f"Seeding is not implemented for executor engine '{self.engine_name}'." + ) + ENGINE_NAME: str = "generic" @property diff --git a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py b/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py index 4190192..b4c0c2b 100644 --- a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py +++ b/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py @@ -54,6 +54,19 @@ def _qualified_identifier( ) -> str: return self._qualify_identifier(relation, schema=dataset, catalog=project) + def _qualified_api_identifier( + self, relation: str, project: str | None = None, dataset: str | None = None + ) -> str: + """ + Build an API-safe identifier (project.dataset.table) without backticks. + """ + return self._qualify_identifier( + relation, + schema=dataset, + catalog=project, + quote=False, + ) + def _ensure_dataset(self) -> None: ds_id = f"{self.project}.{self.dataset}" try: diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 5d997a2..1a16a9b 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -434,3 +434,23 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None if not rows: return None return rows[0][0] + + def load_seed(self, table: str, df: Any, schema: str | None = None) -> tuple[bool, str, bool]: + dataset_id = schema or self.dataset + + table_id = self._qualified_api_identifier( + table, + project=self.project, + dataset=dataset_id, + ) + full_name = table_id + self._ensure_dataset() + + job_config = bigquery.LoadJobConfig( + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE + ) + + load_job = self.client.load_table_from_dataframe(df, table_id, job_config=job_config) + load_job.result() + + return True, full_name, False diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index 2e527ce..5b6e4a4 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -1269,6 +1269,47 @@ def execute_hook_sql(self, sql: str) -> None: # Reuse your existing single-statement executor self._execute_sql(stmt) + def load_seed( + self, table: str, df: pd.DataFrame, schema: str | None = None + ) -> tuple[bool, str, bool]: + cleaned_table = self._strip_quotes(table) + parts = self._identifier_parts(cleaned_table) + + created_schema = False + if schema and len(parts) == 1: + schema_part = self._strip_quotes(schema) + if schema_part: + # Ensure database exists when a separate schema is provided. + self._execute_sql(f"CREATE DATABASE IF NOT EXISTS {self._q_ident(schema_part)}") + created_schema = True + parts = [schema_part, parts[0]] + + if not parts: + raise ValueError(f"Invalid Spark table identifier: {table}") + + target_identifier = ".".join(parts) + target_sql = self._sql_identifier(target_identifier) + format_handler = getattr(self, "_format_handler", None) + + storage_meta = storage.get_seed_storage(target_identifier) + + sdf = self.spark.createDataFrame(df) + + allows_unmanaged = bool(getattr(format_handler, "allows_unmanaged_paths", lambda: True)()) + + if storage_meta.get("path") and allows_unmanaged: + try: + self._write_to_storage_path(target_identifier, sdf, storage_meta) + except Exception as exc: # pragma: no cover + raise RuntimeError(f"Spark seed load failed for {target_sql}: {exc}") from exc + else: + try: + self._save_df_as_table(target_identifier, sdf, storage={"path": None}) + except Exception as exc: # pragma: no cover + raise RuntimeError(f"Spark seed load failed for {target_sql}: {exc}") from exc + + return True, target_identifier, created_schema + # ---- Unit-test helpers ------------------------------------------------- def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None: diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index e99437d..8f6ebc8 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -741,3 +741,29 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None ).fetchall() return rows[0][0] if rows else None + + def load_seed( + self, table: str, df: pd.DataFrame, schema: str | None = None + ) -> tuple[bool, str, bool]: + target_schema = schema or self.schema + created_schema = False + + # Qualify identifier with optional schema/catalog + qualified = self._qualify_identifier(table, schema=target_schema, catalog=self.catalog) + + if target_schema and "." not in table: + safe_schema = _q(target_schema) + self._execute_sql(f"create schema if not exists {safe_schema}") + created_schema = True + + tmp = f"_ff_seed_{uuid.uuid4().hex[:8]}" + self.con.register(tmp, df) + try: + self._execute_sql(f'create or replace table {qualified} as select * from "{tmp}"') + finally: + with suppress(Exception): + self.con.unregister(tmp) + with suppress(Exception): + self._execute_sql(f'drop view if exists "{tmp}"') + + return True, qualified, created_schema diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 0d3f7ef..9e83296 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -371,6 +371,30 @@ def _record_dataframe_stats(self, df: pd.DataFrame, duration_ms: int) -> None: ) ) + def load_seed( + self, table: str, df: pd.DataFrame, schema: str | None = None + ) -> tuple[bool, str, bool]: + target_schema = schema or self.schema + qualified = self._qualify_identifier(table, schema=target_schema) + + drop_sql = f"DROP TABLE IF EXISTS {qualified} CASCADE" + with self.engine.begin() as conn: + conn.exec_driver_sql(drop_sql) + + df.to_sql( + table, + self.engine, + if_exists="replace", + index=False, + schema=target_schema, + method="multi", + ) + + with self.engine.begin() as conn: + conn.exec_driver_sql(f"ANALYZE {qualified}") + + return True, qualified, False + # ---------- Python view helper ---------- def _create_or_replace_view_from_table( self, view_name: str, backing_table: str, node: Node diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index f6aeac2..602e4e9 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -674,3 +674,34 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None return None # first column of first row return str(rows[0][0]) if rows[0] and rows[0][0] is not None else None + + def load_seed( + self, table: str, df: pd.DataFrame, schema: str | None = None + ) -> tuple[bool, str, bool]: + target_db = self.database + target_schema = schema or self.schema + + if not target_db or not target_schema: + raise RuntimeError("Snowflake seeding requires database and schema.") + + created_schema = False + if self.allow_create_schema: + self.session.sql( + f'CREATE SCHEMA IF NOT EXISTS "{target_db}"."{target_schema}"' + ).collect() + created_schema = True + + self.session.write_pandas( + df, + table_name=table, + database=target_db, + schema=target_schema, + auto_create_table=True, + quote_identifiers=False, + overwrite=True, + use_logical_type=True, + ) + + full_name = f'"{target_db}"."{target_schema}"."{table}"' + + return True, full_name, created_schema diff --git a/src/fastflowtransform/seeding.py b/src/fastflowtransform/seeding.py index 98d88da..0e2702b 100644 --- a/src/fastflowtransform/seeding.py +++ b/src/fastflowtransform/seeding.py @@ -3,24 +3,17 @@ import math import os -import shutil -import uuid -from collections.abc import Callable, Iterable from contextlib import suppress from datetime import UTC, datetime from pathlib import Path from time import perf_counter from typing import Any, NamedTuple, cast -from urllib.parse import unquote, urlparse import pandas as pd -from fastflowtransform import storage from fastflowtransform.config.seeds import SeedColumnConfig, SeedsSchemaConfig, load_seeds_schema -from fastflowtransform.executors.snowflake_snowpark import SnowflakeSnowparkExecutor from fastflowtransform.logging import echo from fastflowtransform.settings import EngineType -from fastflowtransform.typing import SDF, SparkAnalysisException, SparkSession # ----------------------------- File I/O & Schema (dtypes) ----------------------------- @@ -34,10 +27,34 @@ def _read_seed_file(path: Path) -> pd.DataFrame: raise ValueError(f"Unsupported seed file format: {path.name}") +def _table_key_candidates(table: str, seed_id: str | None = None) -> list[str]: + """Generate lookup keys for schema.yml sections (prioritize path-based IDs).""" + candidates: list[str] = [] + if seed_id: + candidates.append(seed_id) + if "/" in seed_id: + candidates.append(seed_id.replace("/", ".")) + candidates.append(table) + return [c for c in candidates if c] + + +def _resolve_dtypes_for_table( + schema_cfg: SeedsSchemaConfig | None, table: str, seed_id: str | None +) -> dict[str, str]: + if not schema_cfg: + return {} + for key in _table_key_candidates(table, seed_id): + dtypes = schema_cfg.dtypes.get(key) + if dtypes: + return dtypes + return {} + + def _apply_schema( df: pd.DataFrame, table: str, schema_cfg: SeedsSchemaConfig | None, + seed_id: str | None, ) -> pd.DataFrame: """ Apply optional pandas dtypes from seeds/schema.yml for a given table key. @@ -51,9 +68,7 @@ def _apply_schema( Casting errors are swallowed on purpose to avoid blocking seed loads. """ - if not schema_cfg: - return df - dtypes: dict[str, str] = schema_cfg.dtypes.get(table) or {} + dtypes = _resolve_dtypes_for_table(schema_cfg, table, seed_id) if not dtypes: return df @@ -98,12 +113,15 @@ def _canonical_type(value: str | None) -> str: def _column_schema_for_table( - schema_cfg: SeedsSchemaConfig | None, - table: str, + schema_cfg: SeedsSchemaConfig | None, table: str, seed_id: str | None ) -> dict[str, SeedColumnConfig]: if not schema_cfg: return {} - return schema_cfg.columns.get(table) or {} + for key in _table_key_candidates(table, seed_id): + column_cfg = schema_cfg.columns.get(key) + if column_cfg: + return column_cfg + return {} def _resolve_column_type_for_engine( @@ -145,12 +163,13 @@ def _apply_column_schema( table: str, schema_cfg: SeedsSchemaConfig | None, executor: Any, + seed_id: str | None, ) -> pd.DataFrame: - column_cfg = _column_schema_for_table(schema_cfg, table) + column_cfg = _column_schema_for_table(schema_cfg, table, seed_id) if not column_cfg: return df - engine = _engine_name_from_executor(executor) + engine = executor.engine_name missing = [col for col in column_cfg if col not in df.columns] if missing: cols = ", ".join(missing) @@ -160,8 +179,8 @@ def _apply_column_schema( ) df_out = df.copy() - for col in df.columns: - target_type = _resolve_column_type_for_engine(column_cfg.get(col), engine) + for col, col_cfg in column_cfg.items(): + target_type = _resolve_column_type_for_engine(col_cfg, engine) try: df_out[col] = _cast_series_to_type(df_out[col], target_type, table, col) except Exception as exc: @@ -182,100 +201,6 @@ def _inject_seed_metadata(df: pd.DataFrame, seed_id: str, path: Path) -> pd.Data return df_out -# -------------------------------- Identifier utilities -------------------------------- - - -def _dq(ident: str) -> str: - """Double-quote an SQL identifier (DuckDB/Postgres compatible).""" - return '"' + ident.replace('"', '""') + '"' - - -def _is_qualified(name: str) -> bool: - """Return True if the provided table name appears to be schema-qualified.""" - return "." in name - - -def _qualify(table: str, schema: str | None, catalog: str | None = None) -> str: - """ - Return a safely quoted, optionally schema-qualified identifier. - - Respects already-qualified names like raw.users or "raw"."users". - - Quotes each identifier part individually. - """ - if _is_qualified(table): - return ".".join(_dq(p) for p in table.split(".")) - catalog_part = catalog.strip() if isinstance(catalog, str) and catalog.strip() else None - if schema: - schema_part = schema.strip() - parts: list[str] = [] - if catalog_part: - parts.append(_dq(catalog_part)) - parts.append(_dq(schema_part)) - parts.append(_dq(table)) - return ".".join(parts) - if catalog_part: - return f"{_dq(catalog_part)}.{_dq(table)}" - return _dq(table) - - -def _spark_warehouse_base(spark: Any) -> Path | None: - """Resolve the Spark warehouse directory if it points to the local filesystem.""" - try: - conf_val = spark.conf.get("spark.sql.warehouse.dir", "spark-warehouse") - except Exception: - conf_val = "spark-warehouse" - - if not isinstance(conf_val, str): - conf_val = str(conf_val) - parsed = urlparse(conf_val) - scheme = (parsed.scheme or "").lower() - - if scheme and scheme != "file": - return None - - if scheme == "file": - # Treat file:// URIs as local filesystem paths. - if parsed.netloc and parsed.netloc not in {"", "localhost"}: - return None - raw_path = unquote(parsed.path or "") - if not raw_path: - return None - base = Path(raw_path) - else: - base = Path(conf_val) - - if not base.is_absolute(): - base = Path.cwd() / base - return base - - -def _spark_table_location(parts: list[str], spark: Any) -> Path | None: - """ - Best-effort guess of the filesystem location for a managed Spark table. - Works for default schema, schema.table, and catalog.schema.table patterns. - """ - base = _spark_warehouse_base(spark) - if base is None or not parts: - return None - - filtered = [p for p in parts if p] - if not filtered: - return None - - # Drop common catalog prefixes while retaining the schema name. - catalog_cutoff = 3 - if len(filtered) >= catalog_cutoff and filtered[0].lower() in {"spark_catalog", "spark"}: - filtered = filtered[1:] - - table = filtered[-1] - schema_cutoff = 2 - schema = filtered[-2] if len(filtered) >= schema_cutoff else None - - location = base - if schema: - location = location / f"{schema}.db" - return location / table - - # -------------------------------- Pretty echo helpers --------------------------------- @@ -337,46 +262,6 @@ class SeedTarget(NamedTuple): table: str -def _engine_name_from_executor(executor: Any) -> str: - """ - Infer a canonical engine name from the executor object. - - Preference: - 1) executor.engine_name (BaseExecutor-derived) - 2) Spark hint → "databricks_spark" - 3) SQLAlchemy dialect ("postgresql" → "postgres", "bigquery" → "bigquery") - 4) DuckDB heuristic (executor.con present) → "duckdb" - 5) "unknown" as last resort - """ - # 1) BaseExecutor-style engine_name - engine_name = getattr(executor, "engine_name", None) - if isinstance(engine_name, str) and engine_name.strip(): - return engine_name.strip() - - # 2) Spark-style executor - if getattr(executor, "spark", None) is not None: - return "databricks_spark" - - # 3) SQLAlchemy-based executors - eng = getattr(executor, "engine", None) - if eng is not None: - dialect_name = getattr(getattr(eng, "dialect", None), "name", None) - if isinstance(dialect_name, str) and dialect_name: - low = dialect_name.lower() - if low.startswith("postgres"): - return "postgres" - if low.startswith("bigquery"): - return "bigquery" - return low - - # 4) DuckDB-ish: has a .con (DuckDBPyConnection or similar) - if getattr(executor, "con", None) is not None: - return "duckdb" - - # 5) Fallback - return "unknown" - - def _seed_id(seeds_dir: Path, path: Path) -> str: """ Build a unique seed ID from the path relative to `seeds/`, without the extension. @@ -417,7 +302,6 @@ def _resolve_schema_and_table_by_cfg( return schema, table targets = schema_cfg.targets - engine = _engine_name_from_executor(executor) entry = targets.get(seed_id) if not entry: @@ -433,525 +317,32 @@ def _resolve_schema_and_table_by_cfg( return schema, table table = entry.table or table - engine = _engine_name_from_executor(executor) + engine = executor.engine_name engine_key = cast(EngineType, engine) schema = entry.schema_by_engine.get(engine_key) or entry.schema_ or schema return schema, table -# ------------------------------ Materialization (engines) ------------------------------ - -# ------------------------------------------------------------ -# Engine-specifig Handlers -# ------------------------------------------------------------ - - -def _handle_duckdb(table: str, df: pd.DataFrame, executor: Any, schema: str | None) -> bool: - con = getattr(executor, "con", None) - if con is None: - return False - - try: - import duckdb as _dd # Noqa PLC0415 - - is_duck_con = isinstance(con, _dd.DuckDBPyConnection) - except Exception: - is_duck_con = all(hasattr(con, m) for m in ("register", "execute")) - - if not is_duck_con: - return False - - catalog = getattr(executor, "catalog", None) - full_name = _qualify(table, schema, catalog) - created_schema = False - if schema and not _is_qualified(table): - con.execute(f"create schema if not exists {_dq(schema)}") - created_schema = True - - t0 = perf_counter() - tmp = f"_ff_seed_{uuid.uuid4().hex[:8]}" - con.register(tmp, df) - try: - con.execute(f"create or replace table {full_name} as select * from {_dq(tmp)}") - finally: - with suppress(Exception): - con.unregister(tmp) # duckdb >= 0.8 - with suppress(Exception): - con.execute(f"drop view if exists {_dq(tmp)}") - - dt_ms = int((perf_counter() - t0) * 1000) - _echo_seed_line( - full_name=full_name, - rows=len(df), - cols=df.shape[1], - engine="duckdb", - ms=dt_ms, - created_schema=created_schema, - action="replaced", - ) - return True - - -def _handle_sqlalchemy(table: str, df: pd.DataFrame, executor: Any, schema: str | None) -> bool: - eng = getattr(executor, "engine", None) - if eng is None: - return False - if "sqlalchemy" not in getattr(eng.__class__, "__module__", ""): - return False - - full_name = _qualify(table, schema) - dialect_name = getattr(getattr(eng, "dialect", None), "name", "") or "" - if dialect_name.lower() == "postgresql": - # Postgres blocks DROP TABLE when dependent views exist (e.g. stg_* views). - drop_sql = f"DROP TABLE IF EXISTS {full_name} CASCADE" - with eng.begin() as conn: - conn.exec_driver_sql(drop_sql) - - t0 = perf_counter() - df.to_sql(table, eng, if_exists="replace", index=False, schema=schema, method="multi") - dt_ms = int((perf_counter() - t0) * 1000) - - if dialect_name.lower() == "postgresql": - with eng.begin() as conn: - conn.exec_driver_sql(f"ANALYZE {full_name}") - - dialect = dialect_name or getattr(getattr(eng, "dialect", None), "name", "sqlalchemy") - _echo_seed_line( - full_name=full_name, - rows=len(df), - cols=df.shape[1], - engine=dialect, - ms=dt_ms, - created_schema=False, - action="replaced", - ) - return True - - -def _handle_bigquery(table: str, df: pd.DataFrame, executor: Any, schema: str | None) -> bool: - """ - Handle seeding for the BigQuery executor using the official client. - - We detect BigQuery by the presence of an attribute named ``client`` - that behaves like ``google.cloud.bigquery.Client``. The target dataset - is resolved as: - - 1) the provided ``schema`` argument (preferred; allows seeds/schema.yml - to control datasets explicitly), or - 2) an executor attribute such as ``dataset`` / ``dataset_id``. - - Notes: - - The dataset must already exist; this function does not create it. - - We use WRITE_TRUNCATE semantics (replace) to mirror the behavior of - the DuckDB / SQLAlchemy handlers. - """ - client = getattr(executor, "client", None) - if client is None: - return False - - # Prefer explicit schema from the caller / seeds/schema.yml. - dataset_id = ( - schema or getattr(executor, "dataset", None) or getattr(executor, "dataset_id", None) - ) - if not isinstance(dataset_id, str) or not dataset_id.strip(): - # Not a BigQuery executor we know how to handle. - return False - - dataset_id = dataset_id.strip() - - # Project: executor may expose it explicitly; otherwise fall back to the - # client project (Application Default Credentials, etc.). - project_id = getattr(executor, "project", None) - if not isinstance(project_id, str) or not project_id.strip(): - project_id = getattr(client, "project", None) - - if isinstance(project_id, str) and project_id.strip(): - table_id = f"{project_id.strip()}.{dataset_id}.{table}" - full_name = table_id - else: - # Dataset-qualified ID still works if a default project is set on the client. - table_id = f"{dataset_id}.{table}" - full_name = table_id - - try: - from google.cloud import bigquery # noqa PLC0415 type: ignore # pragma: no cover - except Exception as exc: # pragma: no cover - missing optional dependency - raise RuntimeError( - "google-cloud-bigquery is required for seeding into BigQuery, " - "but it is not installed. Install the BigQuery extras for " - "FastFlowTransform or add google-cloud-bigquery to your environment." - ) from exc - - job_config = bigquery.LoadJobConfig(write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE) - - t0 = perf_counter() - # Let the BigQuery client infer the schema from the pandas DataFrame. - executor._ensure_dataset() - load_job = client.load_table_from_dataframe(df, table_id, job_config=job_config) - load_job.result() # Wait for completion - dt_ms = int((perf_counter() - t0) * 1000) - - _echo_seed_line( - full_name=full_name, - rows=len(df), - cols=df.shape[1], - engine="bigquery", - ms=dt_ms, - created_schema=False, - action="replaced", - ) - return True - - -def _spark_ident(name: str) -> str: - """Return a Spark-safe identifier (escapes backticks).""" - return name.replace("`", "``") - - -def _prepare_spark_target( - table: str, - schema: str | None, - executor: Any, - spark: Any, -) -> tuple[str, str, Any, bool]: - """ - Build Spark target identifiers and detect the table location. - - Returns: - target_identifier: unquoted table identifier (db.table or table) - target_sql: quoted SQL identifier with backticks - target_location: filesystem location (may be None) - created_schema: whether a database schema was created implicitly - """ - created_schema = False - - if schema and not _is_qualified(table): - spark.sql(f"CREATE DATABASE IF NOT EXISTS `{_spark_ident(schema)}`") - created_schema = True - parts = [schema, table] - else: - parts = table.split(".") - - parts = [p for p in parts if p] - target_identifier = ".".join(parts) - target_sql = ".".join(f"`{_spark_ident(p)}`" for p in parts) - target_location = _spark_table_location(parts, spark) - return target_identifier, target_sql, target_location, created_schema - - -def _write_spark_seed_to_path( - spark: Any, - target_identifier: str, - sdf: Any, - storage_meta: dict[str, Any], - table_format: str | None, - table_options: dict[str, Any], -) -> str: - """Write the seed via a custom storage path configuration.""" - storage.spark_write_to_path( - spark, - target_identifier, - sdf, - storage=storage_meta, - default_format=table_format, - default_options=table_options, - ) - return "custom path" - - -def _write_spark_seed_managed( - executor: Any, - spark: SparkSession, - target_identifier: str, - sdf: SDF, - table_format: str | None, - table_options: dict[str, Any], -) -> str | None: - """ - Write the seed as a *managed* Spark table (no custom storage path). - - For engines like DatabricksSparkExecutor this ensures that the - table_format handler (Delta / Iceberg / etc.) is used, so Iceberg - seeds become proper Iceberg tables. - """ - try: - # Prefer engine-specific helper when available (e.g. DatabricksSparkExecutor) - save_df = getattr(executor, "_save_df_as_table", None) - if callable(save_df): - # Pass a truthy storage dict with path=None so we do *not* get - # redirected back into path-based storage again. - save_df( - target_identifier, - sdf, - storage={"path": None, "options": table_options or {}}, - ) - else: - # Generic Spark fallback: may not be format-aware, but keeps behavior - writer = sdf.write.mode("overwrite") - if table_format: - writer = writer.format(table_format) - if table_options: - writer = writer.options(**table_options) - writer.saveAsTable(target_identifier) - except Exception as exc: - raise RuntimeError( - f"Failed to materialize Spark seed '{target_identifier}' as managed table: {exc}" - ) from exc - # No temporary path to clean up for managed tables - return None - - -def _spark_write_table( - sdf: Any, - target_identifier: str, - table_format: str | None, - table_options: dict[str, Any], +def materialize_seed( + table: str, df: pd.DataFrame, executor: Any, schema: str | None = None ) -> None: - """Perform the actual Spark saveAsTable call with configured options.""" - writer = sdf.write.mode("overwrite") - if table_format: - writer = writer.format(table_format) - if table_options: - writer = writer.options(**table_options) - writer.saveAsTable(target_identifier) - - -def _reset_spark_table_and_location( - spark: Any, - target_sql: str, - target_location: Any, -) -> str | None: - """ - Drop the Spark table and remove the underlying location if possible. - - Returns: - A cleanup hint string (e.g. 'reset location') or None. - """ - with suppress(Exception): - spark.sql(f"DROP TABLE IF EXISTS {target_sql}") - - cleanup_hint: str | None = None - if target_location and target_location.exists(): - with suppress(Exception): - shutil.rmtree(target_location, ignore_errors=True) - cleanup_hint = "reset location" - return cleanup_hint - - -def _write_spark_seed_to_table( - spark: Any, - sdf: Any, - target_identifier: str, - target_sql: str, - target_location: Any, - table_format: str | None, - table_options: dict[str, Any], -) -> str | None: - """ - Write the seed as a managed Spark table, handling common location issues. - - Returns: - A cleanup hint string describing corrective actions, or None. - """ - cleanup_hint = _reset_spark_table_and_location(spark, target_sql, target_location) - - try: - _spark_write_table(sdf, target_identifier, table_format, table_options) - return cleanup_hint - except SparkAnalysisException as exc: - message = str(exc) - if target_location and "LOCATION_ALREADY_EXISTS" in message.upper(): - # Attempt to fix by resetting the table location and retrying once. - with suppress(Exception): - shutil.rmtree(target_location, ignore_errors=True) - cleanup_hint = "reset location" - _spark_write_table(sdf, target_identifier, table_format, table_options) - return cleanup_hint - raise RuntimeError(f"Spark seed load failed for {target_sql}: {exc}") from exc - except Exception as exc: # pragma: no cover - generic safety net - raise RuntimeError(f"Spark seed load failed for {target_sql}: {exc}") from exc - - -def _detect_spark_storage_format( - storage_meta: Any, - table_format: Any, -) -> str: - """ - Determine an effective storage format label (e.g. 'delta') from storage - metadata or executor configuration. - """ - storage_format = "" - if isinstance(storage_meta, dict): - raw_fmt = storage_meta.get("format") - if isinstance(raw_fmt, str) and raw_fmt.strip(): - storage_format = raw_fmt.strip().lower() - - if not storage_format and isinstance(table_format, str) and table_format.strip(): - storage_format = table_format.strip().lower() - - return storage_format - - -def _handle_spark( - table: str, - df: pd.DataFrame, - executor: Any, - schema: str | None, -) -> bool: - """Try to detect and handle Spark/Databricks for seeding.""" - spark = getattr(executor, "spark", None) - if spark is None: - return False - - target_identifier, target_sql, target_location, created_schema = _prepare_spark_target( - table=table, - schema=schema, - executor=executor, - spark=spark, - ) - - table_format = getattr(executor, "spark_table_format", None) - table_options = getattr(executor, "spark_table_options", None) or {} - format_handler = getattr(executor, "_format_handler", None) - - storage_meta = storage.get_seed_storage(target_identifier) - - t0 = perf_counter() - sdf = spark.createDataFrame(df) - - cleanup_hint: str | None = None - allows_unmanaged = bool(getattr(format_handler, "allows_unmanaged_paths", lambda: True)()) - - if storage_meta.get("path") and allows_unmanaged: - # Behavior for parquet/delta/etc: respect custom path. - cleanup_hint = _write_spark_seed_to_path( - spark=spark, - target_identifier=target_identifier, - sdf=sdf, - storage_meta=storage_meta, - table_format=table_format, - table_options=table_options, - ) - else: - # Behavior when no path is configured: table-based seed via executor handler - try: - if hasattr(executor, "_save_df_as_table"): - executor._save_df_as_table(target_identifier, sdf, storage={"path": None}) - cleanup_hint = None - else: - cleanup_hint = _write_spark_seed_to_table( - spark=spark, - sdf=sdf, - target_identifier=target_identifier, - target_sql=target_sql, - target_location=target_location, - table_format=table_format, - table_options=table_options, - ) - except Exception as exc: # pragma: no cover - raise RuntimeError(f"Spark seed load failed for {target_sql}: {exc}") from exc - - dt_ms = int((perf_counter() - t0) * 1000) - - storage_format = _detect_spark_storage_format(storage_meta, table_format) - engine_label = f"spark/{storage_format}" if storage_format else "spark" - - _echo_seed_line( - full_name=target_sql, - rows=len(df), - cols=df.shape[1], - engine=engine_label, - ms=dt_ms, - created_schema=created_schema, - action="replaced", - extra=cleanup_hint, - ) - return True - - -def _handle_snowflake_snowpark( - table: str, - df: pd.DataFrame, - executor: Any, - schema: str | None, -) -> bool: """ - Seed loader for SnowflakeSnowparkExecutor. - - Uses Session.write_pandas to create/overwrite the table in the configured - database + schema. + Materialize a DataFrame as a database table across engines. """ - if not isinstance(executor, SnowflakeSnowparkExecutor): - return False - - session = executor.session - target_db = getattr(executor, "database", None) - target_schema = schema or getattr(executor, "schema", None) - - if not target_db or not target_schema: - # Not enough info to build a fully-qualified target - return False - - # Optionally auto-create schema when allowed - created_schema = False - if getattr(executor, "allow_create_schema", False): - session.sql(f'CREATE SCHEMA IF NOT EXISTS "{target_db}"."{target_schema}"').collect() - created_schema = True - - full_name = _qualify(table, target_schema, target_db) - t0 = perf_counter() - # Use Snowpark's write_pandas: CREATE+OVERWRITE semantics - session.write_pandas( - df, - table_name=table, - database=target_db, - schema=target_schema, - auto_create_table=True, - quote_identifiers=False, - overwrite=True, - use_logical_type=True, - ) + result, full_name, created_schema = executor.load_seed(table, df, schema) dt_ms = int((perf_counter() - t0) * 1000) _echo_seed_line( full_name=full_name, rows=len(df), cols=df.shape[1], - engine="snowflake", + engine=executor.engine_name, ms=dt_ms, created_schema=created_schema, action="replaced", ) - return True - - -# ------------------------------------------------------------ -# Dispatcher -# ------------------------------------------------------------ - -Handler = Callable[[str, pd.DataFrame, Any, str | None], bool] - -_HANDLERS: Iterable[Handler] = ( - _handle_duckdb, - _handle_sqlalchemy, - _handle_spark, - _handle_bigquery, - _handle_snowflake_snowpark, -) - - -def materialize_seed( - table: str, df: pd.DataFrame, executor: Any, schema: str | None = None -) -> None: - """ - Materialize a DataFrame as a database table across engines. - """ - for handler in _HANDLERS: - if handler(table, df, executor, schema): - return - - raise RuntimeError("No compatible executor connection for seeding found.") + return result # ----------------------------------- Seeding runner ----------------------------------- @@ -1045,8 +436,8 @@ def seed_project(project_dir: Path, executor: Any, default_schema: str | None = df = _read_seed_file(path) # Use the resolved *table* key for schema enforcement (allows rename-aware mapping). - df = _apply_schema(df, table, schema_cfg) - df = _apply_column_schema(df, table, schema_cfg, executor) + df = _apply_schema(df, table, schema_cfg, seedid) + df = _apply_column_schema(df, table, schema_cfg, executor, seedid) df = _inject_seed_metadata(df, seedid, path) materialize_seed(table, df, executor, schema=schema) diff --git a/tests/unit/test_seeding_unit.py b/tests/unit/test_seeding_unit.py index ecec419..28c2d7c 100644 --- a/tests/unit/test_seeding_unit.py +++ b/tests/unit/test_seeding_unit.py @@ -4,13 +4,11 @@ import textwrap from pathlib import Path from types import SimpleNamespace -from typing import Any -from unittest.mock import MagicMock import pandas as pd import pytest -from fastflowtransform import seeding, storage +from fastflowtransform import seeding # --------------------------------------------------------------------------- # File I/O helpers @@ -49,7 +47,7 @@ def test_apply_schema_happy(): } schema_cfg = seeding.SeedsSchemaConfig.model_validate(cfg_raw) - out = seeding._apply_schema(df, "users", schema_cfg) + out = seeding._apply_schema(df, "users", schema_cfg, seed_id="users") assert str(out.dtypes["name"]).startswith("string") assert str(out.dtypes["age"]) in ("int64", "Int64") @@ -59,7 +57,7 @@ def test_apply_schema_ignores_missing_table_key(): df = pd.DataFrame({"id": [1]}) cfg_raw = {"dtypes": {"users": {"id": "int64"}}} schema_cfg = seeding.SeedsSchemaConfig.model_validate(cfg_raw) - out = seeding._apply_schema(df, "other", schema_cfg) + out = seeding._apply_schema(df, "other", schema_cfg, seed_id="other") assert out.equals(df) @@ -69,84 +67,10 @@ def test_apply_schema_soft_fails_on_bad_cast(): # force bad cast cfg_raw = {"dtypes": {"t": {"id": "int64"}}} schema_cfg = seeding.SeedsSchemaConfig.model_validate(cfg_raw) - out = seeding._apply_schema(df, "t", schema_cfg) + out = seeding._apply_schema(df, "t", schema_cfg, seed_id="t") assert len(out) == 1 -# --------------------------------------------------------------------------- -# Identifier helpers -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_dq_quotes_and_escapes(): - assert seeding._dq('a"b') == '"a""b"' - assert seeding._dq("tbl") == '"tbl"' - - -@pytest.mark.unit -def test_is_qualified(): - assert seeding._is_qualified("raw.users") is True - assert seeding._is_qualified("users") is False - - -@pytest.mark.unit -def test_qualify_unqualified_with_schema(): - out = seeding._qualify("users", "raw") - assert out == '"raw"."users"' - - -@pytest.mark.unit -def test_qualify_with_schema_and_catalog(): - out = seeding._qualify("users", "raw", "cat") - assert out == '"cat"."raw"."users"' - - -@pytest.mark.unit -def test_qualify_with_catalog_only(): - out = seeding._qualify("users", None, "cat") - assert out == '"cat"."users"' - - -@pytest.mark.unit -def test_qualify_already_qualified_preserves_parts(): - out = seeding._qualify("raw.users", None) - assert out == '"raw"."users"' - - -# --------------------------------------------------------------------------- -# Spark warehouse helpers -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_spark_warehouse_base_local(tmp_path: Path): - fake_spark = SimpleNamespace( - conf=SimpleNamespace(get=lambda key, default=None: str(tmp_path / "wh")) - ) - base = seeding._spark_warehouse_base(fake_spark) - assert base == (tmp_path / "wh") - - -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_spark_warehouse_base_remote_scheme(): - fake_spark = SimpleNamespace(conf=SimpleNamespace(get=lambda *_: "s3://bucket/warehouse")) - assert seeding._spark_warehouse_base(fake_spark) is None - - -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_spark_table_location_strips_catalog(tmp_path: Path): - # warehouse dir is local - fake_spark = SimpleNamespace(conf=SimpleNamespace(get=lambda *_: str(tmp_path / "wh"))) - parts = ["spark_catalog", "default", "mytable"] - loc = seeding._spark_table_location(parts, fake_spark) - # should resolve to /default.db/mytable - assert loc == tmp_path / "wh" / "default.db" / "mytable" - - # --------------------------------------------------------------------------- # Pretty helpers # --------------------------------------------------------------------------- @@ -199,25 +123,6 @@ def fake_echo(msg: str) -> None: # --------------------------------------------------------------------------- -@pytest.mark.unit -def test_engine_name_from_executor_spark(): - ex = SimpleNamespace(spark=object()) - assert seeding._engine_name_from_executor(ex) == "databricks_spark" - - -@pytest.mark.unit -def test_engine_name_from_executor_sqlalchemy_like(): - eng = SimpleNamespace(dialect=SimpleNamespace(name="postgres")) - ex = SimpleNamespace(engine=eng) - assert seeding._engine_name_from_executor(ex) == "postgres" - - -@pytest.mark.unit -def test_engine_name_from_executor_duckdb_like(): - ex = SimpleNamespace(con=object()) - assert seeding._engine_name_from_executor(ex) == "duckdb" - - @pytest.mark.unit def test_seed_id_simple(tmp_path: Path): seeds_dir = tmp_path / "seeds" @@ -251,7 +156,9 @@ def test_resolve_schema_and_table_by_cfg_priority_engine_override(): } } # executor pretending to be postgres - ex = SimpleNamespace(engine=SimpleNamespace(dialect=SimpleNamespace(name="postgres"))) + ex = SimpleNamespace( + engine_name="postgres", engine=SimpleNamespace(dialect=SimpleNamespace(name="postgres")) + ) schema_cfg = seeding.SeedsSchemaConfig.model_validate(schema_cfg_raw) @@ -281,259 +188,11 @@ def test_resolve_schema_and_table_falls_back_to_default_schema(): assert table == "users" -# --------------------------------------------------------------------------- -# Handlers: DuckDB -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -@pytest.mark.duckdb -def test_handle_duckdb_returns_false_for_non_duckdb_conn(): - executor = SimpleNamespace( - con=SimpleNamespace(register=lambda *a, **k: None, execute=lambda *a, **k: None) - ) - df = pd.DataFrame({"id": [1, 2]}) - - handled = seeding._handle_duckdb("users", df, executor, schema="raw") - - assert handled is False - - -@pytest.mark.unit -def test_handle_duckdb_returns_false_if_no_con(): - executor = SimpleNamespace() - df = pd.DataFrame({"id": [1]}) - handled = seeding._handle_duckdb("users", df, executor, schema=None) - assert handled is False - - -# --------------------------------------------------------------------------- -# Handlers: SQLAlchemy -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -@pytest.mark.postgres -def test_handle_sqlalchemy_happy(monkeypatch): - calls = {} - - class FakeEngine: - __module__ = "sqlalchemy.engine" # to trigger detection - dialect = SimpleNamespace(name="postgres") - - class FakeDF(pd.DataFrame): - def to_sql(self, name, eng, if_exists, index, schema, method): - calls["name"] = name - calls["schema"] = schema - calls["if_exists"] = if_exists - - df = FakeDF({"id": [1, 2]}) - executor = SimpleNamespace(engine=FakeEngine()) - - handled = seeding._handle_sqlalchemy("seed_tbl", df, executor, schema="raw") - assert handled is True - assert calls["name"] == "seed_tbl" - assert calls["schema"] == "raw" - assert calls["if_exists"] == "replace" - - -@pytest.mark.unit -def test_handle_sqlalchemy_returns_false_if_no_engine(): - df = pd.DataFrame({"id": [1]}) - executor = SimpleNamespace() - assert seeding._handle_sqlalchemy("t", df, executor, None) is False - - -@pytest.mark.unit -def test_handle_sqlalchemy_returns_false_if_engine_not_sqlalchemy(): - df = pd.DataFrame({"id": [1]}) - executor = SimpleNamespace(engine=SimpleNamespace(__module__="not.sqlalchemy")) - assert seeding._handle_sqlalchemy("t", df, executor, None) is False - - -# --------------------------------------------------------------------------- -# Handlers: Spark -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_handle_spark_happy_default_table(tmp_path: Path, monkeypatch): - # fake spark with local warehouse - fake_spark = MagicMock() - fake_spark.conf.get.return_value = str(tmp_path / "wh") - - # DataFrame path - fake_sdf = MagicMock() - fake_spark.createDataFrame.return_value = fake_sdf - - # writer chain - writer = MagicMock() - fake_sdf.write.mode.return_value = writer - writer.format.return_value = writer - writer.options.return_value = writer - - executor = SimpleNamespace( - spark=fake_spark, - spark_table_format="delta", - spark_table_options={"mergeSchema": "true"}, - ) - - df = pd.DataFrame({"id": [1]}) - handled = seeding._handle_spark("default.seed_tbl", df, executor, schema=None) - - assert handled is True - # drop table was attempted - fake_spark.sql.assert_any_call("DROP TABLE IF EXISTS `default`.`seed_tbl`") - # writer.saveAsTable called with identifier - writer.saveAsTable.assert_called_once_with("default.seed_tbl") - - -@pytest.mark.unit -@pytest.mark.databricks_spark -def test_handle_spark_uses_seed_storage(monkeypatch): - # 1) Seed-Storage Override setzen - storage.set_seed_storage( - {"raw.users": {"path": "/tmp/custom", "format": "parquet", "options": {"x": "1"}}} - ) - - # 2) Fake Spark + DataFrame - fake_spark = MagicMock() - fake_sdf = MagicMock() - fake_spark.createDataFrame.return_value = fake_sdf - - writer = MagicMock() - fake_sdf.write.mode.return_value = writer - writer.format.return_value = writer - writer.options.return_value = writer - - # 3) spark_write_to_path stubben, damit kein echtes FS angefasst wird - called: dict[str, Any] = {} - - def _fake_write_to_path( - spark, identifier, df, *, storage: dict, default_format, default_options - ): - called["spark"] = spark - called["identifier"] = identifier - called["df"] = df - called["storage"] = storage - called["default_format"] = default_format - called["default_options"] = default_options - - monkeypatch.setattr(seeding.storage, "spark_write_to_path", _fake_write_to_path) - - # 4) Executor-Stub - executor = SimpleNamespace( - spark=fake_spark, - spark_table_format=None, - spark_table_options=None, - ) - - df = pd.DataFrame({"id": [1]}) - - # 5) Aufruf - handled = seeding._handle_spark("raw.users", df, executor, schema=None) - assert handled is True - - # 6) Asserts - fake_spark.createDataFrame.assert_called_once_with(df) - assert called["spark"] is fake_spark - assert called["identifier"] == "raw.users" - assert called["storage"]["path"] == "/tmp/custom" - assert called["storage"]["format"] == "parquet" - - -# --------------------------------------------------------------------------- -# Dispatcher -# --------------------------------------------------------------------------- - - -@pytest.mark.unit -def test_materialize_seed_tries_all_and_raises(monkeypatch): - df = pd.DataFrame({"id": [1]}) - # executor without con/engine/spark - executor = SimpleNamespace() - - with pytest.raises(RuntimeError) as exc: - seeding.materialize_seed("t", df, executor, schema=None) - assert "No compatible executor" in str(exc.value) - - -@pytest.mark.unit -@pytest.mark.duckdb -def test_materialize_seed_stops_at_first_handler(monkeypatch): - df = pd.DataFrame({"id": [1]}) - - # first handler claims success - def h1(table, df, ex, schema): - return True - - # second handler should not be called - called = {"h2": False} - - def h2(table, df, ex, schema): - called["h2"] = True - return True - - monkeypatch.setattr(seeding, "_HANDLERS", (h1, h2)) - - seeding.materialize_seed("t", df, SimpleNamespace(), schema=None) - assert called["h2"] is False - - # --------------------------------------------------------------------------- # seed_project # --------------------------------------------------------------------------- -@pytest.mark.unit -@pytest.mark.duckdb -def test_seed_project_happy_duckdb(tmp_path: Path, monkeypatch): - # project structure - seeds_dir = tmp_path / "seeds" - seeds_dir.mkdir() - (seeds_dir / "raw").mkdir() - (seeds_dir / "raw" / "users.csv").write_text("id,name\n1,A\n", encoding="utf-8") - - # fake duckdb-like executor - exec_calls = [] - - class FakeCon: - def register(self, name, df): - exec_calls.append(("register", name)) - - def execute(self, sql): - exec_calls.append(("execute", sql)) - - def unregister(self, name): - exec_calls.append(("unregister", name)) - - executor = SimpleNamespace(con=FakeCon(), schema="public") - - # IMPORTANT: in environments where duckdb is installed, _handle_duckdb() - # does an isinstance(...) against the real DuckDB connection type. - # That would make our FakeCon fail. So we just force the handler to succeed. - def fake_handle_duckdb(table, df, ex, schema): - # simulate the real duckdb handler a bit - full_name = seeding._qualify(table, schema) - ex.con.register("_tmp", df) - ex.con.execute(f'create or replace table {full_name} as select * from "_tmp"') - ex.con.unregister("_tmp") - return True - - handlers = tuple(seeding._HANDLERS) - monkeypatch.setattr( - seeding, - "_HANDLERS", - (fake_handle_duckdb, *handlers[1:]), - ) - - count = seeding.seed_project(tmp_path, executor, default_schema=None) - assert count == 1 - # we should have a create or replace in there - assert any("create or replace table" in sql for (op, sql) in exec_calls if op == "execute") - - @pytest.mark.unit def test_seed_project_no_seeds_dir(tmp_path: Path): executor = SimpleNamespace() @@ -561,7 +220,7 @@ def test_seed_project_ambiguous_stems_raises(tmp_path: Path): encoding="utf-8", ) - executor = SimpleNamespace(schema="public") + executor = SimpleNamespace(schema="public", engine_name="duckdb") with pytest.raises(ValueError) as exc: seeding.seed_project(tmp_path, executor, default_schema=None) From b944b09a6cba284633ed23a58098ee7e4da9c78c Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Wed, 10 Dec 2025 09:12:05 +0100 Subject: [PATCH 3/7] Harmonized identifier logic in executors --- .../executors/_sql_identifier.py | 112 +++++++++++++++--- .../executors/bigquery/_bigquery_mixin.py | 82 ------------- .../executors/bigquery/base.py | 80 ++++++++++++- src/fastflowtransform/executors/duckdb.py | 9 +- src/fastflowtransform/executors/postgres.py | 12 +- .../executors/snowflake_snowpark.py | 28 +++-- tests/common/mock/bigquery.py | 18 ++- .../test_buildins_var_this_integration.py | 2 +- .../executors/test_bigquery_bf_exec_unit.py | 7 +- .../unit/executors/test_bigquery_exec_unit.py | 7 +- tests/unit/executors/test_duckdb_exec_unit.py | 12 -- .../unit/executors/test_postgres_exec_unit.py | 8 -- tests/unit/render/test_this_proxy_unit.py | 2 +- tests/unit/render/test_this_relation_unit.py | 6 +- 14 files changed, 223 insertions(+), 162 deletions(-) delete mode 100644 src/fastflowtransform/executors/bigquery/_bigquery_mixin.py diff --git a/src/fastflowtransform/executors/_sql_identifier.py b/src/fastflowtransform/executors/_sql_identifier.py index 2a8a3ae..c636c56 100644 --- a/src/fastflowtransform/executors/_sql_identifier.py +++ b/src/fastflowtransform/executors/_sql_identifier.py @@ -85,9 +85,31 @@ def _qualify_identifier( return ".".join(parts) return ".".join(self._quote_identifier(p) for p in parts) + # ---- Identifier normalization helpers ----------------------------------- + def _normalize_table_identifier(self, table: str) -> tuple[str | None, str]: + """ + Normalize a possibly qualified/quoted table identifier into (schema, table). + + - Strip simple quoting (`"`/`` ` ``) from each part. + - Accept up to 3-part names (catalog.schema.table) and drop the catalog. + - Return (schema, table) with schema possibly None. + """ + raw_parts = [p for p in table.split(".") if p] + parts = [p.strip().strip('`"') for p in raw_parts] + + if len(parts) >= 2: + return parts[-2] or None, parts[-1] + + table_name = parts[0] if parts else table + return None, table_name + + def _normalize_column_identifier(self, column: str) -> str: + """Strip simple quoting from a column identifier.""" + return column.strip().strip('`"') + # ---- Shared formatting hooks ----------------------------------------- - def _format_relation_for_ref(self, name: str) -> str: - return self._qualify_identifier(relation_for(name)) + # def _format_relation_for_ref(self, name: str) -> str: + # return self._qualify_identifier(relation_for(name)) def _pick_schema(self, cfg: dict[str, Any]) -> str | None: for key in ("schema", "dataset"): @@ -103,19 +125,79 @@ def _pick_catalog(self, cfg: dict[str, Any], schema: str | None) -> str | None: return candidate return self._default_catalog_for_source(schema) + # ---- Unified formatting entrypoint ----------------------------------- + def _format_identifier( + self, + name: str, + *, + purpose: str, + schema: str | None = None, + catalog: str | None = None, + quote: bool = True, + source_cfg: dict[str, Any] | None = None, + source_name: str | None = None, + table_name: str | None = None, + ) -> str: + """ + Central formatter for all identifier use-cases. + + purpose: + - "ref" / "this" / "test" / "seed" / "physical": qualify `name` + using defaults and optional overrides. + - "source": qualify based on a resolved source config (identifier + + optional schema/catalog); rejects path-based sources here. + """ + normalized = self._normalize_identifier(name) + + if purpose == "source": + cfg = dict(source_cfg or {}) + if cfg.get("location"): + raise NotImplementedError( + f"{getattr(self, 'engine_name', 'unknown')} executor " + "does not support path-based sources." + ) + + ident = cfg.get("identifier") or normalized + if not ident: + raise KeyError( + f"Source {source_name or ''}.{table_name or ''} " + "missing identifier" + ) + sch = self._clean_part(schema) or self._pick_schema(cfg) + cat = self._clean_part(catalog) or self._pick_catalog(cfg, sch) + return self._qualify_identifier(ident, schema=sch, catalog=cat, quote=quote) + + if purpose in {"ref", "this", "test", "seed", "physical"}: + sch = self._clean_part(schema) + cat = self._clean_part(catalog) + return self._qualify_identifier(normalized, schema=sch, catalog=cat, quote=quote) + + raise ValueError(f"Unknown identifier purpose: {purpose!r}") + + # ---- Default delegations using the unified formatter ------------------ + def _format_relation_for_ref(self, name: str) -> str: + return self._format_identifier(name, purpose="ref") + def _format_source_reference( self, cfg: dict[str, Any], source_name: str, table_name: str ) -> str: - if cfg.get("location"): - raise NotImplementedError( - f"{getattr(self, 'engine_name', 'unknown')} executor " - "does not support path-based sources." - ) - - ident = cfg.get("identifier") - if not ident: - raise KeyError(f"Source {source_name}.{table_name} missing identifier") - - schema = self._pick_schema(cfg) - catalog = self._pick_catalog(cfg, schema) - return self._qualify_identifier(ident, schema=schema, catalog=catalog) + return self._format_identifier( + cfg.get("identifier") or table_name, + purpose="source", + source_cfg=cfg, + source_name=source_name, + table_name=table_name, + ) + + def _format_test_table(self, table: str | None) -> str | None: + table = super()._format_test_table(table) # type: ignore[misc] + if not isinstance(table, str): + return table + return self._format_identifier(table, purpose="test") + + def _this_identifier(self, node: Any) -> str: + """ + Default {{ this }} identifier: reuse the formatter with logical name. + """ + name = getattr(node, "name", node) + return self._format_identifier(str(name), purpose="this") diff --git a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py b/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py deleted file mode 100644 index b4c0c2b..0000000 --- a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py +++ /dev/null @@ -1,82 +0,0 @@ -# fastflowtransform/executors/_bigquery_mixin.py -from __future__ import annotations - -from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin -from fastflowtransform.typing import NotFound, bigquery - - -class BigQueryIdentifierMixin(SqlIdentifierMixin): - """ - Mixin that provides common BigQuery helpers (identifier quoting, dataset creation). - Expect subclasses to define: self.project, self.dataset, self.client. - """ - - project: str - dataset: str - client: bigquery.Client - - def _bq_quote(self, value: str) -> str: - return value.replace("`", "\\`") - - def _quote_identifier(self, ident: str) -> str: - return self._bq_quote(ident) - - def _default_schema(self) -> str | None: - return self.dataset - - def _default_catalog(self) -> str | None: - return self.project - - def _should_include_catalog( - self, catalog: str | None, schema: str | None, *, explicit: bool - ) -> bool: - # BigQuery always expects a project + dataset. - return True - - def _qualify_identifier( - self, - ident: str, - *, - schema: str | None = None, - catalog: str | None = None, - quote: bool = True, - ) -> str: - proj = self._clean_part(catalog) or self._default_catalog() - dset = self._clean_part(schema) or self._default_schema() - normalized = self._normalize_identifier(ident) - parts = [proj, dset, normalized] - if not quote: - return ".".join(p for p in parts if p) - return f"`{'.'.join(self._bq_quote(p) for p in parts if p)}`" - - def _qualified_identifier( - self, relation: str, project: str | None = None, dataset: str | None = None - ) -> str: - return self._qualify_identifier(relation, schema=dataset, catalog=project) - - def _qualified_api_identifier( - self, relation: str, project: str | None = None, dataset: str | None = None - ) -> str: - """ - Build an API-safe identifier (project.dataset.table) without backticks. - """ - return self._qualify_identifier( - relation, - schema=dataset, - catalog=project, - quote=False, - ) - - def _ensure_dataset(self) -> None: - ds_id = f"{self.project}.{self.dataset}" - try: - self.client.get_dataset(ds_id) - return - except NotFound: - if not getattr(self, "allow_create_dataset", False): - raise - - ds_obj = bigquery.Dataset(ds_id) - if getattr(self, "location", None): - ds_obj.location = self.location # type: ignore[attr-defined] - self.client.create_dataset(ds_obj, exists_ok=True) diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index 1a16a9b..ad550a7 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -7,9 +7,9 @@ from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin +from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors._test_utils import make_fetchable from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.executors.bigquery._bigquery_mixin import BigQueryIdentifierMixin from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import _TrackedQueryJob from fastflowtransform.meta import ensure_meta_table, upsert_meta @@ -18,7 +18,7 @@ TFrame = TypeVar("TFrame") -class BigQueryBaseExecutor(BigQueryIdentifierMixin, SnapshotSqlMixin, BaseExecutor[TFrame]): +class BigQueryBaseExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[TFrame]): """ Shared BigQuery executor logic (SQL, incremental, meta, DQ helpers). @@ -56,6 +56,73 @@ def __init__( location=self.location, ) + # ---- Identifier helpers ---- + def _bq_quote(self, value: str) -> str: + return value.replace("`", "\\`") + + def _quote_identifier(self, ident: str) -> str: + return self._bq_quote(ident) + + def _default_schema(self) -> str | None: + return self.dataset + + def _default_catalog(self) -> str | None: + return self.project + + def _should_include_catalog( + self, catalog: str | None, schema: str | None, *, explicit: bool + ) -> bool: + # BigQuery always expects a project + dataset. + return True + + def _qualify_identifier( + self, + ident: str, + *, + schema: str | None = None, + catalog: str | None = None, + quote: bool = True, + ) -> str: + proj = self._clean_part(catalog) or self._default_catalog() + dset = self._clean_part(schema) or self._default_schema() + normalized = self._normalize_identifier(ident) + parts = [proj, dset, normalized] + if not quote: + return ".".join(p for p in parts if p) + return f"`{'.'.join(self._bq_quote(p) for p in parts if p)}`" + + def _qualified_identifier( + self, relation: str, project: str | None = None, dataset: str | None = None + ) -> str: + return self._qualify_identifier(relation, schema=dataset, catalog=project) + + def _qualified_api_identifier( + self, relation: str, project: str | None = None, dataset: str | None = None + ) -> str: + """ + Build an API-safe identifier (project.dataset.table) without backticks. + """ + return self._qualify_identifier( + relation, + schema=dataset, + catalog=project, + quote=False, + ) + + def _ensure_dataset(self) -> None: + ds_id = f"{self.project}.{self.dataset}" + try: + self.client.get_dataset(ds_id) + return + except NotFound: + if not getattr(self, "allow_create_dataset", False): + raise + + ds_obj = bigquery.Dataset(ds_id) + if getattr(self, "location", None): + ds_obj.location = self.location + self.client.create_dataset(ds_obj, exists_ok=True) + def execute_test_sql(self, stmt: Any) -> Any: """ Execute lightweight SQL for DQ tests using the BigQuery client. @@ -173,9 +240,14 @@ def _format_test_table(self, table: str | None) -> str | None: Ensure tests use fully-qualified BigQuery identifiers in fft test. """ table = super()._format_test_table(table) - if not isinstance(table, str) or not table.strip(): + if not isinstance(table, str): return table - return self._qualified_identifier(table.strip()) + stripped = table.strip() + if not stripped or stripped.startswith("`"): + return stripped + if "." in stripped: + return stripped + return self._qualified_identifier(stripped) # ---- SQL hooks ---- def _this_identifier(self, node: Node) -> str: diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index 8f6ebc8..ff70026 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -488,7 +488,7 @@ def _qualified(self, relation: str, *, quoted: bool = True) -> str: Return (catalog.)schema.relation if schema is set; otherwise just relation. When quoted=False, emit bare identifiers for APIs like con.table(). """ - return self._qualify_identifier(relation, quote=quoted) + return self._format_identifier(relation, purpose="physical", quote=quoted) def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> pd.DataFrame: try: @@ -706,13 +706,10 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None """ DuckDB: read `data_type` from information_schema.columns. """ - if "." in table: - schema, table_name = table.split(".", 1) - else: - schema, table_name = None, table + schema, table_name = self._normalize_table_identifier(table) table_lower = table_name.lower() - column_lower = column.lower() + column_lower = self._normalize_column_identifier(column).lower() if schema: rows = self._execute_sql( diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 9e83296..fc6b6a3 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -305,7 +305,7 @@ def _quote_identifier(self, ident: str) -> str: return self._q_ident(ident) def _qualified(self, relname: str, schema: str | None = None) -> str: - return self._qualify_identifier(relname, schema=schema) + return self._format_identifier(relname, purpose="physical", schema=schema) def _set_search_path(self, conn: Connection) -> None: if self.schema: @@ -699,10 +699,8 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None Postgres: read `data_type` from information_schema.columns for the current schema (or an explicit schema if table is qualified). """ - if "." in table: - schema, table_name = table.split(".", 1) - else: - schema, table_name = None, table + schema, table_name = self._normalize_table_identifier(table) + column_name = self._normalize_column_identifier(column) if schema: sql = """ @@ -713,7 +711,7 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None and lower(column_name) = lower(:column) limit 1 """ - params = {"schema": schema, "table": table_name, "column": column} + params = {"schema": schema, "table": table_name, "column": column_name} else: sql = """ select data_type @@ -723,7 +721,7 @@ def introspect_column_physical_type(self, table: str, column: str) -> str | None and lower(column_name) = lower(:column) limit 1 """ - params = {"table": table_name, "column": column} + params = {"table": table_name, "column": column_name} rows = self._execute_sql(sql, params).fetchall() return rows[0][0] if rows else None diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 602e4e9..0f4913b 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -192,7 +192,7 @@ def _should_include_catalog( def _qualified(self, rel: str) -> str: # DATABASE.SCHEMA.TABLE (no quotes) - return self._qualify_identifier(rel, quote=False) + return self._format_identifier(rel, purpose="physical", quote=False) def _ensure_schema(self) -> None: """ @@ -372,25 +372,28 @@ def _this_identifier(self, node: Node) -> str: Identifier for {{ this }} in SQL models. Use fully-qualified DB.SCHEMA.TABLE so all build/read/test paths agree. """ - return self._qualify_identifier(relation_for(node.name), quote=False) + return self._format_identifier(relation_for(node.name), purpose="this", quote=False) def _format_source_reference( self, cfg: dict[str, Any], source_name: str, table_name: str ) -> str: - if cfg.get("location"): - raise NotImplementedError("Snowflake executor does not support path-based sources.") - ident = cfg.get("identifier") if not ident: raise KeyError(f"Source {source_name}.{table_name} missing identifier") - - sch = self._pick_schema(cfg) - db = self._pick_catalog(cfg, sch) - if not db or not sch: + formatted = self._format_identifier( + ident, + purpose="source", + source_cfg=cfg, + source_name=source_name, + table_name=table_name, + quote=False, + ) + # Ensure we resolved to DB.SCHEMA.TABLE; Snowflake needs both parts. + if "." not in formatted: raise KeyError( f"Source {source_name}.{table_name} missing database/schema for Snowflake" ) - return self._qualify_identifier(ident, schema=sch, catalog=db, quote=False) + return formatted def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() @@ -406,7 +409,8 @@ def _create_or_replace_view_from_table( self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").collect() def _format_test_table(self, table: str | None) -> str | None: - formatted = super()._format_test_table(table) + # Bypass mixin qualification to avoid double-qualifying already dotted names. + formatted = BaseExecutor._format_test_table(self, table) if formatted is None: return None @@ -416,7 +420,7 @@ def _format_test_table(self, table: str | None) -> str | None: # Otherwise, treat it as a logical relation name and fully-qualify it # with the executor's configured database/schema. - return self._qualified(formatted) + return self._format_identifier(formatted, purpose="test", quote=False) # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: diff --git a/tests/common/mock/bigquery.py b/tests/common/mock/bigquery.py index 4ed702a..a0a0532 100644 --- a/tests/common/mock/bigquery.py +++ b/tests/common/mock/bigquery.py @@ -131,7 +131,10 @@ def __init__(self, project: str, location: str | None = None): # ---- Test helper ---- def add_dataset(self, ds_id: str) -> None: - self._datasets.setdefault(ds_id, FakeDataset(ds_id)) + key = ds_id + if "." not in key: + key = f"{self.project}.{key}" + self._datasets.setdefault(key, FakeDataset(key)) def add_table(self, dataset_id: str, table_id: str) -> None: self._tables.setdefault(dataset_id, []).append(SimpleNamespace(table_id=table_id)) @@ -221,13 +224,22 @@ def get_table(self, table_ref: str): raise FakeNotFound(f"table {table_ref} not found") def get_dataset(self, ds_id: str): - ds = self._datasets.get(ds_id) + key = ds_id + if "." not in key: + key = f"{self.project}.{key}" + ds = self._datasets.get(key) if ds is None: raise FakeNotFound(f"dataset {ds_id} not found") return ds def create_dataset(self, ds_obj: Any, exists_ok: bool | None = None): - ds_id = getattr(ds_obj, "dataset_id", ds_obj) + ds_id_raw = getattr(ds_obj, "dataset_id", ds_obj) + ds_id = ds_id_raw + if isinstance(ds_id_raw, FakeDatasetReference): + ds_id = f"{ds_id_raw.project}.{ds_id_raw.dataset_id}" + if isinstance(ds_id, str) and "." not in ds_id: + ds_id = f"{self.project}.{ds_id}" + ds = self._datasets.get(ds_id) if ds is None or not exists_ok: ds = FakeDataset(ds_id) diff --git a/tests/integration/core/test_buildins_var_this_integration.py b/tests/integration/core/test_buildins_var_this_integration.py index 2f9c3e8..aab4328 100644 --- a/tests/integration/core/test_buildins_var_this_integration.py +++ b/tests/integration/core/test_buildins_var_this_integration.py @@ -35,4 +35,4 @@ def test_var_overrides_and_this_object(tmp_path: Path): row = ex.con.execute("select * from m").fetchone() assert row is not None assert row[0] == "2099-01-01" - assert row[1] == "m" + assert row[1] == '"m"' diff --git a/tests/unit/executors/test_bigquery_bf_exec_unit.py b/tests/unit/executors/test_bigquery_bf_exec_unit.py index 38eda27..f979631 100644 --- a/tests/unit/executors/test_bigquery_bf_exec_unit.py +++ b/tests/unit/executors/test_bigquery_bf_exec_unit.py @@ -17,7 +17,6 @@ install_fake_bigquery, ) -import fastflowtransform.executors.bigquery._bigquery_mixin as bq_mix_mod import fastflowtransform.executors.bigquery.base as bq_base_mod import fastflowtransform.executors.bigquery.bigframes as bq_exec_mod from fastflowtransform.core import Node @@ -47,7 +46,7 @@ def read_gbq(self, table_id: str) -> Any: @pytest.fixture def bq_exec(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) # Test-only shim: ensure the fake bigquery module has DatasetReference, # which BigQueryBaseExecutor._execute_sql now relies on. @@ -141,7 +140,7 @@ def to_gbq(self, table_id, if_exists="replace"): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_respects_flag(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) fake_bigframes = types.ModuleType("bigframes") fake_conf = types.ModuleType("bigframes._config") @@ -170,7 +169,7 @@ def test_ensure_dataset_respects_flag(monkeypatch): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_creates_when_allowed(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod, bq_base_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_base_mod]) fake_bigframes = types.ModuleType("bigframes") fake_conf = types.ModuleType("bigframes._config") diff --git a/tests/unit/executors/test_bigquery_exec_unit.py b/tests/unit/executors/test_bigquery_exec_unit.py index 21c4d10..71f9586 100644 --- a/tests/unit/executors/test_bigquery_exec_unit.py +++ b/tests/unit/executors/test_bigquery_exec_unit.py @@ -17,7 +17,6 @@ install_fake_bigquery, ) -import fastflowtransform.executors.bigquery._bigquery_mixin as bq_mix_mod import fastflowtransform.executors.bigquery.base as bq_base_mod import fastflowtransform.executors.bigquery.pandas as bq_exec_mod from fastflowtransform.core import Node @@ -26,7 +25,7 @@ @pytest.fixture def bq_exec(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) fake_client = FakeClient(project="p1", location="EU") @@ -172,7 +171,7 @@ def test_format_source_reference(bq_exec): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_respects_flag(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) fake_client = FakeClient(project="p1", location="EU") ex = bq_exec_mod.BigQueryExecutor( @@ -190,7 +189,7 @@ def test_ensure_dataset_respects_flag(monkeypatch): @pytest.mark.unit @pytest.mark.bigquery def test_ensure_dataset_creates_when_allowed(monkeypatch): - _ = install_fake_bigquery(monkeypatch, [bq_exec_mod, bq_mix_mod]) + _ = install_fake_bigquery(monkeypatch, [bq_exec_mod]) fake_client = FakeClient(project="p1", location="EU") ex = bq_exec_mod.BigQueryExecutor( diff --git a/tests/unit/executors/test_duckdb_exec_unit.py b/tests/unit/executors/test_duckdb_exec_unit.py index be7aa6e..3f87354 100644 --- a/tests/unit/executors/test_duckdb_exec_unit.py +++ b/tests/unit/executors/test_duckdb_exec_unit.py @@ -132,18 +132,6 @@ def test_format_source_reference_ok(duck_exec: DuckExecutor): assert ref == '"c1"."s1"."src_tbl"' -@pytest.mark.unit -@pytest.mark.duckdb -def test_format_source_reference_missing_identifier_raises(duck_exec: DuckExecutor): - cfg = { - "catalog": "c1", - "schema": "s1", - # no identifier! - } - with pytest.raises(KeyError): - duck_exec._format_source_reference(cfg, "src", "tbl") - - @pytest.mark.unit @pytest.mark.duckdb def test_format_source_reference_path_not_supported(duck_exec: DuckExecutor): diff --git a/tests/unit/executors/test_postgres_exec_unit.py b/tests/unit/executors/test_postgres_exec_unit.py index 46076cc..8b20c56 100644 --- a/tests/unit/executors/test_postgres_exec_unit.py +++ b/tests/unit/executors/test_postgres_exec_unit.py @@ -283,14 +283,6 @@ def test_format_source_reference_with_db_and_schema(fake_engine_and_conn): assert out == '"mydb"."other"."t_src"' -@pytest.mark.unit -@pytest.mark.postgres -def test_format_source_reference_missing_identifier(fake_engine_and_conn): - ex = PostgresExecutor("postgresql+psycopg://x", schema="public") - with pytest.raises(KeyError): - ex._format_source_reference({}, "src", "t") - - # --------------------------------------------------------------------------- # view / table creation # --------------------------------------------------------------------------- diff --git a/tests/unit/render/test_this_proxy_unit.py b/tests/unit/render/test_this_proxy_unit.py index a4ec945..666576c 100644 --- a/tests/unit/render/test_this_proxy_unit.py +++ b/tests/unit/render/test_this_proxy_unit.py @@ -16,4 +16,4 @@ def test_this_string_and_name(tmp_path: Path): env = Environment() ex = DuckExecutor() sql = ex.render_sql(node, env).strip().lower() - assert sql == "select 'm' as a, 'm' as b" + assert sql == "select '\"m\"' as a, '\"m\"' as b" diff --git a/tests/unit/render/test_this_relation_unit.py b/tests/unit/render/test_this_relation_unit.py index 71d0efe..5f08b16 100644 --- a/tests/unit/render/test_this_relation_unit.py +++ b/tests/unit/render/test_this_relation_unit.py @@ -1,6 +1,6 @@ # tests/unit/render/test_this_relation_unit.py import pytest -from jinja2 import Environment, FileSystemLoader, select_autoescape +from jinja2 import Environment, FileSystemLoader from fastflowtransform.core import Node from fastflowtransform.executors.duckdb import DuckExecutor @@ -9,7 +9,7 @@ def _env_for_tests() -> Environment: return Environment( loader=FileSystemLoader(["."]), - autoescape=select_autoescape([]), + autoescape=False, trim_blocks=True, lstrip_blocks=True, ) @@ -29,4 +29,4 @@ def test_this_renders_physical_relation(tmp_path): rendered = ex.render_sql(node, env).strip() # Assert - assert rendered.lower() == "select 'm' as rel" + assert rendered.lower() == "select '\"m\"' as rel" From acf0f84f65a1e5e25553a99a83f5076cae3faa27 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Thu, 11 Dec 2025 18:01:37 +0100 Subject: [PATCH 4/7] Added contract runtime for duckdb --- src/fastflowtransform/cli/run.py | 29 +++ src/fastflowtransform/cli/test_cmd.py | 2 +- src/fastflowtransform/config/contracts.py | 74 +++++- src/fastflowtransform/contracts/__init__.py | 0 .../{contracts.py => contracts/core.py} | 2 +- .../contracts/runtime/__init__.py | 0 .../contracts/runtime/base.py | 222 +++++++++++++++++ .../contracts/runtime/duckdb.py | 230 ++++++++++++++++++ src/fastflowtransform/executors/base.py | 71 +++++- src/fastflowtransform/executors/duckdb.py | 96 +++++--- src/fastflowtransform/incremental.py | 54 ++++ tests/unit/test_contracts_unit.py | 2 +- 12 files changed, 743 insertions(+), 39 deletions(-) create mode 100644 src/fastflowtransform/contracts/__init__.py rename src/fastflowtransform/{contracts.py => contracts/core.py} (99%) create mode 100644 src/fastflowtransform/contracts/runtime/__init__.py create mode 100644 src/fastflowtransform/contracts/runtime/base.py create mode 100644 src/fastflowtransform/contracts/runtime/duckdb.py diff --git a/src/fastflowtransform/cli/run.py b/src/fastflowtransform/cli/run.py index d573c3b..5946466 100644 --- a/src/fastflowtransform/cli/run.py +++ b/src/fastflowtransform/cli/run.py @@ -58,6 +58,7 @@ load_budgets_config, ) from fastflowtransform.config.project import HookSpec +from fastflowtransform.contracts.core import _load_project_contracts, load_contracts from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.dag import levels as dag_levels from fastflowtransform.executors.base import BaseExecutor @@ -1810,6 +1811,34 @@ def run( engine_.invocation_id = uuid4().hex engine_.run_started_at = datetime.now(UTC).isoformat(timespec="seconds") + # ---------- Runtime contracts: load + configure executor ---------- + try: + project_dir = Path(ctx.project) + except TypeError: + project_dir = Path(str(ctx.project)) + + try: + contracts_by_table = load_contracts(project_dir) + project_contracts = _load_project_contracts(project_dir) + except Exception as exc: + # If contracts parsing blows up, you can either: + # - treat it as fatal (like budgets.yml), or + # - log a warning and continue without contracts. + # For now we log and proceed, contracts are optional. + warn(f"[contracts] Failed to load contracts from {project_dir}: {exc}") + contracts_by_table = {} + project_contracts = None + + # engine_.shared is (executor, run_sql_fn, run_py_fn) + try: + executor, _, _ = engine_.shared + except Exception: + executor = None + + if executor is not None and hasattr(executor, "configure_contracts"): + with suppress(Exception): + executor.configure_contracts(contracts_by_table, project_contracts) + bind_context( engine=ctx.profile.engine, env=env_name, diff --git a/src/fastflowtransform/cli/test_cmd.py b/src/fastflowtransform/cli/test_cmd.py index 48b4a82..e8dca96 100644 --- a/src/fastflowtransform/cli/test_cmd.py +++ b/src/fastflowtransform/cli/test_cmd.py @@ -24,7 +24,7 @@ BaseProjectTestConfig, parse_project_yaml_config, ) -from fastflowtransform.contracts import load_contract_tests +from fastflowtransform.contracts.core import load_contract_tests from fastflowtransform.core import REGISTRY from fastflowtransform.dag import topo_sort from fastflowtransform.errors import ModelExecutionError diff --git a/src/fastflowtransform/config/contracts.py b/src/fastflowtransform/config/contracts.py index cf34780..3eaf496 100644 --- a/src/fastflowtransform/config/contracts.py +++ b/src/fastflowtransform/config/contracts.py @@ -2,7 +2,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any +from typing import Any, Literal import yaml from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -10,6 +10,8 @@ from fastflowtransform.config.loaders import NoDupLoader from fastflowtransform.errors import ContractsConfigError +SchemaEnforcementMode = Literal["off", "verify", "cast"] + class PhysicalTypeConfig(BaseModel): """ @@ -131,6 +133,23 @@ def _coerce_physical(cls, v: Any) -> Any: ) +class TableSchemaEnforcementModel(BaseModel): + """ + Per-table runtime schema enforcement configuration. + + Example in *.contracts.yml: + + enforce_schema: + mode: cast # off | verify | cast + allow_extra_columns: false + """ + + model_config = ConfigDict(extra="forbid") + + mode: SchemaEnforcementMode = "off" + allow_extra_columns: bool = True + + class ContractsFileModel(BaseModel): """ One contracts file. @@ -162,6 +181,11 @@ class ContractsFileModel(BaseModel): table: str = Field(..., description="Logical/physical table name the contract applies to") columns: dict[str, ColumnContractModel] = Field(default_factory=dict) + enforce_schema: TableSchemaEnforcementModel | None = Field( + default=None, + description="Optional runtime schema enforcement config for this table", + ) + # --------------------------------------------------------------------------- # Project-level contracts (contracts.yml at project root) @@ -281,6 +305,49 @@ class ContractsDefaultsModel(BaseModel): columns: list[ColumnDefaultsRuleModel] = Field(default_factory=list) +class TableSchemaEnforcementOverrideModel(BaseModel): + """ + Per-table override in project-level contracts.yml + + Example: + + enforcement: + tables: + customers: + mode: cast + allow_extra_columns: false + """ + + model_config = ConfigDict(extra="forbid") + + mode: SchemaEnforcementMode | None = None + allow_extra_columns: bool | None = None + + +class ProjectSchemaEnforcementModel(BaseModel): + """ + Project-level schema enforcement defaults (contracts.yml). + + Example: + + version: 1 + + enforcement: + default_mode: verify # off | verify | cast + allow_extra_columns: true + tables: + customers: + mode: cast + allow_extra_columns: false + """ + + model_config = ConfigDict(extra="forbid") + + default_mode: SchemaEnforcementMode = "off" + allow_extra_columns: bool = True + tables: dict[str, TableSchemaEnforcementOverrideModel] = Field(default_factory=dict) + + class ProjectContractsModel(BaseModel): """ Top-level model for project-level contracts.yml. @@ -294,6 +361,11 @@ class ProjectContractsModel(BaseModel): version: int = 1 defaults: ContractsDefaultsModel = Field(default_factory=ContractsDefaultsModel) + enforcement: ProjectSchemaEnforcementModel | None = Field( + default=None, + description="Runtime schema enforcement defaults and per-table overrides", + ) + # ---- Parsers ----------------------------------------------------------------- diff --git a/src/fastflowtransform/contracts/__init__.py b/src/fastflowtransform/contracts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastflowtransform/contracts.py b/src/fastflowtransform/contracts/core.py similarity index 99% rename from src/fastflowtransform/contracts.py rename to src/fastflowtransform/contracts/core.py index bd2c4a7..088dad8 100644 --- a/src/fastflowtransform/contracts.py +++ b/src/fastflowtransform/contracts/core.py @@ -1,4 +1,4 @@ -# fastflowtransform/contracts.py +# fastflowtransform/contracts/core.py from __future__ import annotations import re diff --git a/src/fastflowtransform/contracts/runtime/__init__.py b/src/fastflowtransform/contracts/runtime/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fastflowtransform/contracts/runtime/base.py b/src/fastflowtransform/contracts/runtime/base.py new file mode 100644 index 0000000..cbbdef2 --- /dev/null +++ b/src/fastflowtransform/contracts/runtime/base.py @@ -0,0 +1,222 @@ +# fastflowtransform/contracts/runtime/base.by +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Protocol, TypeVar + +from fastflowtransform.config.contracts import ( + ContractsFileModel, + PhysicalTypeConfig, + ProjectContractsModel, + SchemaEnforcementMode, +) +from fastflowtransform.core import Node + + +class ContractExecutor(Protocol): + """ + Minimal surface that runtime contracts are allowed to use on an executor. + + Every engine that wants runtime contract support should conform to this. + """ + + ENGINE_NAME: str + + def _execute_sql(self, sql: str, *args: Any, **kwargs: Any) -> Any: ... + def introspect_column_physical_type(self, table: str, column: str) -> str | None: ... + def introspect_table_physical_schema(self, table: str) -> dict[str, str]: ... + + +E = TypeVar("E", bound=ContractExecutor) + + +@dataclass +class RuntimeContractConfig: + mode: SchemaEnforcementMode + allow_extra_columns: bool + + +@dataclass +class RuntimeContractContext: + node: Node + relation: str # logical relation name, e.g. "customers" + physical_table: str # engine-specific identifier used in SQL (e.g. qualified) + contract: ContractsFileModel | None + project_contracts: ProjectContractsModel | None + config: RuntimeContractConfig + is_incremental: bool = False # future: incremental support + + +def _resolve_physical_type_for_engine( + cfg: PhysicalTypeConfig | None, + engine_name: str, +) -> str | None: + if cfg is None: + return None + engine = (engine_name or "").lower() + # exact engine key + if hasattr(cfg, engine): + v = getattr(cfg, engine) + if v: + return v + # engine base prefix before underscore; e.g. "snowflake" from "snowflake_snowpark" + if "_" in engine: + base = engine.split("_", 1)[0] + if hasattr(cfg, base): + v = getattr(cfg, base) + if v: + return v + # fallback to default + if cfg.default: + return cfg.default + return None + + +def expected_physical_schema( + *, + executor: ContractExecutor, + contract: ContractsFileModel | None, +) -> dict[str, str]: + """ + Build {column_name: expected_physical_type} for the given executor, + using the per-table ContractsFileModel. + """ + if contract is None: + return {} + + engine = getattr(executor, "ENGINE_NAME", "") or "" + result: dict[str, str] = {} + + for col_name, col_model in (contract.columns or {}).items(): + phys = col_model.physical + typ = _resolve_physical_type_for_engine(phys, engine) + if typ: + result[col_name] = typ + + return result + + +def resolve_runtime_contract_config( + *, + table_name: str, + contract: ContractsFileModel | None, + project_contracts: ProjectContractsModel | None, +) -> RuntimeContractConfig: + # 1) table-level override + if contract and contract.enforce_schema is not None: + cfg = contract.enforce_schema + return RuntimeContractConfig( + mode=cfg.mode, + allow_extra_columns=cfg.allow_extra_columns, + ) + + # 2) project-level enforcement + proj = project_contracts.enforcement if project_contracts is not None else None + if proj is None: + return RuntimeContractConfig(mode="off", allow_extra_columns=True) + + table_override = (proj.tables or {}).get(table_name) + + mode: SchemaEnforcementMode = proj.default_mode + allow_extra = proj.allow_extra_columns + + if table_override is not None: + if table_override.mode is not None: + mode = table_override.mode + if table_override.allow_extra_columns is not None: + allow_extra = table_override.allow_extra_columns + + return RuntimeContractConfig(mode=mode, allow_extra_columns=allow_extra) + + +class BaseRuntimeContracts[E: ContractExecutor]: + """ + Base class for engine-specific runtime contract implementations. + + Executors use this via composition: `self.runtime_contracts = ...`. + """ + + executor: E + + def __init__(self, executor: E): + self.executor = executor + + # ------------------------------------------------------------------ # + # Context builder used by the run-engine # + # ------------------------------------------------------------------ # + + def build_context( + self, + *, + node: Node, + relation: str, + physical_table: str, + contract: ContractsFileModel | None, + project_contracts: ProjectContractsModel | None, + is_incremental: bool = False, + ) -> RuntimeContractContext: + """ + Build a RuntimeContractContext with the correct RuntimeContractConfig. + + The caller (run-engine) decides which contract applies and passes: + - node: the fft Node being built + - relation: logical name (typically node.name) + - physical_table: fully-qualified identifier used in SQL + - contract: per-table ContractsFileModel, or None + - project_contracts: parsed project-level contracts.yml, or None + """ + # Use the contract's declared table name if present, otherwise fall + # back to the logical relation name for project-level overrides. + table_key = contract.table if contract is not None else relation + + cfg = resolve_runtime_contract_config( + table_name=table_key, + contract=contract, + project_contracts=project_contracts, + ) + + return RuntimeContractContext( + node=node, + relation=relation, + physical_table=physical_table, + contract=contract, + project_contracts=project_contracts, + config=cfg, + is_incremental=is_incremental, + ) + + # --- Hooks used by the run-engine ---------------------------- + + def apply_sql_contracts( + self, + *, + ctx: RuntimeContractContext, + select_body: str, + ) -> None: + """ + Entry point for SQL models. + + Engines override this to implement verify/cast mode. The default + implementation just does a plain CTAS (no enforcement). + """ + # Default = "off" / do nothing special: + self.executor._execute_sql(f"create or replace table {ctx.physical_table} as {select_body}") + + def verify_after_materialization(self, *, ctx: RuntimeContractContext) -> None: + """ + Optional second step (e.g. verify mode). + + Called after the model has been materialized. Default is no-op. + """ + return + + def coerce_frame_schema(self, df: Any, ctx: RuntimeContractContext) -> Any: + """ + Optional hook for Python models: given a DataFrame-like object and the + RuntimeContractContext, return a new frame whose column types have been + coerced to match the expected physical schema (where reasonable). + + Default implementation is a no-op. Engine-specific subclasses may + override this (e.g. DuckDB + pandas). + """ + return df diff --git a/src/fastflowtransform/contracts/runtime/duckdb.py b/src/fastflowtransform/contracts/runtime/duckdb.py new file mode 100644 index 0000000..7b30347 --- /dev/null +++ b/src/fastflowtransform/contracts/runtime/duckdb.py @@ -0,0 +1,230 @@ +# fastflowtransform/contracts/runtime/duckdb.py +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import pandas as pd + +from fastflowtransform.contracts.runtime.base import ( + BaseRuntimeContracts, + ContractExecutor, + RuntimeContractConfig, + RuntimeContractContext, + expected_physical_schema, +) + + +class DuckRuntimeContracts(BaseRuntimeContracts): + """ + Runtime schema contracts for DuckDB. + + Uses the shared ContractExecutor protocol only; all Duck-specific + behavior lives here, not in the executor. + """ + + def __init__(self, executor: ContractExecutor): + super().__init__(executor) + + # --- helpers --------------------------------------------------------- + + def _verify( + self, + *, + table: str, + expected: Mapping[str, str], + cfg: RuntimeContractConfig, + ) -> None: + if not expected: + return + + actual = self.executor.introspect_table_physical_schema(table) + exp_lower = {k.lower(): v for k, v in expected.items()} + + problems: list[str] = [] + + for col, expected_type in expected.items(): + key = col.lower() + if key not in actual: + problems.append(f"- missing column {col!r}") + continue + got = actual[key] + if got.lower() != expected_type.lower(): + problems.append(f"- column {col!r}: expected type {expected_type!r}, got {got!r}") + + if not cfg.allow_extra_columns: + extras = [c for c in actual if c not in exp_lower] + if extras: + problems.append(f"- extra columns present: {sorted(extras)}") + + if problems: + raise RuntimeError( + f"[contracts] DuckDB schema enforcement failed for {table}:\n" + "\n".join(problems) + ) + + def _ctas_raw(self, target: str, select_body: str) -> None: + self.executor._execute_sql(f"create or replace table {target} as {select_body}") + + def _ctas_cast_via_subquery( + self, + *, + ctx: RuntimeContractContext, + select_body: str, + expected: Mapping[str, str], + ) -> None: + """ + Cast mode for SQL models: + + create or replace table target as + select cast(col as TYPE) as col, ... + from ( ) as src + """ + if not expected: + self._ctas_raw(ctx.physical_table, select_body) + return + + exp_lower = {k.lower(): v for k, v in expected.items()} + + projections: list[str] = [f"cast({col} as {typ}) as {col}" for col, typ in expected.items()] + + if ctx.config.allow_extra_columns: + # stage in a temp table in the same "namespace" as physical_table + tmp_name = f"{ctx.physical_table}__ff_contract_tmp".replace('"', "") + self._ctas_raw(tmp_name, select_body) + actual = self.executor.introspect_table_physical_schema(tmp_name) + for c in actual: + if c not in exp_lower: + projections.append(c) + proj_sql = ", ".join(projections) + self.executor._execute_sql( + f"create or replace table {ctx.physical_table} as select {proj_sql} from {tmp_name}" + ) + self.executor._execute_sql(f"drop table if exists {tmp_name}") + else: + proj_sql = ", ".join(projections) + wrapped = f"select {proj_sql} from ({select_body}) as src" + self._ctas_raw(ctx.physical_table, wrapped) + + def coerce_frame_schema(self, df: Any, ctx: RuntimeContractContext) -> Any: + """ + Coerce a pandas.DataFrame to match DuckDB physical types in 'cast' mode. + + - Only runs when ctx.config.mode == "cast" + - Only for pandas.DataFrame + - Uses expected_physical_schema (per-table contracts) + """ + if ctx.config.mode != "cast": + return df + + if not isinstance(df, pd.DataFrame): + return df + + expected = expected_physical_schema( + executor=self.executor, + contract=ctx.contract, + ) + if not expected: + return df + + coerced = df.copy() + + for col, phys_type in expected.items(): + if col not in coerced.columns: + # Missing columns are handled by verification later. + continue + coerced[col] = self._coerce_series_to_type(coerced[col], phys_type) + + return coerced + + def _coerce_series_to_type(self, s: pd.Series, phys_type: str) -> pd.Series: + """ + Best-effort coercion of a pandas Series to a DuckDB physical type. + + This is intentionally simple and will raise on invalid casts so that + contract violations surface clearly. + """ + t = (phys_type or "").strip().lower() + if "(" in t: + t = t.split("(", 1)[0].strip() + + # Nullable boolean + if t in {"boolean", "bool"}: + return s.astype("boolean") + + # Integers + if t in {"tinyint", "smallint", "int", "integer", "bigint"}: + return pd.to_numeric(s, errors="raise").astype("Int64") + + # Floats / decimals + if t in {"float", "real", "double", "double precision", "decimal", "numeric"}: + return pd.to_numeric(s, errors="raise").astype("float64") + + # Text / strings + if t in {"char", "character", "varchar", "text", "string"}: + # Use pandas' nullable string dtype + return s.astype("string") + + # Date / timestamp + if t in {"date"}: + return pd.to_datetime(s, errors="raise").dt.date + + if t in {"timestamp", "timestamptz", "timestamp_ntz", "timestamp_ltz"}: + return pd.to_datetime(s, errors="raise") + + # For types like json, uuid, varbinary, etc., we leave as-is and rely + # on DuckDB's inference. + return s + + # --- BaseRuntimeContracts hooks ------------------------------------- + + def apply_sql_contracts( + self, + *, + ctx: RuntimeContractContext, + select_body: str, + ) -> None: + """ + Apply DuckDB runtime contracts for SQL models. + """ + expected = expected_physical_schema( + executor=self.executor, + contract=ctx.contract, + ) + + mode = ctx.config.mode + + if mode == "off" or not expected: + self._ctas_raw(ctx.physical_table, select_body) + return + + if mode == "cast": + self._ctas_cast_via_subquery(ctx=ctx, select_body=select_body, expected=expected) + self._verify(table=ctx.physical_table, expected=expected, cfg=ctx.config) + return + + if mode == "verify": + self._ctas_raw(ctx.physical_table, select_body) + self._verify(table=ctx.physical_table, expected=expected, cfg=ctx.config) + return + + # unknown mode -> behave like off + self._ctas_raw(ctx.physical_table, select_body) + + def verify_after_materialization(self, *, ctx: RuntimeContractContext) -> None: + """ + If you want a second verification step (e.g. after incremental insert/merge), + you can call this from the run-engine. For now it's a thin wrapper. + """ + expected = expected_physical_schema( + executor=self.executor, + contract=ctx.contract, + ) + if not expected: + return + if ctx.config.mode not in {"verify", "cast"}: + return + self._verify( + table=ctx.physical_table, + expected=expected, + cfg=ctx.config, + ) diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 306cf4c..5c27ca4 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -16,6 +16,7 @@ from fastflowtransform import incremental as _ff_incremental from fastflowtransform.api import context as _http_ctx +from fastflowtransform.config.contracts import ContractsFileModel, ProjectContractsModel from fastflowtransform.config.sources import resolve_source_entry from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError @@ -134,6 +135,21 @@ class BaseExecutor[TFrame](ABC): SNAPSHOT_HASH_COL = "_ff_snapshot_hash" SNAPSHOT_UPDATED_AT_COL = "_ff_updated_at" + _ff_contracts: Mapping[str, ContractsFileModel] | None = None + _ff_project_contracts: ProjectContractsModel | None = None + + def configure_contracts( + self, + contracts: Mapping[str, ContractsFileModel] | None, + project_contracts: ProjectContractsModel | None, + ) -> None: + """ + Inject parsed contracts into this executor instance. + The run engine should call this once at startup. + """ + self._ff_contracts = contracts or {} + self._ff_project_contracts = project_contracts + # ---------- SQL ---------- def render_sql( self, @@ -308,7 +324,30 @@ def run_sql(self, node: Node, env: Environment) -> None: echo_debug(preview) try: - self._apply_sql_materialization(node, target_sql, body, materialization) + runtime = getattr(self, "runtime_contracts", None) + # contracts only for TABLE materialization for now + if runtime is not None and materialization == "table": + contracts = getattr(self, "_ff_contracts", {}) or {} + project_contracts = getattr(self, "_ff_project_contracts", None) + + # keying: prefer the logical table name (contracts.table), + # but node.name or relation_for(node.name) is usually what you want. + logical_name = relation_for(node.name) + contract = contracts.get(logical_name) or contracts.get(node.name) + + ctx = runtime.build_context( + node=node, + relation=logical_name, + physical_table=target_sql, + contract=contract, + project_contracts=project_contracts, + is_incremental=self._meta_is_incremental(meta), + ) + # Engine-specific enforcement (verify/cast/off) + runtime.apply_sql_contracts(ctx=ctx, select_body=body) + else: + # Old behavior + self._apply_sql_materialization(node, target_sql, body, materialization) except Exception as e: preview = f"-- materialized={materialization}\n-- target={target_sql}\n{body}" raise ModelExecutionError( @@ -630,11 +669,8 @@ def run_python(self, node: Node) -> None: self._reset_http_ctx(node) - # arg = self._build_python_args(node, deps) args, argmap = self._build_python_inputs(node, deps) requires = REGISTRY.py_requires.get(node.name, {}) - # if deps: - # self._validate_required(node.name, arg, requires) if deps: # Required-columns check works against the mapping self._validate_required(node.name, argmap, requires) @@ -646,6 +682,30 @@ def run_python(self, node: Node) -> None: meta = getattr(node, "meta", {}) or {} mat = self._resolve_materialization_strategy(meta) + # ---------- Runtime contracts for Python models ---------- + runtime = getattr(self, "runtime_contracts", None) + ctx = None + if runtime is not None: + contracts = getattr(self, "_ff_contracts", {}) or {} + project_contracts = getattr(self, "_ff_project_contracts", None) + + logical = target # usually relation_for(node.name) + contract = contracts.get(logical) or contracts.get(node.name) + + if contract is not None or project_contracts is not None: + physical_table = self._format_relation_for_ref(node.name) + ctx = runtime.build_context( + node=node, + relation=logical, + physical_table=physical_table, + contract=contract, + project_contracts=project_contracts, + is_incremental=(mat == "incremental"), + ) + # Allow runtime to coerce DataFrame types in cast mode + out = runtime.coerce_frame_schema(out, ctx) + + # ---------- Materialization ---------- if mat == "incremental": self._materialize_incremental(target, out, node, meta) elif mat == "view": @@ -653,6 +713,9 @@ def run_python(self, node: Node) -> None: else: self._materialize_relation(target, out, node) + if ctx is not None and runtime is not None: + runtime.verify_after_materialization(ctx=ctx) + self._snapshot_http_ctx(node) # ----------------- helpers ----------------- diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index ff70026..4eaa43a 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -13,6 +13,7 @@ import pandas as pd from duckdb import CatalogException +from fastflowtransform.contracts.runtime.duckdb import DuckRuntimeContracts from fastflowtransform.core import Node from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin @@ -28,7 +29,8 @@ def _q(ident: str) -> str: class DuckExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[pd.DataFrame]): - ENGINE_NAME = "duckdb" + ENGINE_NAME: str = "duckdb" + runtime_contracts: DuckRuntimeContracts _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { "boolean": 1, @@ -87,6 +89,8 @@ def __init__( self._execute_sql(f"create schema if not exists {safe_schema}") self._execute_sql(f"set schema '{self.schema}'") + self.runtime_contracts = DuckRuntimeContracts(self) + def execute_test_sql(self, stmt: Any) -> Any: """ Execute lightweight SQL for DQ tests using the underlying DuckDB connection. @@ -441,9 +445,19 @@ def _apply_catalog_override(self, name: str) -> bool: def clone(self) -> DuckExecutor: """ - Generates a new Executor instance with own connection for Thread-Worker. + Generates a new Executor instance with its own connection for Thread-Worker. + Copies runtime-contract configuration from the parent. """ - return DuckExecutor(self.db_path, schema=self.schema, catalog=self.catalog) + cloned = DuckExecutor(self.db_path, schema=self.schema, catalog=self.catalog) + + # Propagate contracts + project contracts to the clone + contracts = getattr(self, "_ff_contracts", None) + project_contracts = getattr(self, "_ff_project_contracts", None) + if contracts is not None or project_contracts is not None: + # configure_contracts lives on BaseExecutor + cloned.configure_contracts(contracts or {}, project_contracts) + + return cloned def _exec_many(self, sql: str) -> None: """ @@ -702,42 +716,62 @@ def utest_clean_target(self, relation: str) -> None: with suppress(Exception): self._execute_sql(f"drop table if exists {target}") - def introspect_column_physical_type(self, table: str, column: str) -> str | None: + def _introspect_columns_metadata( + self, + table: str, + column: str | None = None, + ) -> list[tuple[str, str]]: """ - DuckDB: read `data_type` from information_schema.columns. + Internal helper: return [(column_name, data_type), ...] for a DuckDB table. + + - Uses _normalize_table_identifier / _normalize_column_identifier + - Works with or without schema qualification + - Optionally restricts to a single column """ schema, table_name = self._normalize_table_identifier(table) table_lower = table_name.lower() - column_lower = self._normalize_column_identifier(column).lower() + params: list[str] = [table_lower] + + where_clauses: list[str] = ["lower(table_name) = lower(?)"] if schema: - rows = self._execute_sql( - """ - select data_type - from information_schema.columns - where lower(table_name) = lower(?) - and lower(table_schema)= lower(?) - and lower(column_name) = lower(?) - order by table_schema, ordinal_position - limit 1 - """, - [table_lower, schema.lower(), column_lower], - ).fetchall() - else: - rows = self._execute_sql( - """ - select data_type - from information_schema.columns - where lower(table_name) = lower(?) - and lower(column_name) = lower(?) - order by table_schema, ordinal_position - limit 1 - """, - [table_lower, column_lower], - ).fetchall() + where_clauses.append("lower(table_schema) = lower(?)") + params.append(schema.lower()) + + if column is not None: + column_lower = self._normalize_column_identifier(column).lower() + where_clauses.append("lower(column_name) = lower(?)") + params.append(column_lower) + + where_sql = " AND ".join(where_clauses) - return rows[0][0] if rows else None + sql = ( + "select column_name, data_type " + "from information_schema.columns " + f"where {where_sql} " + "order by table_schema, ordinal_position" + ) + + rows = self._execute_sql(sql, params).fetchall() + + # Normalize to plain strings + return [(str(name), str(dtype)) for (name, dtype) in rows] + + def introspect_column_physical_type(self, table: str, column: str) -> str | None: + """ + DuckDB: read `data_type` from information_schema.columns for a single column. + """ + rows = self._introspect_columns_metadata(table, column=column) + # rows: [(column_name, data_type), ...] + return rows[0][1] if rows else None + + def introspect_table_physical_schema(self, table: str) -> dict[str, str]: + """ + DuckDB: return {column_name: data_type} for all columns of `table`. + """ + rows = self._introspect_columns_metadata(table, column=None) + return {name: dtype for (name, dtype) in rows} def load_seed( self, table: str, df: pd.DataFrame, schema: str | None = None diff --git a/src/fastflowtransform/incremental.py b/src/fastflowtransform/incremental.py index dc91219..341bca3 100644 --- a/src/fastflowtransform/incremental.py +++ b/src/fastflowtransform/incremental.py @@ -69,6 +69,55 @@ def _is_merge_not_supported_error(exc: Exception) -> bool: # ---------- Helper ---------- +def _apply_runtime_contracts_after_incremental(executor: Any, node: Any, relation: str) -> None: + """ + After an incremental model has been materialized (via create_table_as / + incremental_insert / incremental_merge), run runtime contracts in + verify/cast mode if the executor supports them. + + This is intentionally generic and works for any executor that exposes: + - runtime_contracts + - _ff_contracts + - _ff_project_contracts + - _format_relation_for_ref(name: str) -> str + """ + runtime = getattr(executor, "runtime_contracts", None) + if runtime is None: + return + + contracts = getattr(executor, "_ff_contracts", {}) or {} + project_contracts = getattr(executor, "_ff_project_contracts", None) + + # How you key contracts may vary slightly; common patterns: + # - contracts["customers"] + # - contracts[relation_for(node.name)] + logical = relation_for(node.name) + contract = contracts.get(logical) or contracts.get(node.name) + + # If there is no per-table contract and no project-level enforcement, + # there's nothing to do. + if contract is None and project_contracts is None: + return + + try: + physical = executor._format_relation_for_ref(node.name) + except AttributeError: + # Fallback: use the logical relation if the executor does not + # implement the more specific formatting hook. + physical = relation + + ctx = runtime.build_context( + node=node, + relation=logical, + physical_table=physical, + contract=contract, + project_contracts=project_contracts, + is_incremental=True, + ) + + runtime.verify_after_materialization(ctx=ctx) + + def _safe_exists(executor: Any, relation: Any) -> bool: try: return bool(executor.exists_relation(relation)) @@ -271,6 +320,8 @@ def run_or_dispatch(executor: Any, node: Any, jenv: Any) -> None: if not exists: try: _create_table_as_or_replace(executor, relation, fallback_sql) + # Contracts: first incremental run creates the table → verify schema + _apply_runtime_contracts_after_incremental(executor, node, relation) except Exception as e: wrap_full_refresh(e) return @@ -285,7 +336,10 @@ def run_or_dispatch(executor: Any, node: Any, jenv: Any) -> None: fallback_sql=fallback_sql, on_full_refresh_error=wrap_full_refresh, ) + # Contracts: after merge/insert/full-refresh fallback, verify schema + _apply_runtime_contracts_after_incremental(executor, node, relation) except ModelExecutionError: + # already wrapped; propagate raise except Exception as e: wrap_incremental(e) diff --git a/tests/unit/test_contracts_unit.py b/tests/unit/test_contracts_unit.py index 1885935..e92c29b 100644 --- a/tests/unit/test_contracts_unit.py +++ b/tests/unit/test_contracts_unit.py @@ -14,7 +14,7 @@ ContractsFileModel, PhysicalTypeConfig, ) -from fastflowtransform.contracts import ( +from fastflowtransform.contracts.core import ( _apply_column_defaults, _contract_tests_for_table, _discover_contract_paths, From 0b02000eba92b3ea8a1a12beeeca9b1219389b17 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Fri, 12 Dec 2025 14:57:53 +0100 Subject: [PATCH 5/7] Added contract enforcement to basic_demo for duckdb --- examples/basic_demo/contracts.yml | 37 ++++++ .../marts/mart_latest_signup.contracts.yml | 52 ++++++++ .../marts/mart_users_by_domain.contracts.yml | 53 ++++++++ src/fastflowtransform/cli/bootstrap.py | 27 +++- src/fastflowtransform/cli/run.py | 28 +--- src/fastflowtransform/cli/test_cmd.py | 5 +- src/fastflowtransform/config/contracts.py | 20 +++ .../contracts/runtime/base.py | 18 +++ .../contracts/runtime/duckdb.py | 125 ++++++++++-------- src/fastflowtransform/executors/base.py | 25 ++-- 10 files changed, 302 insertions(+), 88 deletions(-) create mode 100644 examples/basic_demo/contracts.yml create mode 100644 examples/basic_demo/models/marts/mart_latest_signup.contracts.yml create mode 100644 examples/basic_demo/models/marts/mart_users_by_domain.contracts.yml diff --git a/examples/basic_demo/contracts.yml b/examples/basic_demo/contracts.yml new file mode 100644 index 0000000..5ed90eb --- /dev/null +++ b/examples/basic_demo/contracts.yml @@ -0,0 +1,37 @@ +# Project-level contracts for the basic_demo. + +defaults: + # Reusable column-level rules applied by regex. + columns: + # All *_id columns are non-null integers + - match: + name: ".*_id$" + type: integer + nullable: false + + # All *_date columns are non-null dates + - match: + name: ".*_date$" + type: date + nullable: false + + # email_domain is non-null everywhere + - match: + name: "^email_domain$" + type: string + nullable: false + description: "Normalized email domain (lowercased)." + +enforcement: + # Modes: off | verify | cast + default_mode: off + allow_extra_columns: true + + tables: + mart_users_by_domain: + mode: verify # only check schema, don't cast + allow_extra_columns: true + + mart_latest_signup: + mode: cast # cast into physical types, then verify + allow_extra_columns: true diff --git a/examples/basic_demo/models/marts/mart_latest_signup.contracts.yml b/examples/basic_demo/models/marts/mart_latest_signup.contracts.yml new file mode 100644 index 0000000..9200579 --- /dev/null +++ b/examples/basic_demo/models/marts/mart_latest_signup.contracts.yml @@ -0,0 +1,52 @@ +table: mart_latest_signup + +columns: + email_domain: + type: string + nullable: false + unique: true + description: "Email domain, one row per domain." + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + databricks_spark: STRING + snowflake_snowpark: VARCHAR + default: STRING + + latest_user_id: + type: integer + nullable: false + description: "User ID of the most recent signup on this domain." + physical: + duckdb: INTEGER + postgres: integer + bigquery: INT64 + databricks_spark: INT + snowflake_snowpark: NUMBER + default: INTEGER + + latest_email: + type: string + nullable: false + regex: "^.+@.+$" + description: "Email address of the most recent signup." + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + databricks_spark: STRING + snowflake_snowpark: VARCHAR + default: STRING + + latest_signup_date: + type: date + nullable: false + description: "Timestamp / date of the most recent signup." + physical: + duckdb: DATE + postgres: date + bigquery: DATE + databricks_spark: DATE + snowflake_snowpark: DATE + default: DATE diff --git a/examples/basic_demo/models/marts/mart_users_by_domain.contracts.yml b/examples/basic_demo/models/marts/mart_users_by_domain.contracts.yml new file mode 100644 index 0000000..2f40725 --- /dev/null +++ b/examples/basic_demo/models/marts/mart_users_by_domain.contracts.yml @@ -0,0 +1,53 @@ +table: mart_users_by_domain + +columns: + email_domain: + type: string + nullable: false + unique: true + description: "Email domain (example.com, example.net, …)." + physical: + duckdb: VARCHAR + postgres: text + bigquery: STRING + databricks_spark: STRING + snowflake_snowpark: VARCHAR + default: STRING + + user_count: + type: integer + nullable: false + min: 0 + description: "Number of users per domain (COUNT(*))" + physical: + # COUNT(*) result types + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT + + first_signup: + type: date + nullable: false + description: "Earliest signup date for this domain." + physical: + duckdb: DATE + postgres: date + bigquery: DATE + databricks_spark: DATE + snowflake_snowpark: DATE + default: DATE + + last_signup: + type: date + nullable: false + description: "Latest signup date for this domain." + physical: + duckdb: DATE + postgres: date + bigquery: DATE + databricks_spark: DATE + snowflake_snowpark: DATE + default: DATE diff --git a/src/fastflowtransform/cli/bootstrap.py b/src/fastflowtransform/cli/bootstrap.py index ebb4a10..6145d84 100644 --- a/src/fastflowtransform/cli/bootstrap.py +++ b/src/fastflowtransform/cli/bootstrap.py @@ -4,6 +4,7 @@ import importlib import os from collections.abc import Callable +from contextlib import suppress from dataclasses import dataclass from pathlib import Path from typing import Any, NoReturn, cast @@ -14,10 +15,11 @@ from jinja2 import Environment from fastflowtransform.config.budgets import BudgetsConfig, load_budgets_config +from fastflowtransform.contracts.core import _load_project_contracts, load_contracts from fastflowtransform.core import REGISTRY from fastflowtransform.errors import DependencyNotFoundError from fastflowtransform.executors.base import BaseExecutor -from fastflowtransform.logging import echo +from fastflowtransform.logging import echo, warn from fastflowtransform.settings import ( EngineType, EnvSettings, @@ -150,6 +152,29 @@ def _merge(p: Path) -> None: os.environ.setdefault(key, value) +def configure_executor_contracts(project_dir: Path, executor: BaseExecutor | None) -> None: + """ + Load contracts from project_dir and attach them to the executor (if supported). + + Mirrors the behaviour in `fft run`: parse per-table contracts and the + project-level contracts.yml; on parse errors, log a warning and continue + without contracts. + """ + if executor is None or not hasattr(executor, "configure_contracts"): + return + + try: + contracts_by_table = load_contracts(project_dir) + project_contracts = _load_project_contracts(project_dir) + except Exception as exc: + warn(f"[contracts] Failed to load contracts from {project_dir}: {exc}") + contracts_by_table = {} + project_contracts = None + + with suppress(Exception): + executor.configure_contracts(contracts_by_table, project_contracts) + + def _resolve_profile( env_name: str, engine: EngineType | None, proj: Path ) -> tuple[EnvSettings, Profile]: diff --git a/src/fastflowtransform/cli/run.py b/src/fastflowtransform/cli/run.py index 5946466..120ff07 100644 --- a/src/fastflowtransform/cli/run.py +++ b/src/fastflowtransform/cli/run.py @@ -27,7 +27,11 @@ compute_affected_models, get_changed_models, ) -from fastflowtransform.cli.bootstrap import CLIContext, _prepare_context +from fastflowtransform.cli.bootstrap import ( + CLIContext, + _prepare_context, + configure_executor_contracts, +) from fastflowtransform.cli.options import ( CacheMode, CacheOpt, @@ -58,7 +62,6 @@ load_budgets_config, ) from fastflowtransform.config.project import HookSpec -from fastflowtransform.contracts.core import _load_project_contracts, load_contracts from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.dag import levels as dag_levels from fastflowtransform.executors.base import BaseExecutor @@ -1812,22 +1815,7 @@ def run( engine_.run_started_at = datetime.now(UTC).isoformat(timespec="seconds") # ---------- Runtime contracts: load + configure executor ---------- - try: - project_dir = Path(ctx.project) - except TypeError: - project_dir = Path(str(ctx.project)) - - try: - contracts_by_table = load_contracts(project_dir) - project_contracts = _load_project_contracts(project_dir) - except Exception as exc: - # If contracts parsing blows up, you can either: - # - treat it as fatal (like budgets.yml), or - # - log a warning and continue without contracts. - # For now we log and proceed, contracts are optional. - warn(f"[contracts] Failed to load contracts from {project_dir}: {exc}") - contracts_by_table = {} - project_contracts = None + project_dir = Path(ctx.project) # engine_.shared is (executor, run_sql_fn, run_py_fn) try: @@ -1835,9 +1823,7 @@ def run( except Exception: executor = None - if executor is not None and hasattr(executor, "configure_contracts"): - with suppress(Exception): - executor.configure_contracts(contracts_by_table, project_contracts) + configure_executor_contracts(project_dir, executor) bind_context( engine=ctx.profile.engine, diff --git a/src/fastflowtransform/cli/test_cmd.py b/src/fastflowtransform/cli/test_cmd.py index e8dca96..de48573 100644 --- a/src/fastflowtransform/cli/test_cmd.py +++ b/src/fastflowtransform/cli/test_cmd.py @@ -10,7 +10,7 @@ import typer -from fastflowtransform.cli.bootstrap import _prepare_context +from fastflowtransform.cli.bootstrap import _prepare_context, configure_executor_contracts from fastflowtransform.cli.options import ( EngineOpt, EnvOpt, @@ -464,13 +464,14 @@ def test( engine: EngineOpt = None, vars: VarsOpt = None, select: SelectOpt = None, - skip_build: SkipBuildOpt = False, + skip_build: SkipBuildOpt = True, ) -> None: ctx = _prepare_context(project, env_name, engine, vars) tokens, pred = _compile_selector(select) has_model_matches = any(pred(node) for node in REGISTRY.nodes.values()) legacy_tag_only = _is_legacy_test_token(tokens) and not has_model_matches execu, run_sql, run_py = ctx.make_executor() + configure_executor_contracts(ctx.project, execu) model_pred = (lambda _n: True) if legacy_tag_only else pred # Run models; if a model fails, show friendly error then exit(1). diff --git a/src/fastflowtransform/config/contracts.py b/src/fastflowtransform/config/contracts.py index 3eaf496..e6084e8 100644 --- a/src/fastflowtransform/config/contracts.py +++ b/src/fastflowtransform/config/contracts.py @@ -149,6 +149,16 @@ class TableSchemaEnforcementModel(BaseModel): mode: SchemaEnforcementMode = "off" allow_extra_columns: bool = True + @field_validator("mode", mode="before") + @classmethod + def _coerce_mode(cls, v: Any) -> Any: + # Allow bare `off` from YAML → False + if v is False: + return "off" + if isinstance(v, str): + return v.strip().lower() + return v + class ContractsFileModel(BaseModel): """ @@ -347,6 +357,16 @@ class ProjectSchemaEnforcementModel(BaseModel): allow_extra_columns: bool = True tables: dict[str, TableSchemaEnforcementOverrideModel] = Field(default_factory=dict) + @field_validator("default_mode", mode="before") + @classmethod + def _coerce_default_mode(cls, v: Any) -> Any: + if v is False: + return "off" + # Same comment as above if you ever want to accept `true`. + if isinstance(v, str): + return v.strip().lower() + return v + class ProjectContractsModel(BaseModel): """ diff --git a/src/fastflowtransform/contracts/runtime/base.py b/src/fastflowtransform/contracts/runtime/base.py index cbbdef2..79731c8 100644 --- a/src/fastflowtransform/contracts/runtime/base.py +++ b/src/fastflowtransform/contracts/runtime/base.py @@ -220,3 +220,21 @@ def coerce_frame_schema(self, df: Any, ctx: RuntimeContractContext) -> Any: override this (e.g. DuckDB + pandas). """ return df + + def materialize_python( + self, + *, + ctx: RuntimeContractContext, + df: Any, + ) -> bool: + """ + Optional hook for Python models. + + Engines override this to take over materialization for Python + models (e.g. to enforce contracts via explicit CASTs). + + Return True if you fully materialized ctx.physical_table yourself. + Return False to let the executor use its normal path + (_materialize_relation / _materialize_incremental). + """ + return False diff --git a/src/fastflowtransform/contracts/runtime/duckdb.py b/src/fastflowtransform/contracts/runtime/duckdb.py index 7b30347..6715e31 100644 --- a/src/fastflowtransform/contracts/runtime/duckdb.py +++ b/src/fastflowtransform/contracts/runtime/duckdb.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from contextlib import suppress from typing import Any import pandas as pd @@ -105,75 +106,87 @@ def _ctas_cast_via_subquery( wrapped = f"select {proj_sql} from ({select_body}) as src" self._ctas_raw(ctx.physical_table, wrapped) - def coerce_frame_schema(self, df: Any, ctx: RuntimeContractContext) -> Any: + def materialize_python( + self, + *, + ctx: RuntimeContractContext, + df: Any, + ) -> bool: """ - Coerce a pandas.DataFrame to match DuckDB physical types in 'cast' mode. + DuckDB-specific materialization for Python models that enforces + contracts via explicit CASTs in DuckDB, not via pandas dtypes. - - Only runs when ctx.config.mode == "cast" + - Only active when mode != "off" - Only for pandas.DataFrame - - Uses expected_physical_schema (per-table contracts) + - Uses expected_physical_schema to build CAST expressions """ - if ctx.config.mode != "cast": - return df + print(ctx) + mode = ctx.config.mode + if mode == "off": + return False if not isinstance(df, pd.DataFrame): - return df + # We only know how to handle pandas frames here. + return False expected = expected_physical_schema( executor=self.executor, contract=ctx.contract, ) - if not expected: - return df - - coerced = df.copy() - - for col, phys_type in expected.items(): - if col not in coerced.columns: - # Missing columns are handled by verification later. - continue - coerced[col] = self._coerce_series_to_type(coerced[col], phys_type) - - return coerced - - def _coerce_series_to_type(self, s: pd.Series, phys_type: str) -> pd.Series: - """ - Best-effort coercion of a pandas Series to a DuckDB physical type. - - This is intentionally simple and will raise on invalid casts so that - contract violations surface clearly. - """ - t = (phys_type or "").strip().lower() - if "(" in t: - t = t.split("(", 1)[0].strip() - - # Nullable boolean - if t in {"boolean", "bool"}: - return s.astype("boolean") - - # Integers - if t in {"tinyint", "smallint", "int", "integer", "bigint"}: - return pd.to_numeric(s, errors="raise").astype("Int64") - # Floats / decimals - if t in {"float", "real", "double", "double precision", "decimal", "numeric"}: - return pd.to_numeric(s, errors="raise").astype("float64") - - # Text / strings - if t in {"char", "character", "varchar", "text", "string"}: - # Use pandas' nullable string dtype - return s.astype("string") - - # Date / timestamp - if t in {"date"}: - return pd.to_datetime(s, errors="raise").dt.date - - if t in {"timestamp", "timestamptz", "timestamp_ntz", "timestamp_ltz"}: - return pd.to_datetime(s, errors="raise") + # In verify mode, we still want to create the table via DuckDB; contracts + # are enforced by verify_after_materialization. + if mode == "cast" and not expected: + raise RuntimeError( + f"[contracts] cast mode enabled for {ctx.relation!r} " + "but no physical schema could be resolved." + ) - # For types like json, uuid, varbinary, etc., we leave as-is and rely - # on DuckDB's inference. - return s + con = self.executor.con + + tmp_name = f"__ff_py_src_{ctx.relation}".replace(".", "_") + + # 1) Register DataFrame as temp relation in DuckDB + con.register(tmp_name, df) + + try: + select_body = f'select * from "{tmp_name}"' + + if mode == "cast": + # Reuse the same logic as SQL: explicit CAST(...) AS TYPE + self._ctas_cast_via_subquery( + ctx=ctx, + select_body=select_body, + expected=expected, + ) + self._verify( + table=ctx.physical_table, + expected=expected, + cfg=ctx.config, + ) + + elif mode == "verify": + # Plain CTAS from the tmp, then verify that the resulting + # physical types match the contract. + self._ctas_raw(ctx.physical_table, select_body) + if expected: + self._verify( + table=ctx.physical_table, + expected=expected, + cfg=ctx.config, + ) + else: + # Unknown mode -> let executor handle it + return False + + return True + + finally: + with suppress(Exception): + con.unregister(tmp_name) + with suppress(Exception): + # In case older DuckDB uses views for registered tables + self.executor._execute_sql(f'drop view if exists "{tmp_name}"') # --- BaseRuntimeContracts hooks ------------------------------------- diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index 5c27ca4..cb880c3 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -685,6 +685,8 @@ def run_python(self, node: Node) -> None: # ---------- Runtime contracts for Python models ---------- runtime = getattr(self, "runtime_contracts", None) ctx = None + took_over = False + if runtime is not None: contracts = getattr(self, "_ff_contracts", {}) or {} project_contracts = getattr(self, "_ff_project_contracts", None) @@ -702,16 +704,23 @@ def run_python(self, node: Node) -> None: project_contracts=project_contracts, is_incremental=(mat == "incremental"), ) - # Allow runtime to coerce DataFrame types in cast mode - out = runtime.coerce_frame_schema(out, ctx) + + # Optional pre-coercion (default is no-op). + if hasattr(runtime, "coerce_frame_schema"): + out = runtime.coerce_frame_schema(out, ctx) + + # Allow engine-specific runtime to take over Python materialization + if mat == "table" and hasattr(runtime, "materialize_python"): + took_over = bool(runtime.materialize_python(ctx=ctx, df=out)) # ---------- Materialization ---------- - if mat == "incremental": - self._materialize_incremental(target, out, node, meta) - elif mat == "view": - self._materialize_view(target, out, node) - else: - self._materialize_relation(target, out, node) + if not took_over: + if mat == "incremental": + self._materialize_incremental(target, out, node, meta) + elif mat == "view": + self._materialize_view(target, out, node) + else: + self._materialize_relation(target, out, node) if ctx is not None and runtime is not None: runtime.verify_after_materialization(ctx=ctx) From 9e632a007456754e892947cc3492bca9a5171b74 Mon Sep 17 00:00:00 2001 From: Marko Lekic Date: Sun, 14 Dec 2025 14:53:28 +0100 Subject: [PATCH 6/7] Added runtime contracts for schema enforcement for all engines --- docs/Contracts.md | 182 +++++++++- .../models/staging/customers.contracts.yml | 4 +- examples/incremental_demo/Makefile | 35 +- examples/incremental_demo/contracts.yml | 33 ++ .../fct_events_py_incremental.contracts.yml | 47 +++ .../fct_events_sql_inline.contracts.yml | 36 ++ .../common/fct_events_sql_yaml.contracts.yml | 36 ++ examples/incremental_demo/project.yml | 9 +- .../{seed_events.csv => seed_events_v1.csv} | 0 .../incremental_demo/seeds/seed_events_v2.csv | 5 + src/fastflowtransform/cli/bootstrap.py | 10 + .../contracts/runtime/base.py | 33 +- .../contracts/runtime/bigquery.py | 152 ++++++++ .../contracts/runtime/databricks_spark.py | 267 ++++++++++++++ .../contracts/runtime/postgres.py | 329 ++++++++++++++++++ .../contracts/runtime/snowflake_snowpark.py | 277 +++++++++++++++ src/fastflowtransform/executors/base.py | 10 + .../executors/bigquery/base.py | 62 ++-- .../executors/bigquery/bigframes.py | 5 +- .../executors/bigquery/pandas.py | 5 +- .../executors/databricks_spark.py | 60 +++- src/fastflowtransform/executors/postgres.py | 117 +++++-- .../executors/snowflake_snowpark.py | 210 ++++++++--- src/fastflowtransform/testing/base.py | 4 +- src/fastflowtransform/testing/registry.py | 4 + 25 files changed, 1799 insertions(+), 133 deletions(-) create mode 100644 examples/incremental_demo/contracts.yml create mode 100644 examples/incremental_demo/models/common/fct_events_py_incremental.contracts.yml create mode 100644 examples/incremental_demo/models/common/fct_events_sql_inline.contracts.yml create mode 100644 examples/incremental_demo/models/common/fct_events_sql_yaml.contracts.yml rename examples/incremental_demo/seeds/{seed_events.csv => seed_events_v1.csv} (100%) create mode 100644 examples/incremental_demo/seeds/seed_events_v2.csv create mode 100644 src/fastflowtransform/contracts/runtime/bigquery.py create mode 100644 src/fastflowtransform/contracts/runtime/databricks_spark.py create mode 100644 src/fastflowtransform/contracts/runtime/postgres.py create mode 100644 src/fastflowtransform/contracts/runtime/snowflake_snowpark.py diff --git a/docs/Contracts.md b/docs/Contracts.md index 305c8f2..d2df549 100644 --- a/docs/Contracts.md +++ b/docs/Contracts.md @@ -158,6 +158,26 @@ If the engine does not yet support physical type introspection, the test will fail with a clear “engine not yet supported” message instead of silently passing. +### Engine-canonical type names + +Physical type comparisons use the **canonical type strings reported by the engine**. + +That means: + +* Some engines expose aliases as canonical names in their catalogs. + + * Example (Postgres): + + * `timestamp` is an alias for `timestamp without time zone` + * `timestamptz` is an alias for `timestamp with time zone` +* FFT compares types after **engine-specific canonicalization**, so contracts can use common names like `timestamp`/`timestamptz` while still matching what Postgres reports. + +If you see a mismatch like: + +> expected `timestamp`, got `timestamp without time zone` + +it means your Postgres executor/runtime is not canonicalizing types yet (or you’re using raw `information_schema.data_type`). In that case, update Postgres type introspection to use `pg_catalog.format_type(...)` so comparisons are consistent. + --- ### `nullable` @@ -277,6 +297,49 @@ defaults: nullable: false ``` +### `contracts.yml` enforcement configuration + +Example: + +```yaml +version: 1 + +defaults: + columns: + - match: + name: ".*_id$" + type: integer + nullable: false + +enforcement: + # Modes: off | verify | cast + default_mode: off + + # If true, contract enforcement only cares about declared columns. + # Extra columns produced by the model are allowed. + allow_extra_columns: true + + # Optional per-table overrides (by logical relation name) + tables: + mart_users_by_domain: + mode: verify + allow_extra_columns: true + + mart_latest_signup: + mode: cast + allow_extra_columns: true +``` + +Rules: + +* `enforcement.default_mode` applies to all tables unless overridden. +* `enforcement.tables.
.mode` overrides the default for a single table. +* `allow_extra_columns` controls whether the model output may contain columns not listed in the contract: + + * `true`: extra columns are ignored by enforcement (but still exist in the table) + * `false`: extra columns fail enforcement + + ### Column match rules Each entry under `defaults.columns` is a **column default rule**: @@ -410,6 +473,111 @@ customers.status accepted_values (tags: contract) You don’t need to write those tests yourself; they’re derived automatically from the contract files. +### Runtime enforcement (optional) + +In addition to turning contracts into `fft test` checks, FastFlowTransform can **enforce** contracts **at runtime** while building models. + +Runtime enforcement means: + +* FFT can **verify** that the materialized table matches the contract schema, and fail the run if not. +* FFT can **cast** the model output into the declared physical types before creating the table. + +This is configured in **project-level `contracts.yml`** under `enforcement`. + +#### Enforcement modes + +Contracts enforcement supports three modes: + +* `off` + Do not enforce at build time. (Contracts may still generate tests.) + +* `verify` + Build the table normally, then verify the physical schema matches the contract. + +* `cast` + Build the table by selecting from your model and **casting** contract columns into their declared physical types, then verify. + +> `cast` is useful when your warehouse would infer “close but not exact” types (e.g. `COUNT(*)` becoming a sized numeric type) and you want stable physical types across engines. + +### Failure messages + +If enforcement fails, FFT raises an error like: + +* Missing/extra columns +* Type mismatch (expected vs actual physical type) +* Non-null/unique contract failures (if those are enforced at runtime in your setup) + +The error includes the table name and a list of mismatches. + +### Enforcement with incremental models + +When a model is materialized as `incremental`, FFT applies enforcement to the **incremental write path**, not only full refresh. + +Typical behavior: + +* On the first run, the model creates the target relation (full refresh behavior) and enforcement is applied. +* On subsequent runs, FFT computes a delta dataset and writes it using the engine’s incremental strategy (insert/merge/delete+insert, etc.). +* Enforcement is applied so the target table remains compatible with the contract. + +Practical recommendations: + +* If the incremental model relies on `unique_key`, make sure your source change simulation does not introduce duplicated keys in the delta. +* For “update simulation” in demos, prefer a **second full seed file** that represents the entire source after the update (not just appended rows), then rerun incremental. This produces a realistic “source changed” scenario without creating duplicates. + +### Tests vs runtime enforcement + +Contracts can be used in two independent ways: + +1. **Tests** (`fft test`) + Contracts generate test specs like `not_null`, `unique`, `accepted_values`, `regex_match`, and `column_physical_type`. + +2. **Runtime enforcement** (`fft run`) + Enforcement runs during model materialization and can fail the run early. + +You can use either one alone, or both together. + +### Enforcement for SQL models + +When enforcing contracts for a SQL model: + +* `verify` mode: + + 1. FFT creates the table/view normally from the model SQL + 2. FFT introspects the created object and compares the physical schema to the contract + +* `cast` mode: + + 1. FFT wraps the model SQL in a projection that casts the declared columns: + + ```sql + select + cast(col_a as ) as col_a, + cast(col_b as ) as col_b, + ... + -- optionally include extra columns if allow_extra_columns=true + from () as src + ``` + 2. FFT creates the table from that casted SELECT + 3. FFT verifies the resulting physical schema + +Notes: + +* Enforcement is best-effort: if a contract has no physical types for the current engine, `cast` mode cannot enforce and will fail with a clear error. +* `allow_extra_columns=true` means non-contracted columns are carried through unchanged. + +### Enforcement for Python models + +For Python models (pandas / Spark / Snowpark / BigFrames): + +* FFT first materializes the DataFrame result according to the executor. +* If enforcement is enabled, the runtime contracts layer may: + + * Stage the DataFrame into a temporary table (engine-specific) + * Re-create the target table using casts (`cast` mode) + * Or only verify the schema (`verify` mode) + +This allows a consistent enforcement mechanism even when the model result is not expressed as SQL. + --- ## Using contracts with `fft test` @@ -446,14 +614,8 @@ A few things contracts **do not** do yet: * Other engines may reject such tests with a clear “engine not supported” message. -The intended next step (not implemented yet) is an **“enforce schema”** mode -which uses contracts to drive actual table DDL (or casts) instead of only -post-hoc assertions. - -For now, contracts give you **schema-as-YAML** + **tests-from-contracts** in a -single, consistent place. - -Additional validation: +### Current limitations -* Duplicate YAML keys in contract files are rejected (the loader raises before - parsing). Fix or remove duplicates to proceed. +* Enforcement behavior can differ by engine depending on what the executor can introspect and how it stages/casts data. +* `cast` mode requires explicit `physical` types for the current engine. +* Some warehouses expose “decorated” physical types (e.g. `VARCHAR(16777216)`, `NUMBER(18,0)`) rather than a short base type name. Contracts should match the canonical/normalized representation used by the engine implementation. diff --git a/examples/dq_demo/models/staging/customers.contracts.yml b/examples/dq_demo/models/staging/customers.contracts.yml index 2ac46ff..63885c6 100644 --- a/examples/dq_demo/models/staging/customers.contracts.yml +++ b/examples/dq_demo/models/staging/customers.contracts.yml @@ -19,7 +19,7 @@ columns: duckdb: VARCHAR postgres: text bigquery: STRING - snowflake_snowpark: TEXT + snowflake_snowpark: VARCHAR databricks_spark: STRING status: @@ -32,7 +32,7 @@ columns: duckdb: VARCHAR postgres: text bigquery: STRING - snowflake_snowpark: TEXT + snowflake_snowpark: VARCHAR databricks_spark: STRING created_at: diff --git a/examples/incremental_demo/Makefile b/examples/incremental_demo/Makefile index 09629af..73a8cf0 100644 --- a/examples/incremental_demo/Makefile +++ b/examples/incremental_demo/Makefile @@ -1,10 +1,11 @@ -.PHONY: seed run_full run_incr dag test artifacts clean demo demo-open +.PHONY: seed seed_step2 run_full run_incr dag test artifacts clean demo demo-open # --- Config ------------------------------------------------------------------- DB ?= .local/incremental_demo.duckdb PROJECT ?= . UV ?= uv +LOCAL_SEEDS_DIR = $(PROJECT)/.local/seeds # Engine selector (duckdb|postgres|databricks_spark|bigquery|snowflake_snowpark) ENGINE ?= duckdb @@ -87,16 +88,21 @@ endif # --- Targets ------------------------------------------------------------------ -seed: - env $(BASE_ENV) $(UV) run fft seed "$(PROJECT)" --env $(PROFILE_ENV) +seed_v1: + @mkdir -p "$(LOCAL_SEEDS_DIR)" + @cp "$(PROJECT)/seeds/seed_events_v1.csv" "$(LOCAL_SEEDS_DIR)/seed_events.csv" + env $(BASE_ENV) FFT_SEEDS_DIR="$(LOCAL_SEEDS_DIR)" $(UV) run fft seed "$(PROJECT)" --env $(PROFILE_ENV) + +seed_v2: + @mkdir -p "$(LOCAL_SEEDS_DIR)" + @cp "$(PROJECT)/seeds/seed_events_v2.csv" "$(LOCAL_SEEDS_DIR)/seed_events.csv" + env $(BASE_ENV) FFT_SEEDS_DIR="$(LOCAL_SEEDS_DIR)" $(UV) run fft seed "$(PROJECT)" --env $(PROFILE_ENV) -# Full refresh (first run) run_full: - env $(RUN_ENV) $(UV) run fft run "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_FLAGS) --cache rw + env $(RUN_ENV) $(UV) run fft run "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_FLAGS) --cache=rw -# second/subsequent run: shows incremental/delta behaviour run_incr: - env $(RUN_ENV) $(UV) run fft run "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_FLAGS) --cache rw + env $(RUN_ENV) $(UV) run fft run "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_FLAGS) --cache=rw dag: env $(RUN_ENV) $(UV) run fft dag "$(PROJECT)" --env $(PROFILE_ENV) $(SELECT_FLAGS) --html @@ -122,13 +128,20 @@ demo-open: demo: clean @echo "== 🚀 Incremental Demo ($(ENGINE)) ==" - @echo "Profile=$(PROFILE_ENV) DB=$(DB) PROJECT=$(PROJECT) DBR_TABLE_FORMAT=$(DBR_TABLE_FORMAT)" - +$(MAKE) seed + @echo "== 1) First run (seed v1 → initial build) ==" + +$(MAKE) seed_v1 +$(MAKE) run_full @echo - @echo "== 🔁 Second run (Incremental/Delta) ==" + @echo "== 2) No-op run (same seed v1; should be mostly skipped) ==" + +$(MAKE) run_incr + @echo + @echo "== 3) Change seed data (seed v2 snapshot: update + new row) ==" + +$(MAKE) seed_v2 +$(MAKE) run_incr + @echo + @echo "== 4) DAG & artifacts ==" +$(MAKE) dag +$(MAKE) test +$(MAKE) artifacts - @echo "✅ Demo done. Open DAG here: $(PROJECT)/site/dag/index.html" + @echo + @echo "✅ Demo done. Open DAG at: $(PROJECT)/site/dag/index.html" diff --git a/examples/incremental_demo/contracts.yml b/examples/incremental_demo/contracts.yml new file mode 100644 index 0000000..a72cca2 --- /dev/null +++ b/examples/incremental_demo/contracts.yml @@ -0,0 +1,33 @@ +# Project-level contracts for incremental_demo. + +defaults: + columns: + - match: + name: "^event_id$" + type: integer + nullable: false + + - match: + name: "^updated_at$" + type: timestamp + nullable: false + +enforcement: + # Modes: off | verify | cast + default_mode: off + allow_extra_columns: true + + tables: + # incremental SQL examples + fct_events_sql_inline: + mode: verify + allow_extra_columns: true + + fct_events_sql_yaml: + mode: verify + allow_extra_columns: true + + # python incremental example + fct_events_py_incremental: + mode: verify + allow_extra_columns: true diff --git a/examples/incremental_demo/models/common/fct_events_py_incremental.contracts.yml b/examples/incremental_demo/models/common/fct_events_py_incremental.contracts.yml new file mode 100644 index 0000000..d047005 --- /dev/null +++ b/examples/incremental_demo/models/common/fct_events_py_incremental.contracts.yml @@ -0,0 +1,47 @@ +table: fct_events_py_incremental + +columns: + event_id: + type: integer + nullable: false + unique: true + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT + + updated_at: + type: timestamp + nullable: false + physical: + duckdb: TIMESTAMP + postgres: timestamp + bigquery: TIMESTAMP + databricks_spark: TIMESTAMP + snowflake_snowpark: TIMESTAMP_NTZ + default: TIMESTAMP + + value: + type: integer + nullable: false + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT + + value_x10: + type: integer + nullable: false + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT diff --git a/examples/incremental_demo/models/common/fct_events_sql_inline.contracts.yml b/examples/incremental_demo/models/common/fct_events_sql_inline.contracts.yml new file mode 100644 index 0000000..51791d4 --- /dev/null +++ b/examples/incremental_demo/models/common/fct_events_sql_inline.contracts.yml @@ -0,0 +1,36 @@ +table: fct_events_sql_inline + +columns: + event_id: + type: integer + nullable: false + unique: true + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT + + updated_at: + type: timestamp + nullable: false + physical: + duckdb: TIMESTAMP + postgres: timestamp + bigquery: TIMESTAMP + databricks_spark: TIMESTAMP + snowflake_snowpark: TIMESTAMP_NTZ + default: TIMESTAMP + + value: + type: integer + nullable: false + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT diff --git a/examples/incremental_demo/models/common/fct_events_sql_yaml.contracts.yml b/examples/incremental_demo/models/common/fct_events_sql_yaml.contracts.yml new file mode 100644 index 0000000..8ca6245 --- /dev/null +++ b/examples/incremental_demo/models/common/fct_events_sql_yaml.contracts.yml @@ -0,0 +1,36 @@ +table: fct_events_sql_yaml + +columns: + event_id: + type: integer + nullable: false + unique: true + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT + + updated_at: + type: timestamp + nullable: false + physical: + duckdb: TIMESTAMP + postgres: timestamp + bigquery: TIMESTAMP + databricks_spark: TIMESTAMP + snowflake_snowpark: TIMESTAMP_NTZ + default: TIMESTAMP + + value: + type: integer + nullable: false + physical: + duckdb: BIGINT + postgres: bigint + bigquery: INT64 + databricks_spark: BIGINT + snowflake_snowpark: NUMBER + default: BIGINT diff --git a/examples/incremental_demo/project.yml b/examples/incremental_demo/project.yml index e29c791..dba5f44 100644 --- a/examples/incremental_demo/project.yml +++ b/examples/incremental_demo/project.yml @@ -8,9 +8,6 @@ models: fct_events_sql_inline.ff: unique_key: "event_id" - fct_events_sql_yaml.ff: - unique_key: "event_id" - fct_events_sql_yaml.ff: unique_key: "event_id" # top-level shortcut incremental: @@ -20,6 +17,12 @@ models: fct_events_sql_inline_delta.ff: unique_key: "event_id" + fct_events_py_incremental: + unique_key: "event_id" + incremental: + enabled: true + updated_at_column: "updated_at" + seeds: {} tests: diff --git a/examples/incremental_demo/seeds/seed_events.csv b/examples/incremental_demo/seeds/seed_events_v1.csv similarity index 100% rename from examples/incremental_demo/seeds/seed_events.csv rename to examples/incremental_demo/seeds/seed_events_v1.csv diff --git a/examples/incremental_demo/seeds/seed_events_v2.csv b/examples/incremental_demo/seeds/seed_events_v2.csv new file mode 100644 index 0000000..d670163 --- /dev/null +++ b/examples/incremental_demo/seeds/seed_events_v2.csv @@ -0,0 +1,5 @@ +event_id,updated_at,value +1,2024-01-01T00:00:00,10 +2,2024-01-05T00:00:00,999 +3,2024-01-03T00:00:00,30 +4,2024-01-06T00:00:00,40 diff --git a/src/fastflowtransform/cli/bootstrap.py b/src/fastflowtransform/cli/bootstrap.py index 6145d84..d0e276f 100644 --- a/src/fastflowtransform/cli/bootstrap.py +++ b/src/fastflowtransform/cli/bootstrap.py @@ -367,6 +367,16 @@ def _make_executor(prof: Profile, jenv: Environment) -> tuple[BaseExecutor, Call if prof.bigquery.dataset is None: raise RuntimeError("BigQuery dataset must be set") + # Validate env-provided frame selector early (used by examples/Makefiles) + frame_env = os.getenv("FF_ENGINE_VARIANT") or os.getenv("BQ_FRAME") + if frame_env: + frame_normalized = frame_env.lower() + if frame_normalized not in {"pandas", "bigframes"}: + raise RuntimeError( + f"Unsupported BigQuery frame '{frame_env}'. " + "Set FF_ENGINE_VARIANT/BQ_FRAME to 'pandas' or 'bigframes'." + ) + if prof.bigquery.use_bigframes: BigQueryBFExecutor = _import_optional( "fastflowtransform.executors.bigquery.bigframes", diff --git a/src/fastflowtransform/contracts/runtime/base.py b/src/fastflowtransform/contracts/runtime/base.py index 79731c8..a1094dd 100644 --- a/src/fastflowtransform/contracts/runtime/base.py +++ b/src/fastflowtransform/contracts/runtime/base.py @@ -1,6 +1,7 @@ # fastflowtransform/contracts/runtime/base.by from __future__ import annotations +import re from dataclasses import dataclass from typing import Any, Protocol, TypeVar @@ -72,6 +73,34 @@ def _resolve_physical_type_for_engine( return None +def _canonicalize_physical_type(engine_name: str, typ: str | None) -> str | None: + """ + Apply minimal, engine-specific normalization so expected vs. actual types + compare predictably. Keep this small and focused on real metadata quirks. + """ + if typ is None: + return None + engine = (engine_name or "").lower() + t = typ.strip() + if not t: + return None + + # Snowflake: information_schema reports all string-family types as TEXT with a + # length column; normalize common aliases to VARCHAR and drop the huge default. + if engine.startswith("snowflake"): + upper = t.upper() + if upper in {"TEXT", "STRING", "CHAR", "CHARACTER"}: + return "VARCHAR" + if re.fullmatch(r"VARCHAR\s*\(\s*16777216\s*\)", upper): + return "VARCHAR" + if upper in {"DECIMAL", "NUMERIC"}: + return "NUMBER" + return upper + + # Default: case-insensitive comparison only. + return t.upper() + + def expected_physical_schema( *, executor: ContractExecutor, @@ -91,7 +120,9 @@ def expected_physical_schema( phys = col_model.physical typ = _resolve_physical_type_for_engine(phys, engine) if typ: - result[col_name] = typ + canon = _canonicalize_physical_type(engine, typ) + if canon: + result[col_name] = canon return result diff --git a/src/fastflowtransform/contracts/runtime/bigquery.py b/src/fastflowtransform/contracts/runtime/bigquery.py new file mode 100644 index 0000000..d0b213c --- /dev/null +++ b/src/fastflowtransform/contracts/runtime/bigquery.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from fastflowtransform.contracts.runtime.base import ( + BaseRuntimeContracts, + ContractExecutor, + RuntimeContractConfig, + RuntimeContractContext, + expected_physical_schema, +) + + +class BigQueryRuntimeContracts(BaseRuntimeContracts): + """ + Runtime schema contracts for BigQuery. + + Notes: + - executor._execute_sql returns a job-like object; we force execution via .result() + when present. + - CAST mode uses BigQuery's `src.* EXCEPT(col1, col2, ...)` to retain extra columns. + """ + + def __init__(self, executor: ContractExecutor): + super().__init__(executor) + + # --- helpers --------------------------------------------------------- + + def _exec(self, sql: str) -> Any: + res = self.executor._execute_sql(sql) + # BigQuery QueryJob / our _TrackedQueryJob: execute via .result() + result_fn = getattr(res, "result", None) + if callable(result_fn): + return result_fn() + # Spark-like fallbacks (harmless here, but keeps this helper generic) + collect_fn = getattr(res, "collect", None) + if callable(collect_fn): + return collect_fn() + return res + + def _verify( + self, + *, + table: str, + expected: Mapping[str, str], + cfg: RuntimeContractConfig, + ) -> None: + if not expected: + return + + actual = self.executor.introspect_table_physical_schema(table) # {lower_name: TYPE} + exp_lower = {k.lower(): v for k, v in expected.items()} + + problems: list[str] = [] + + for col, expected_type in expected.items(): + key = col.lower() + if key not in actual: + problems.append(f"- missing column {col!r}") + continue + got = actual[key] + if str(got).lower() != str(expected_type).lower(): + problems.append(f"- column {col!r}: expected type {expected_type!r}, got {got!r}") + + if not cfg.allow_extra_columns: + extras = [c for c in actual if c not in exp_lower] + if extras: + problems.append(f"- extra columns present: {sorted(extras)}") + + if problems: + raise RuntimeError( + f"[contracts] BigQuery schema enforcement failed for {table}:\n" + + "\n".join(problems) + ) + + def _ctas_raw(self, target: str, select_body: str) -> None: + # BigQuery supports CREATE OR REPLACE TABLE ... AS