diff --git a/pyproject.toml b/pyproject.toml index e8919f7..d03ff07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "fastflowtransform" -version = "0.6.9" +version = "0.6.10" description = "Python framework for SQL & Python data transformation, ETL pipelines, and dbt-style data modeling" readme = "README.md" license = { text = "Apache-2.0" } diff --git a/src/fastflowtransform/executors/_budget_runner.py b/src/fastflowtransform/executors/_budget_runner.py index 9ea4a95..8817f68 100644 --- a/src/fastflowtransform/executors/_budget_runner.py +++ b/src/fastflowtransform/executors/_budget_runner.py @@ -5,6 +5,7 @@ from time import perf_counter from typing import Any +from fastflowtransform.executors._query_stats_adapter import QueryStatsAdapter, RowcountStatsAdapter from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats @@ -20,6 +21,7 @@ def run_sql_with_budget( estimate_fn: Callable[[str], int | None] | None = None, post_estimate_fn: Callable[[str, Any], int | None] | None = None, record_stats: bool = True, + stats_adapter: QueryStatsAdapter | None = None, ) -> Any: """ Shared helper for guarded SQL execution with timing + stats recording. @@ -51,36 +53,18 @@ def run_sql_with_budget( result = exec_fn() duration_ms = int((perf_counter() - started) * 1000) - rows: int | None = None - if rowcount_extractor is not None: - with suppress(Exception): - rows = rowcount_extractor(result) - - stats = QueryStats(bytes_processed=estimated_bytes, rows=rows, duration_ms=duration_ms) - - if stats.bytes_processed is None and post_estimate_fn is not None: - with suppress(Exception): - post_estimate = post_estimate_fn(sql, result) - if post_estimate is not None: - stats = QueryStats( - bytes_processed=post_estimate, - rows=stats.rows, - duration_ms=stats.duration_ms, - ) - - if extra_stats is not None: - with suppress(Exception): - extra = extra_stats(result) - if extra: - stats = QueryStats( - bytes_processed=extra.bytes_processed - if extra.bytes_processed is not None - else stats.bytes_processed, - rows=extra.rows if extra.rows is not None else stats.rows, - duration_ms=extra.duration_ms - if extra.duration_ms is not None - else stats.duration_ms, - ) + adapter = stats_adapter + if adapter is None and (rowcount_extractor or post_estimate_fn or extra_stats): + adapter = RowcountStatsAdapter( + rowcount_extractor=rowcount_extractor, + post_estimate_fn=post_estimate_fn, + extra_stats=extra_stats, + sql=sql, + ) + if adapter is None: + stats = QueryStats(bytes_processed=estimated_bytes, rows=None, duration_ms=duration_ms) + else: + stats = adapter.collect(result, duration_ms=duration_ms, estimated_bytes=estimated_bytes) executor._record_query_stats(stats) return result diff --git a/src/fastflowtransform/executors/_query_stats_adapter.py b/src/fastflowtransform/executors/_query_stats_adapter.py new file mode 100644 index 0000000..f2a01c1 --- /dev/null +++ b/src/fastflowtransform/executors/_query_stats_adapter.py @@ -0,0 +1,142 @@ +# fastflowtransform/executors/_query_stats_adapter.py +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Protocol + +from fastflowtransform.executors.query_stats import QueryStats + + +class QueryStatsAdapter(Protocol): + """Adapter interface to normalize stats extraction across engines.""" + + def collect( + self, result: Any, *, duration_ms: int | None, estimated_bytes: int | None + ) -> QueryStats: ... + + +class RowcountStatsAdapter: + """ + Default stats adapter for DB-API style executors that expose rowcount. + Preserves existing post_estimate/extra_stats hook behavior. + """ + + def __init__( + self, + *, + rowcount_extractor: Callable[[Any], int | None] | None = None, + post_estimate_fn: Callable[[str, Any], int | None] | None = None, + extra_stats: Callable[[Any], QueryStats | None] | None = None, + sql: str | None = None, + ) -> None: + self.rowcount_extractor = rowcount_extractor + self.post_estimate_fn = post_estimate_fn + self.extra_stats = extra_stats + self.sql = sql + + def collect( + self, result: Any, *, duration_ms: int | None, estimated_bytes: int | None + ) -> QueryStats: + rows: int | None = None + if self.rowcount_extractor is not None: + try: + rows = self.rowcount_extractor(result) + except Exception: + rows = None + + stats = QueryStats(bytes_processed=estimated_bytes, rows=rows, duration_ms=duration_ms) + + if stats.bytes_processed is None and self.post_estimate_fn is not None: + try: + post_estimate = self.post_estimate_fn(self.sql or "", result) + except Exception: + post_estimate = None + if post_estimate is not None: + stats = QueryStats( + bytes_processed=post_estimate, + rows=stats.rows, + duration_ms=stats.duration_ms, + ) + + if self.extra_stats is not None: + try: + extra = self.extra_stats(result) + except Exception: + extra = None + if extra: + stats = QueryStats( + bytes_processed=extra.bytes_processed + if extra.bytes_processed is not None + else stats.bytes_processed, + rows=extra.rows if extra.rows is not None else stats.rows, + duration_ms=( + extra.duration_ms if extra.duration_ms is not None else stats.duration_ms + ), + ) + + return stats + + +class JobStatsAdapter: + """ + Generic job-handle stats extractor (BigQuery/Snowflake/Spark-like objects). + Mirrors the previous `_record_query_job_stats` heuristics. + """ + + def collect(self, job: Any) -> QueryStats: + def _safe_int(val: Any) -> int | None: + try: + if val is None: + return None + return int(val) + except Exception: + return None + + bytes_processed = _safe_int( + getattr(job, "total_bytes_processed", None) or getattr(job, "bytes_processed", None) + ) + + rows = _safe_int( + getattr(job, "num_dml_affected_rows", None) + or getattr(job, "total_rows", None) + or getattr(job, "rowcount", None) + ) + + duration_ms: int | None = None + try: + started = getattr(job, "started", None) + ended = getattr(job, "ended", None) + if started is not None and ended is not None: + duration_ms = int((ended - started).total_seconds() * 1000) + except Exception: + duration_ms = None + + return QueryStats(bytes_processed=bytes_processed, rows=rows, duration_ms=duration_ms) + + +class SparkDataFrameStatsAdapter: + """ + Adapter for Spark DataFrames that mirrors existing Databricks behaviour: + - bytes via provided bytes_fn (plan-based best effort) + - rows left as None + - duration passed through + """ + + def __init__(self, bytes_fn: Callable[[Any], int | None]) -> None: + self.bytes_fn = bytes_fn + + def collect( + self, df: Any, *, duration_ms: int | None, estimated_bytes: int | None = None + ) -> QueryStats: + bytes_val = estimated_bytes + if bytes_val is None: + try: + bytes_val = self.bytes_fn(df) + except Exception: + bytes_val = None + + return QueryStats( + bytes_processed=bytes_val if bytes_val is not None and bytes_val > 0 else None, + rows=None, + duration_ms=duration_ms, + ) diff --git a/src/fastflowtransform/executors/_snapshot_sql_mixin.py b/src/fastflowtransform/executors/_snapshot_sql_mixin.py new file mode 100644 index 0000000..0e15da5 --- /dev/null +++ b/src/fastflowtransform/executors/_snapshot_sql_mixin.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +from collections.abc import Callable, Iterable +from contextlib import suppress +from typing import TYPE_CHECKING, Any, cast + +from jinja2 import Environment + +from fastflowtransform.core import Node, relation_for +from fastflowtransform.logging import echo +from fastflowtransform.snapshots import resolve_snapshot_config + +if TYPE_CHECKING: + # Adjust this import to your actual path + from fastflowtransform.executors.base import BaseExecutor + + +class SnapshotSqlMixin: + """ + Shared SQL snapshot materialization (timestamp + check strategies). + + Engines provide small hooks for identifier qualification, expressions, + staging, and execution. All column names come from BaseExecutor constants. + """ + + def run_snapshot_sql(self, node: Node, env: Environment) -> None: + ex = cast("BaseExecutor[Any]", self) + + meta = self._snapshot_validate_node(node) + cfg = resolve_snapshot_config(node, meta) + + body = self._snapshot_render_body(node, env) + rel_name = relation_for(node.name) + target = self._snapshot_target_identifier(rel_name) + if not cfg.unique_key: + raise ValueError(f"{node.path}: snapshot models require a non-empty unique_key list.") + + vf = self.SNAPSHOT_VALID_FROM_COL # type: ignore[attr-defined] + vt = self.SNAPSHOT_VALID_TO_COL # type: ignore[attr-defined] + is_cur = self.SNAPSHOT_IS_CURRENT_COL # type: ignore[attr-defined] + hash_col = self.SNAPSHOT_HASH_COL # type: ignore[attr-defined] + upd_meta = self.SNAPSHOT_UPDATED_AT_COL # type: ignore[attr-defined] + + self._snapshot_prepare_target() + + # First run: create snapshot table + if not ex.exists_relation(rel_name): + sql = self._snapshot_first_run_sql( + body=body, + strategy=cfg.strategy, + unique_key=cfg.unique_key, + updated_at=cfg.updated_at, + check_cols=cfg.check_cols, + target=target, + vf=vf, + vt=vt, + is_cur=is_cur, + hash_col=hash_col, + upd_meta=upd_meta, + ) + self._snapshot_exec_and_wait(sql) + return + + # Incremental update + src_ref, cleanup = self._snapshot_source_ref(rel_name, body) + try: + keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in cfg.unique_key]) or "FALSE" + + if cfg.strategy == "timestamp": + if not cfg.updated_at: + raise ValueError( + f"{node.path}: strategy='timestamp' snapshots require an updated_at column." + ) + change_condition = f"s.{cfg.updated_at} > t.{upd_meta}" + new_upd_expr = self._snapshot_updated_at_expr(cfg.updated_at, "s") + new_valid_from_expr = self._snapshot_updated_at_expr(cfg.updated_at, "s") + new_hash_expr = self._snapshot_null_hash() + else: + hash_expr_s = self._snapshot_hash_expr(cfg.check_cols, "s") + change_condition = f"COALESCE({hash_expr_s}, '') <> COALESCE(t.{hash_col}, '')" + new_upd_expr = ( + self._snapshot_updated_at_expr(cfg.updated_at, "s") + if cfg.updated_at + else self._snapshot_current_timestamp() + ) + new_valid_from_expr = self._snapshot_current_timestamp() + new_hash_expr = hash_expr_s + + close_sql = self._snapshot_close_sql( + target=target, + src_ref=src_ref, + keys_pred=keys_pred, + change_condition=change_condition, + vt=vt, + is_cur=is_cur, + ) + self._snapshot_exec_and_wait(close_sql) + + insert_sql = self._snapshot_insert_sql( + target=target, + src_ref=src_ref, + keys_pred=keys_pred, + first_key=cfg.unique_key[0], + new_upd_expr=new_upd_expr, + new_valid_from_expr=new_valid_from_expr, + new_hash_expr=new_hash_expr, + change_condition=change_condition, + vf=vf, + vt=vt, + is_cur=is_cur, + hash_col=hash_col, + upd_meta=upd_meta, + ) + self._snapshot_exec_and_wait(insert_sql) + finally: + with suppress(Exception): + cleanup() + + # ---- Core SQL builders ------------------------------------------------- + def _snapshot_first_run_sql( + self, + *, + body: str, + strategy: str, + unique_key: list[str], + updated_at: str | None, + check_cols: list[str], + target: str, + vf: str, + vt: str, + is_cur: str, + hash_col: str, + upd_meta: str, + ) -> str: + if not unique_key: + raise ValueError("Snapshot models require a non-empty unique_key list.") + + if strategy == "timestamp": + if not updated_at: + raise ValueError("strategy='timestamp' snapshots require an updated_at column.") + return f""" +{self._snapshot_create_keyword()} {target} AS +SELECT + s.*, + {self._snapshot_updated_at_expr(updated_at, "s")} AS {upd_meta}, + {self._snapshot_updated_at_expr(updated_at, "s")} AS {vf}, + {self._snapshot_null_timestamp()} AS {vt}, + TRUE AS {is_cur}, + {self._snapshot_null_hash()} AS {hash_col} +FROM ({body}) AS s +""" + + if not check_cols: + raise ValueError("strategy='check' snapshots require non-empty check_cols.") + + hash_expr = self._snapshot_hash_expr(check_cols, "s") + upd_expr = ( + self._snapshot_updated_at_expr(updated_at, "s") + if updated_at + else self._snapshot_current_timestamp() + ) + return f""" +{self._snapshot_create_keyword()} {target} AS +SELECT + s.*, + {upd_expr} AS {upd_meta}, + {self._snapshot_current_timestamp()} AS {vf}, + {self._snapshot_null_timestamp()} AS {vt}, + TRUE AS {is_cur}, + {hash_expr} AS {hash_col} +FROM ({body}) AS s +""" + + def _snapshot_close_sql( + self, + *, + target: str, + src_ref: str, + keys_pred: str, + change_condition: str, + vt: str, + is_cur: str, + ) -> str: + return f""" +UPDATE {target} AS t +SET + {vt} = {self._snapshot_current_timestamp()}, + {is_cur} = FALSE +FROM {src_ref} AS s +WHERE + {keys_pred} + AND t.{is_cur} = TRUE + AND {change_condition} +""" + + def _snapshot_insert_sql( + self, + *, + target: str, + src_ref: str, + keys_pred: str, + first_key: str, + new_upd_expr: str, + new_valid_from_expr: str, + new_hash_expr: str, + change_condition: str, + vf: str, + vt: str, + is_cur: str, + hash_col: str, + upd_meta: str, + ) -> str: + return f""" +INSERT INTO {target} +SELECT + s.*, + {new_upd_expr} AS {upd_meta}, + {new_valid_from_expr} AS {vf}, + {self._snapshot_null_timestamp()} AS {vt}, + TRUE AS {is_cur}, + {new_hash_expr} AS {hash_col} +FROM {src_ref} AS s +LEFT JOIN {target} AS t + ON {keys_pred} + AND t.{is_cur} = TRUE +WHERE + t.{first_key} IS NULL + OR {change_condition} +""" + + # ---- Pruning ----------------------------------------------------------- + def snapshot_prune( + self, + relation: str, + unique_key: list[str], + keep_last: int, + *, + dry_run: bool = False, + ) -> None: + """ + Delete older snapshot versions while keeping the most recent `keep_last` + rows per business key (including the current row). + """ + ex = cast("BaseExecutor[Any]", self) + + if keep_last <= 0: + return + + keys = [k for k in unique_key if k] + if not keys: + return + + target = self._snapshot_target_identifier(relation) + vf = self.SNAPSHOT_VALID_FROM_COL # type: ignore[attr-defined] + + key_select = ", ".join(keys) + part_by = ", ".join(keys) + + ranked_sql = f""" +SELECT + {key_select}, + {vf}, + ROW_NUMBER() OVER ( + PARTITION BY {part_by} + ORDER BY {vf} DESC + ) AS rn +FROM {target} +""" + + if dry_run: + sql = f""" +WITH ranked AS ( + {ranked_sql} +) +SELECT COUNT(*) AS rows_to_delete +FROM ranked +WHERE rn > {int(keep_last)} +""" + res = ex._execute_sql(sql) + count = self._snapshot_fetch_count(res) + echo( + f"[DRY-RUN] snapshot_prune({relation}): would delete {count} row(s) " + f"(keep_last={keep_last})" + ) + return + + join_pred = " AND ".join([f"t.{k} = r.{k}" for k in keys]) + delete_sql = f""" +DELETE FROM {target} t +USING ( + {ranked_sql} +) r +WHERE + r.rn > {int(keep_last)} + AND {join_pred} + AND t.{vf} = r.{vf} +""" + ex._execute_sql(delete_sql) + + # ---- Rendering helpers ------------------------------------------------- + def _snapshot_render_body(self, node: Node, env: Environment) -> str: + ex = cast("BaseExecutor[Any]", self) + + sql_rendered = ex.render_sql( + node, + env, + ref_resolver=lambda name: ex._resolve_ref(name, env), + source_resolver=ex._resolve_source, + ) + sql_clean = ex._strip_leading_config(sql_rendered).strip() + return ex._selectable_body(sql_clean).rstrip(" ;\n\t") + + def _snapshot_validate_node(self, node: Node) -> dict[str, Any]: + if node.kind != "sql": + raise TypeError( + f"Snapshot materialization is only supported for SQL models, " + f"got kind={node.kind!r} for {node.name}." + ) + + meta = getattr(node, "meta", {}) or {} + if not self._meta_is_snapshot(meta): # type: ignore[attr-defined] + raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") + return meta + + # ---- Staging ----------------------------------------------------------- + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: + """ + Return (source_ref, cleanup). Default: inline subquery. + Engines can override to use temp views/tables and cleanup afterward. + """ + return f"({select_body})", lambda: None + + # ---- Hooks (must be provided by engines) ------------------------------- + def _snapshot_target_identifier(self, rel_name: str) -> str: # pragma: no cover - abstract + raise NotImplementedError + + def _snapshot_current_timestamp(self) -> str: # pragma: no cover - abstract + raise NotImplementedError + + def _snapshot_null_timestamp(self) -> str: # pragma: no cover - abstract + raise NotImplementedError + + def _snapshot_null_hash(self) -> str: # pragma: no cover - abstract + raise NotImplementedError + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: # pragma: no cover + raise NotImplementedError + + def _snapshot_updated_at_expr(self, updated_at: str, src_alias: str) -> str: + return f"{src_alias}.{updated_at}" + + def _snapshot_prepare_target(self) -> None: + """Hook for engines that need to ensure dataset/schema before writes.""" + return None + + def _snapshot_exec_and_wait(self, sql: str) -> None: + """ + Execute SQL and, if necessary, wait for completion (jobs, lazy DataFrames). + """ + res = self._execute_sql(sql) # type: ignore[attr-defined] + if res is None: + return + for attr in ("result", "collect"): + fn = getattr(res, attr, None) + if callable(fn): + with suppress(Exception): + fn() + break + + # ---- Helpers ----------------------------------------------------------- + def _snapshot_concat_expr(self, columns: list[str], src_alias: str) -> str: + parts = [ + self._snapshot_coalesce(self._snapshot_cast_as_string(f"{src_alias}.{col}"), "''") + for col in columns + ] + return " || '||' || ".join(parts) if parts else "''" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS STRING)" + + def _snapshot_coalesce(self, expr: str, default: str) -> str: + return f"COALESCE({expr}, {default})" + + def _snapshot_create_keyword(self) -> str: + """Hook to allow engines to override CREATE vs CREATE OR REPLACE.""" + return "CREATE TABLE" + + def _snapshot_fetch_count(self, res: Any) -> int: + """ + Best-effort extraction of a single COUNT(*) value from various result shapes. + """ + try: + if hasattr(res, "fetchone"): + row = res.fetchone() + if row is not None: + return int(row[0]) + if hasattr(res, "fetchall"): + rows = res.fetchall() + if rows: + return int(rows[0][0]) + result_fn = getattr(res, "result", None) + if callable(result_fn): + rows_obj = result_fn() + if isinstance(rows_obj, Iterable): + rows = list(rows_obj) + if rows: + return int(rows[0][0]) + collect_fn = getattr(res, "collect", None) + if callable(collect_fn): + rows_obj = collect_fn() + if isinstance(rows_obj, Iterable): + rows = list(rows_obj) + if rows: + return int(rows[0][0]) + if isinstance(res, list) and res: + return int(res[0][0]) + except Exception: + return 0 + return 0 diff --git a/src/fastflowtransform/executors/_sql_identifier.py b/src/fastflowtransform/executors/_sql_identifier.py new file mode 100644 index 0000000..2a8a3ae --- /dev/null +++ b/src/fastflowtransform/executors/_sql_identifier.py @@ -0,0 +1,121 @@ +# fastflowtransform/executors/_sql_identifier.py +from __future__ import annotations + +from typing import Any + +from fastflowtransform.core import relation_for + + +class SqlIdentifierMixin: + """ + Thin helper mixin for engines that qualify SQL identifiers with optional + catalog/database and schema. + + Subclasses must implement `_quote_identifier` and may override the + *_default_* / *_should_include_catalog methods to match engine quirks. + """ + + def _normalize_identifier(self, ident: str) -> str: + """ + Normalize fastflowtransform's logical identifiers: + - Strip `.ff` suffixes via relation_for + - Leave other strings untouched + """ + if not isinstance(ident, str): + return ident + return relation_for(ident) if ident.endswith(".ff") else ident + + def _clean_part(self, part: Any) -> str | None: + if not isinstance(part, str): + return None + stripped = part.strip() + return stripped or None + + def _quote_identifier(self, ident: str) -> str: # pragma: no cover - abstract + """Engine-specific quoting (e.g., \"name\" or `name`).""" + raise NotImplementedError + + def _default_schema(self) -> str | None: + return self._clean_part(getattr(self, "schema", None)) + + def _default_catalog(self) -> str | None: + return self._clean_part(getattr(self, "catalog", None)) + + def _default_catalog_for_source(self, schema: str | None) -> str | None: + """Hook to adjust catalog fallback for sources (override per engine).""" + return self._default_catalog() + + def _should_include_catalog( + self, catalog: str | None, schema: str | None, *, explicit: bool + ) -> bool: + """ + Decide whether to emit the catalog in a qualified identifier. + + explicit=True when the caller passed a catalog argument (as opposed to + using defaults), so engines can honour explicit catalogs even if they + normally omit them. + """ + return bool(catalog) + + def _qualify_identifier( + self, + ident: str, + *, + schema: str | None = None, + catalog: str | None = None, + quote: bool = True, + ) -> str: + """ + Assemble a qualified identifier (catalog.schema.ident) with engine + defaults and quoting. + """ + normalized = self._normalize_identifier(ident) + explicit_catalog = catalog is not None + sch = self._clean_part(schema) or self._default_schema() + cat = self._clean_part(catalog) if explicit_catalog else self._default_catalog() + + parts: list[str] = [] + if self._should_include_catalog(cat, sch, explicit=explicit_catalog) and cat: + parts.append(cat) + if sch: + parts.append(sch) + parts.append(normalized) + + if not quote: + return ".".join(parts) + return ".".join(self._quote_identifier(p) for p in parts) + + # ---- Shared formatting hooks ----------------------------------------- + def _format_relation_for_ref(self, name: str) -> str: + return self._qualify_identifier(relation_for(name)) + + def _pick_schema(self, cfg: dict[str, Any]) -> str | None: + for key in ("schema", "dataset"): + candidate = self._clean_part(cfg.get(key)) + if candidate: + return candidate + return self._default_schema() + + def _pick_catalog(self, cfg: dict[str, Any], schema: str | None) -> str | None: + for key in ("catalog", "database", "project"): + candidate = self._clean_part(cfg.get(key)) + if candidate: + return candidate + return self._default_catalog_for_source(schema) + + def _format_source_reference( + self, cfg: dict[str, Any], source_name: str, table_name: str + ) -> str: + if cfg.get("location"): + raise NotImplementedError( + f"{getattr(self, 'engine_name', 'unknown')} executor " + "does not support path-based sources." + ) + + ident = cfg.get("identifier") + if not ident: + raise KeyError(f"Source {source_name}.{table_name} missing identifier") + + schema = self._pick_schema(cfg) + catalog = self._pick_catalog(cfg, schema) + return self._qualify_identifier(ident, schema=schema, catalog=catalog) diff --git a/src/fastflowtransform/executors/base.py b/src/fastflowtransform/executors/base.py index f8357b2..4a33461 100644 --- a/src/fastflowtransform/executors/base.py +++ b/src/fastflowtransform/executors/base.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Mapping from contextlib import suppress -from datetime import datetime from pathlib import Path from typing import Any, TypeVar, cast @@ -20,6 +19,7 @@ from fastflowtransform.config.sources import resolve_source_entry from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError +from fastflowtransform.executors._query_stats_adapter import JobStatsAdapter from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.incremental import _normalize_unique_key @@ -424,6 +424,15 @@ def _execute_sql_direct(self, sql: str, node: Node) -> None: raise NotImplementedError("Direct DDL execution is not implemented for this executor.") con.execute(sql) + def _execute_sql( + self, sql: str, *args: Any, **kwargs: Any + ) -> Any: # pragma: no cover - abstract + """ + Engine-specific SQL execution hook used by shared helpers (snapshots, pruning, etc.). + Concrete executors override this with their own signatures and semantics. + """ + raise NotImplementedError + def _render_ephemeral_sql(self, name: str, env: Environment) -> str: """ Render the SQL for an 'ephemeral' model and return it as a parenthesized @@ -486,42 +495,8 @@ def _record_query_job_stats(self, job: Any) -> None: (BigQuery, Snowflake, Spark) can pass them here. Engines can override this if they want more precise logic. """ - - def _safe_int(val: Any) -> int | None: - try: - if val is None: - return None - return int(val) - except Exception: - return None - - # Heuristic attribute names - BigQuery and others may expose these. - bytes_processed = _safe_int( - getattr(job, "total_bytes_processed", None) or getattr(job, "bytes_processed", None) - ) - - rows = _safe_int( - getattr(job, "num_dml_affected_rows", None) - or getattr(job, "total_rows", None) - or getattr(job, "rowcount", None) - ) - - duration_ms: int | None = None - try: - started = getattr(job, "started", None) - ended = getattr(job, "ended", None) - if isinstance(started, datetime) and isinstance(ended, datetime): - duration_ms = int((ended - started).total_seconds() * 1000) - except Exception: - pass - - self._record_query_stats( - QueryStats( - bytes_processed=bytes_processed, - rows=rows, - duration_ms=duration_ms, - ) - ) + adapter = JobStatsAdapter() + self._record_query_stats(adapter.collect(job)) def configure_query_budget_limit(self, limit: int | None) -> None: """ diff --git a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py b/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py index 3da7441..4190192 100644 --- a/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py +++ b/src/fastflowtransform/executors/bigquery/_bigquery_mixin.py @@ -1,10 +1,11 @@ # fastflowtransform/executors/_bigquery_mixin.py from __future__ import annotations +from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.typing import NotFound, bigquery -class BigQueryIdentifierMixin: +class BigQueryIdentifierMixin(SqlIdentifierMixin): """ Mixin that provides common BigQuery helpers (identifier quoting, dataset creation). Expect subclasses to define: self.project, self.dataset, self.client. @@ -14,16 +15,44 @@ class BigQueryIdentifierMixin: dataset: str client: bigquery.Client - @staticmethod - def _bq_quote(value: str) -> str: + def _bq_quote(self, value: str) -> str: return value.replace("`", "\\`") + def _quote_identifier(self, ident: str) -> str: + return self._bq_quote(ident) + + def _default_schema(self) -> str | None: + return self.dataset + + def _default_catalog(self) -> str | None: + return self.project + + def _should_include_catalog( + self, catalog: str | None, schema: str | None, *, explicit: bool + ) -> bool: + # BigQuery always expects a project + dataset. + return True + + def _qualify_identifier( + self, + ident: str, + *, + schema: str | None = None, + catalog: str | None = None, + quote: bool = True, + ) -> str: + proj = self._clean_part(catalog) or self._default_catalog() + dset = self._clean_part(schema) or self._default_schema() + normalized = self._normalize_identifier(ident) + parts = [proj, dset, normalized] + if not quote: + return ".".join(p for p in parts if p) + return f"`{'.'.join(self._bq_quote(p) for p in parts if p)}`" + def _qualified_identifier( self, relation: str, project: str | None = None, dataset: str | None = None ) -> str: - proj = project or self.project - dset = dataset or self.dataset - return f"`{self._bq_quote(proj)}.{self._bq_quote(dset)}.{self._bq_quote(relation)}`" + return self._qualify_identifier(relation, schema=dataset, catalog=project) def _ensure_dataset(self) -> None: ds_id = f"{self.project}.{self.dataset}" diff --git a/src/fastflowtransform/executors/bigquery/base.py b/src/fastflowtransform/executors/bigquery/base.py index da47d98..e9729a5 100644 --- a/src/fastflowtransform/executors/bigquery/base.py +++ b/src/fastflowtransform/executors/bigquery/base.py @@ -1,24 +1,23 @@ # fastflowtransform/executors/bigquery/base.py from __future__ import annotations -from typing import Any, TypeVar +from typing import TypeVar from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._shims import BigQueryConnShim +from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.bigquery._bigquery_mixin import BigQueryIdentifierMixin from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import _TrackedQueryJob -from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta -from fastflowtransform.snapshots import resolve_snapshot_config from fastflowtransform.typing import BadRequest, Client, NotFound, bigquery TFrame = TypeVar("TFrame") -class BigQueryBaseExecutor(BigQueryIdentifierMixin, BaseExecutor[TFrame]): +class BigQueryBaseExecutor(BigQueryIdentifierMixin, SnapshotSqlMixin, BaseExecutor[TFrame]): """ Shared BigQuery executor logic (SQL, incremental, meta, DQ helpers). @@ -129,33 +128,13 @@ def _format_test_table(self, table: str | None) -> str | None: return self._qualified_identifier(table.strip()) # ---- SQL hooks ---- - def _format_relation_for_ref(self, name: str) -> str: - return self._qualified_identifier(relation_for(name)) - def _this_identifier(self, node: Node) -> str: """ Ensure {{ this }} renders as a fully-qualified identifier so BigQuery incremental SQL (e.g., subqueries against {{ this }}) includes project and dataset. """ - return self._qualified_identifier(relation_for(node.name)) - - def _format_source_reference( - self, - cfg: dict[str, Any], - source_name: str, - table_name: str, - ) -> str: - if cfg.get("location"): - raise NotImplementedError("BigQuery executor does not support path-based sources.") - - ident = cfg.get("identifier") - if not ident: - raise KeyError(f"Source {source_name}.{table_name} missing identifier") - - proj = cfg.get("project") or cfg.get("database") or cfg.get("catalog") or self.project - dset = cfg.get("dataset") or cfg.get("schema") or self.dataset - return self._qualified_identifier(ident, project=proj, dataset=dset) + return self._qualify_identifier(relation_for(node.name)) def _apply_sql_materialization( self, @@ -189,6 +168,29 @@ def _create_or_replace_view_from_table( self._ensure_dataset() self._execute_sql(f"CREATE OR REPLACE VIEW {view_id} AS SELECT * FROM {back_id}").result() + # ---- Snapshot mixin hooks ---- + def _snapshot_prepare_target(self) -> None: + self._ensure_dataset() + + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self._qualified_identifier(rel_name) + + def _snapshot_current_timestamp(self) -> str: + return "CURRENT_TIMESTAMP()" + + def _snapshot_null_timestamp(self) -> str: + return "CAST(NULL AS TIMESTAMP)" + + def _snapshot_null_hash(self) -> str: + return "CAST(NULL AS STRING)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"TO_HEX(MD5({concat_expr}))" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS STRING)" + # ---- Meta hook ---- def on_node_built(self, node: Node, relation: str, fingerprint: str) -> None: """ @@ -325,227 +327,6 @@ def alter_table_sync_schema( self._execute_sql(f"ALTER TABLE {target} ADD COLUMN {col} {typ}").result() # ── Snapshots API (shared for pandas + BigFrames) ───────────────────── - def run_snapshot_sql(self, node: Node, env: Any) -> None: - """ - Snapshot materialization for BigQuery SQL models. - - Uses the same semantics as the DuckDB/Postgres/Snowflake executors: - - First run: create table with snapshot metadata columns. - - Subsequent runs: - * close changed current rows (set valid_to, is_current=false) - * insert new current rows for new/changed keys. - """ - if node.kind != "sql": - raise TypeError( - f"Snapshot materialization is only supported for SQL models, " - f"got kind={node.kind!r} for {node.name}." - ) - - meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): - raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") - - cfg = resolve_snapshot_config(node, meta) - strategy = cfg.strategy # "timestamp" | "check" - unique_key = cfg.unique_key # list[str] - updated_at = cfg.updated_at # str | None - check_cols = cfg.check_cols # list[str] - - if not unique_key: - raise ValueError(f"{node.path}: snapshot models require a non-empty unique_key list.") - - # ---- Render SQL and extract SELECT body ---- - sql_rendered = self.render_sql( - node, - env, - ref_resolver=lambda name: self._resolve_ref(name, env), - source_resolver=self._resolve_source, - ) - sql_clean = self._strip_leading_config(sql_rendered).strip() - body = self._selectable_body(sql_clean).rstrip(" ;\n\t") - - rel_name = relation_for(node.name) - target = self._qualified_identifier(rel_name) - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - vt = BaseExecutor.SNAPSHOT_VALID_TO_COL - is_cur = BaseExecutor.SNAPSHOT_IS_CURRENT_COL - hash_col = BaseExecutor.SNAPSHOT_HASH_COL - upd_meta = BaseExecutor.SNAPSHOT_UPDATED_AT_COL - - self._ensure_dataset() - - # ---- First run: create snapshot table ---- - if not self.exists_relation(rel_name): - if strategy == "timestamp": - if not updated_at: - raise ValueError( - f"{node.path}: strategy='timestamp' snapshots require an updated_at column." - ) - create_sql = f""" -CREATE TABLE {target} AS -SELECT - s.*, - s.{updated_at} AS {upd_meta}, - s.{updated_at} AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - CAST(NULL AS STRING) AS {hash_col} -FROM ({body}) AS s -""" - else: # strategy == "check" - if not check_cols: - raise ValueError( - f"{node.path}: strategy='check' snapshots require non-empty check_cols." - ) - col_exprs = [f"COALESCE(CAST(s.{col} AS STRING), '')" for col in check_cols] - concat_expr = " || '||' || ".join(col_exprs) - hash_expr = f"TO_HEX(MD5({concat_expr}))" - upd_expr = f"s.{updated_at}" if updated_at else "CURRENT_TIMESTAMP()" - create_sql = f""" -CREATE TABLE {target} AS -SELECT - s.*, - {upd_expr} AS {upd_meta}, - CURRENT_TIMESTAMP() AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - {hash_expr} AS {hash_col} -FROM ({body}) AS s -""" - self._execute_sql(create_sql).result() - return - - # ---- Incremental snapshot update ---- - keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in unique_key]) - - if strategy == "timestamp": - if not updated_at: - raise ValueError( - f"{node.path}: strategy='timestamp' snapshots require an updated_at column." - ) - change_condition = f"s.{updated_at} > t.{upd_meta}" - new_upd_expr = f"s.{updated_at}" - new_valid_from_expr = f"s.{updated_at}" - new_hash_expr = "NULL" - else: - col_exprs_s = [f"COALESCE(CAST(s.{col} AS STRING), '')" for col in check_cols] - concat_expr_s = " || '||' || ".join(col_exprs_s) - hash_expr_s = f"TO_HEX(MD5({concat_expr_s}))" - change_condition = f"COALESCE({hash_expr_s}, '') <> COALESCE(t.{hash_col}, '')" - new_upd_expr = f"s.{updated_at}" if updated_at else "CURRENT_TIMESTAMP()" - new_valid_from_expr = "CURRENT_TIMESTAMP()" - new_hash_expr = hash_expr_s - - # 1) Close changed current rows - close_sql = f""" -UPDATE {target} AS t -SET - {vt} = CURRENT_TIMESTAMP(), - {is_cur} = FALSE -FROM ({body}) AS s -WHERE - {keys_pred} - AND t.{is_cur} = TRUE - AND {change_condition} -""" - self._execute_sql(close_sql).result() - - # 2) Insert new current versions (new keys or changed rows) - first_key = unique_key[0] - insert_sql = f""" -INSERT INTO {target} -SELECT - s.*, - {new_upd_expr} AS {upd_meta}, - {new_valid_from_expr} AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - {new_hash_expr} AS {hash_col} -FROM ({body}) AS s -LEFT JOIN {target} AS t - ON {keys_pred} - AND t.{is_cur} = TRUE -WHERE - t.{first_key} IS NULL - OR {change_condition} -""" - self._execute_sql(insert_sql).result() - - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row). - """ - if keep_last <= 0: - return - - keys = [k for k in unique_key if k] - if not keys: - return - - target = self._qualified_identifier( - relation, - project=self.project, - dataset=self.dataset, - ) - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - key_select = ", ".join(keys) - part_by = ", ".join(keys) - - ranked_sql = f""" -SELECT - {key_select}, - {vf}, - ROW_NUMBER() OVER ( - PARTITION BY {part_by} - ORDER BY {vf} DESC - ) AS rn -FROM {target} -""" - - if dry_run: - sql = f""" -WITH ranked AS ( - {ranked_sql} -) -SELECT COUNT(*) AS rows_to_delete -FROM ranked -WHERE rn > {int(keep_last)} -""" - job = self.client.query(sql, location=self.location) - rows = list(job.result()) - count = int(rows[0][0]) if rows else 0 - - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {count} row(s) " - f"(keep_last={keep_last})" - ) - return - - join_pred = " AND ".join([f"t.{k} = r.{k}" for k in keys]) - delete_sql = f""" -DELETE FROM {target} AS t -WHERE EXISTS ( - WITH ranked AS ( - {ranked_sql} - ) - SELECT 1 - FROM ranked AS r - WHERE - r.rn > {int(keep_last)} - AND {join_pred} - AND t.{vf} = r.{vf} -) -""" - self._execute_sql(delete_sql).result() def execute_hook_sql(self, sql: str) -> None: """ diff --git a/src/fastflowtransform/executors/databricks_spark.py b/src/fastflowtransform/executors/databricks_spark.py index a93d2b3..3261820 100644 --- a/src/fastflowtransform/executors/databricks_spark.py +++ b/src/fastflowtransform/executors/databricks_spark.py @@ -16,13 +16,13 @@ from fastflowtransform.core import REGISTRY, Node, relation_for from fastflowtransform.errors import ModelExecutionError from fastflowtransform.executors._budget_runner import run_sql_with_budget +from fastflowtransform.executors._query_stats_adapter import SparkDataFrameStatsAdapter from fastflowtransform.executors._spark_imports import ( get_spark_functions, get_spark_window, ) from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.executors.query_stats import QueryStats from fastflowtransform.logging import echo, echo_debug from fastflowtransform.meta import ensure_meta_table, upsert_meta from fastflowtransform.snapshots import resolve_snapshot_config @@ -571,14 +571,9 @@ def _write_to_storage_path(self, relation: str, df: SDF, storage_meta: dict[str, self.spark.catalog.refreshByPath(path) def _record_spark_dataframe_stats(self, df: SDF, duration_ms: int) -> None: - bytes_est = self._spark_dataframe_bytes(df) - self._record_query_stats( - QueryStats( - bytes_processed=bytes_est, - rows=None, - duration_ms=duration_ms, - ) - ) + adapter = SparkDataFrameStatsAdapter(self._spark_dataframe_bytes) + stats = adapter.collect(df, duration_ms=duration_ms) + self._record_query_stats(stats) def _spark_dataframe_bytes(self, df: SDF) -> int | None: try: diff --git a/src/fastflowtransform/executors/duckdb.py b/src/fastflowtransform/executors/duckdb.py index a4374bf..8c31ad8 100644 --- a/src/fastflowtransform/executors/duckdb.py +++ b/src/fastflowtransform/executors/duckdb.py @@ -4,7 +4,7 @@ import json import re import uuid -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import suppress from pathlib import Path from typing import Any, ClassVar @@ -12,22 +12,21 @@ import duckdb import pandas as pd from duckdb import CatalogException -from jinja2 import Environment -from fastflowtransform.core import Node, relation_for +from fastflowtransform.core import Node from fastflowtransform.executors._budget_runner import run_sql_with_budget +from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin +from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard -from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta -from fastflowtransform.snapshots import resolve_snapshot_config def _q(ident: str) -> str: return '"' + ident.replace('"', '""') + '"' -class DuckExecutor(BaseExecutor[pd.DataFrame]): +class DuckExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME = "duckdb" _FIXED_TYPE_SIZES: ClassVar[dict[str, int]] = { @@ -422,29 +421,39 @@ def _exec_many(self, sql: str) -> None: self._execute_sql(stmt) # ---- Frame hooks ---- + def _quote_identifier(self, ident: str) -> str: + return _q(ident) + + def _should_include_catalog( + self, catalog: str | None, schema: str | None, *, explicit: bool + ) -> bool: + """ + DuckDB includes catalog only when explicitly provided or when it matches + the schema (mirrors previous behaviour). + """ + if explicit: + return bool(catalog) + return bool(catalog and schema and catalog.lower() == schema.lower()) + + def _default_catalog_for_source(self, schema: str | None) -> str | None: + """ + For sources, fall back to DuckDB's detected catalog when: + - schema is set and matches the catalog, or + - neither schema nor catalog was provided (keep old fallback) + """ + cat = self._default_catalog() + if not cat: + return None + if schema is None or cat.lower() == schema.lower(): + return cat + return None + def _qualified(self, relation: str, *, quoted: bool = True) -> str: """ Return (catalog.)schema.relation if schema is set; otherwise just relation. When quoted=False, emit bare identifiers for APIs like con.table(). """ - rel = relation_for(relation) if relation.endswith(".ff") else relation - rel_part = _q(rel) if quoted else rel - if not self.schema: - return rel_part - parts: list[str] = [] - cat_clean = None - include_catalog = False - if isinstance(self.catalog, str): - cat_trimmed = self.catalog.strip() - if cat_trimmed and cat_trimmed.lower() == self.schema.strip().lower(): - include_catalog = True - cat_clean = cat_trimmed - if include_catalog and cat_clean is not None: - parts.append(_q(cat_clean) if quoted else cat_clean) - schema_clean = self.schema.strip() - parts.append(_q(schema_clean) if quoted else schema_clean) - parts.append(rel_part) - return ".".join(parts) + return self._qualify_identifier(relation, quote=quoted) def _read_relation(self, relation: str, node: Node, deps: Iterable[str]) -> pd.DataFrame: try: @@ -488,45 +497,6 @@ def _frame_name(self) -> str: return "pandas" # ---- SQL hooks ---- - def _format_relation_for_ref(self, name: str) -> str: - return self._qualified(relation_for(name)) - - def _format_source_reference( - self, cfg: dict[str, Any], source_name: str, table_name: str - ) -> str: - location = cfg.get("location") - if location: - raise NotImplementedError("DuckDB executor does not support path-based sources yet.") - - identifier = cfg.get("identifier") - if not identifier: - raise KeyError(f"Source {source_name}.{table_name} missing identifier") - - catalog_cfg = cfg.get("catalog") or cfg.get("database") - catalog = ( - catalog_cfg.strip() if isinstance(catalog_cfg, str) and catalog_cfg.strip() else None - ) - schema_candidate = cfg.get("schema") or self.schema - schema = ( - schema_candidate.strip() - if isinstance(schema_candidate, str) and schema_candidate.strip() - else None - ) - if catalog is None and schema and isinstance(self.catalog, str): - cat_clean = self.catalog.strip() - if cat_clean and cat_clean.lower() == schema.lower(): - catalog = cat_clean - if catalog is None and schema is None and isinstance(self.catalog, str): - cat_clean = self.catalog.strip() - catalog = cat_clean or None - parts: list[str] = [] - if catalog: - parts.append(catalog) - if schema: - parts.append(schema) - parts.append(identifier) - return ".".join(_q(str(part)) for part in parts if part) - def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: self._execute_sql(f"create or replace view {target_sql} as {select_body}") @@ -575,8 +545,6 @@ def incremental_merge(self, relation: str, select_sql: str, unique_key: list[str Fallback strategy for DuckDB: - DELETE collisions via DELETE ... USING () - We intentionally do NOT use a CTE here, because we execute two separate - statements and DuckDB won't see the CTE from the previous statement. """ # 1) clean inner SELECT body = self._selectable_body(select_sql).strip().rstrip(";\n\t ") @@ -620,223 +588,6 @@ def alter_table_sync_schema( except Exception: self._execute_sql(f"alter table {target} add column {col} varchar") - def run_snapshot_sql(self, node: Node, env: Environment) -> None: - """ - Snapshot materialization for DuckDB. - - Config (node.meta): - - materialized='snapshot' - - snapshot: { ... } # strategy + per-strategy hints - - unique_key: str | list[str] - - Behaviour: - - First run: create table with one current row per unique key. - - Subsequent runs: - * close changed current rows (set valid_to, is_current=false) - * insert new current rows for new/changed keys. - """ - if node.kind != "sql": - raise TypeError( - f"Snapshot materialization is only supported for SQL models, " - f"got kind={node.kind!r} for {node.name}." - ) - - meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): - raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") - - # ---- Extract & normalise snapshot config (shared helper) ---- - cfg = resolve_snapshot_config(node, meta) - strategy = cfg.strategy - unique_key = cfg.unique_key - updated_at = cfg.updated_at - check_cols = cfg.check_cols - - # ---- Render SQL and extract SELECT body ---- - sql_rendered = self.render_sql( - node, - env, - ref_resolver=lambda name: self._resolve_ref(name, env), - source_resolver=self._resolve_source, - ) - sql = self._strip_leading_config(sql_rendered).strip() - body = self._selectable_body(sql).rstrip(" ;\n\t") - - rel_name = relation_for(node.name) - target = self._qualified(rel_name) - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - vt = BaseExecutor.SNAPSHOT_VALID_TO_COL - is_cur = BaseExecutor.SNAPSHOT_IS_CURRENT_COL - hash_col = BaseExecutor.SNAPSHOT_HASH_COL - upd_meta = BaseExecutor.SNAPSHOT_UPDATED_AT_COL - - # ---- First run: create snapshot table ---- - if not self.exists_relation(rel_name): - if strategy == "timestamp": - # valid_from + updated_at come from the source updated_at column - create_sql = f""" -create table {target} as -select - s.*, - s.{updated_at} as {upd_meta}, - s.{updated_at} as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - cast(null as varchar) as {hash_col} -from ({body}) as s -""" - else: # strategy == "check" - # Hash over check_cols to detect changes - col_exprs = [f"coalesce(cast(s.{col} as varchar), '')" for col in check_cols] - concat_expr = " || '||' || ".join(col_exprs) - hash_expr = f"cast(md5({concat_expr}) as varchar)" - upd_expr = f"s.{updated_at}" if updated_at else "current_timestamp" - create_sql = f""" -create table {target} as -select - s.*, - {upd_expr} as {upd_meta}, - current_timestamp as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - {hash_expr} as {hash_col} -from ({body}) as s -""" - self._execute_sql(create_sql) - return - - # ---- Incremental snapshot update ---- - - # Stage current source rows in a temp view for reuse - src_view_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") - src_quoted = _q(src_view_name) - self._execute_sql(f"create or replace temp view {src_quoted} as {body}") - - try: - # Join predicate on unique keys - keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in unique_key]) - - # Change condition & hash for staging rows - if strategy == "timestamp": - change_condition = f"s.{updated_at} > t.{upd_meta}" - hash_expr_s = "NULL" - new_upd_expr = f"s.{updated_at}" - new_valid_from_expr = f"s.{updated_at}" - new_hash_expr = "NULL" - else: - col_exprs_s = [f"coalesce(cast(s.{col} as varchar), '')" for col in check_cols] - concat_expr_s = " || '||' || ".join(col_exprs_s) - hash_expr_s = f"cast(md5({concat_expr_s}) as varchar)" - change_condition = f"coalesce({hash_expr_s}, '') <> coalesce(t.{hash_col}, '')" - new_upd_expr = f"s.{updated_at}" if updated_at else "current_timestamp" - new_valid_from_expr = "current_timestamp" - new_hash_expr = hash_expr_s - - # 1) Close changed current rows - close_sql = f""" -update {target} as t -set - {vt} = current_timestamp, - {is_cur} = false -from {src_quoted} as s -where - {keys_pred} - and t.{is_cur} = true - and {change_condition}; -""" - self._execute_sql(close_sql) - - # 2) Insert new current versions (new keys or changed rows) - first_key = unique_key[0] - insert_sql = f""" -insert into {target} -select - s.*, - {new_upd_expr} as {upd_meta}, - {new_valid_from_expr} as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - {new_hash_expr} as {hash_col} -from {src_quoted} as s -left join {target} as t - on {keys_pred} - and t.{is_cur} = true -where - t.{first_key} is null - or {change_condition}; -""" - self._execute_sql(insert_sql) - finally: - with suppress(Exception): - self._execute_sql(f"drop view if exists {src_quoted}") - - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row). - """ - if keep_last <= 0: - return - - target = self._qualified(relation) - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - keys = [k for k in unique_key if k] - - if not keys: - return - - part_by = ", ".join([k for k in keys]) - key_select = ", ".join(keys) - - ranked_sql = f""" -select - {key_select}, - {vf}, - row_number() over ( - partition by {part_by} - order by {vf} desc - ) as rn -from {target} -""" - - if dry_run: - sql = f""" -with ranked as ( - {ranked_sql} -) -select count(*) as rows_to_delete -from ranked -where rn > {int(keep_last)} -""" - res = self._execute_sql(sql).fetchone() - rows = int(res[0]) if res else 0 - - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {rows} row(s) " - f"(keep_last={keep_last})" - ) - return - - delete_sql = f""" -delete from {target} t -using ( - {ranked_sql} -) r -where - r.rn > {int(keep_last)} - and {" AND ".join([f"t.{k} = r.{k}" for k in keys])} - and t.{vf} = r.{vf}; -""" - self._execute_sql(delete_sql) - def execute_hook_sql(self, sql: str) -> None: """ Execute one or multiple SQL statements for pre/post/on_run hooks. @@ -845,6 +596,38 @@ def execute_hook_sql(self, sql: str) -> None: """ self._exec_many(sql) + # ---- Snapshot mixin hooks ---- + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self._qualified(rel_name) + + def _snapshot_current_timestamp(self) -> str: + return "current_timestamp" + + def _snapshot_null_timestamp(self) -> str: + return "cast(null as timestamp)" + + def _snapshot_null_hash(self) -> str: + return "cast(null as varchar)" + + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"cast(md5({concat_expr}) as varchar)" + + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"cast({expr} as varchar)" + + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: + src_view_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") + src_quoted = _q(src_view_name) + self._execute_sql(f"create or replace temp view {src_quoted} as {select_body}") + + def _cleanup() -> None: + self._execute_sql(f"drop view if exists {src_quoted}") + + return src_quoted, _cleanup + # ---- Unit-test helpers ------------------------------------------------- def utest_load_relation_from_rows(self, relation: str, rows: list[dict]) -> None: diff --git a/src/fastflowtransform/executors/postgres.py b/src/fastflowtransform/executors/postgres.py index 81e5128..dacc4d8 100644 --- a/src/fastflowtransform/executors/postgres.py +++ b/src/fastflowtransform/executors/postgres.py @@ -1,28 +1,27 @@ # fastflowtransform/executors/postgres.py import json -from collections.abc import Iterable +from collections.abc import Callable, Iterable from time import perf_counter from typing import Any import pandas as pd -from jinja2 import Environment from sqlalchemy import create_engine, text from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError -from fastflowtransform.core import Node, relation_for +from fastflowtransform.core import Node from fastflowtransform.errors import ModelExecutionError, ProfileConfigError from fastflowtransform.executors._budget_runner import run_sql_with_budget from fastflowtransform.executors._shims import SAConnShim +from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin +from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats -from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta -from fastflowtransform.snapshots import resolve_snapshot_config -class PostgresExecutor(BaseExecutor[pd.DataFrame]): +class PostgresExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[pd.DataFrame]): ENGINE_NAME = "postgres" _DEFAULT_PG_ROW_WIDTH = 128 _BUDGET_GUARD = BudgetGuard( @@ -269,11 +268,11 @@ def _q_ident(self, ident: str) -> str: # Simple, safe quoting for identifiers return '"' + ident.replace('"', '""') + '"' + def _quote_identifier(self, ident: str) -> str: + return self._q_ident(ident) + def _qualified(self, relname: str, schema: str | None = None) -> str: - sch = schema or self.schema - if sch: - return f"{self._q_ident(sch)}.{self._q_ident(relname)}" - return self._q_ident(relname) + return self._qualify_identifier(relname, schema=schema) def _set_search_path(self, conn: Connection | SAConnShim) -> None: if self.schema: @@ -358,26 +357,6 @@ def _create_or_replace_view_from_table( def _frame_name(self) -> str: return "pandas" - # ---- SQL hooks ---- - def _format_relation_for_ref(self, name: str) -> str: - return self._qualified(relation_for(name)) - - def _format_source_reference( - self, cfg: dict[str, Any], source_name: str, table_name: str - ) -> str: - if cfg.get("location"): - raise NotImplementedError("Postgres executor does not support path-based sources.") - - ident = cfg.get("identifier") - if not ident: - raise KeyError(f"Source {source_name}.{table_name} missing identifier") - - relation = self._qualified(ident, schema=cfg.get("schema")) - database = cfg.get("database") or cfg.get("catalog") - if database: - return f"{self._q_ident(database)}.{relation}" - return relation - def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: try: self._execute_sql(f"DROP VIEW IF EXISTS {target_sql} CASCADE") @@ -497,229 +476,37 @@ def alter_table_sync_schema( self._execute_sql(f'alter table {qrel} add column "{c}" text', conn=conn) # ── Snapshot API ────────────────────────────────────────────────────── + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self._qualified(rel_name) - def run_snapshot_sql(self, node: Node, env: Environment) -> None: - """ - Snapshot materialization for Postgres. + def _snapshot_current_timestamp(self) -> str: + return "current_timestamp" - Config: - - materialized='snapshot' - - snapshot={...} and/or top-level strategy/updated_at/check_cols - - unique_key / primary_key + def _snapshot_null_timestamp(self) -> str: + return "cast(null as timestamp)" - Behaviour: - - First run: create table with one current row per unique key. - - Subsequent runs: - * close changed current rows (set valid_to, is_current=false) - * insert new current rows for new/changed keys. - """ - if node.kind != "sql": - raise TypeError( - f"Snapshot materialization is only supported for SQL models, " - f"got kind={node.kind!r} for {node.name}." - ) + def _snapshot_null_hash(self) -> str: + return "cast(null as text)" - meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): - raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") - - # Shared normalisation: supports nested 'snapshot={...}' OR flattened config. - cfg = resolve_snapshot_config(node, meta) - strategy = cfg.strategy - unique_key = cfg.unique_key - updated_at = cfg.updated_at - check_cols = cfg.check_cols - - # ---- Render SQL and extract SELECT body ---- - sql_rendered = self.render_sql( - node, - env, - ref_resolver=lambda name: self._resolve_ref(name, env), - source_resolver=self._resolve_source, - ) - sql = self._strip_leading_config(sql_rendered).strip() - body = self._selectable_body(sql).rstrip(" ;\n\t") - - rel_name = relation_for(node.name) - target = self._qualified(rel_name) - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - vt = BaseExecutor.SNAPSHOT_VALID_TO_COL - is_cur = BaseExecutor.SNAPSHOT_IS_CURRENT_COL - hash_col = BaseExecutor.SNAPSHOT_HASH_COL - upd_meta = BaseExecutor.SNAPSHOT_UPDATED_AT_COL - - # ---- First run: create snapshot table ---- - if not self.exists_relation(rel_name): - if strategy == "timestamp": - # valid_from + updated_at come from the source updated_at column - create_sql = f""" -create table {target} as -select - s.*, - s.{updated_at} as {upd_meta}, - s.{updated_at} as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - cast(null as text) as {hash_col} -from ({body}) as s -""" - else: # strategy == "check" - # Hash over check_cols to detect changes - col_exprs = [f"coalesce(cast(s.{col} as text), '')" for col in check_cols] - concat_expr = " || '||' || ".join(col_exprs) - hash_expr = f"md5({concat_expr})" - upd_expr = f"s.{updated_at}" if updated_at else "current_timestamp" - create_sql = f""" -create table {target} as -select - s.*, - {upd_expr} as {upd_meta}, - current_timestamp as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - {hash_expr} as {hash_col} -from ({body}) as s -""" - self._execute_sql(create_sql) - return + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"md5({concat_expr})" - # ---- Incremental snapshot update ---- + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"cast({expr} as text)" - # Stage current source rows in a temporary table for reuse + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") src_q = self._q_ident(src_name) + self._execute_sql(f"drop table if exists {src_q}") + self._execute_sql(f"create temporary table {src_q} as {select_body}") - with self.engine.begin() as conn: - # (Re-)create temp staging table - self._execute_sql(f"drop table if exists {src_q}", conn=conn) - self._execute_sql(f"create temporary table {src_q} as {body}", conn=conn) - - # Join predicate on unique keys - keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in unique_key]) - - # Change condition & hash for staging rows - if strategy == "timestamp": - change_condition = f"s.{updated_at} > t.{upd_meta}" - hash_expr_s = "NULL" - new_upd_expr = f"s.{updated_at}" - new_valid_from_expr = f"s.{updated_at}" - new_hash_expr = "NULL" - else: - col_exprs_s = [f"coalesce(cast(s.{col} as text), '')" for col in check_cols] - concat_expr_s = " || '||' || ".join(col_exprs_s) - hash_expr_s = f"md5({concat_expr_s})" - change_condition = ( - f"coalesce({hash_expr_s}, '') <> coalesce(t.{hash_col}::text, '')" - ) - new_upd_expr = f"s.{updated_at}" if updated_at else "current_timestamp" - new_valid_from_expr = "current_timestamp" - new_hash_expr = hash_expr_s - - # 1) Close changed current rows - close_sql = f""" -update {target} as t -set - {vt} = current_timestamp, - {is_cur} = false -from {src_q} as s -where - {keys_pred} - and t.{is_cur} = true - and {change_condition}; -""" - self._execute_sql(close_sql, conn=conn) - - # 2) Insert new current versions (new keys or changed rows) - first_key = unique_key[0] - insert_sql = f""" -insert into {target} -select - s.*, - {new_upd_expr} as {upd_meta}, - {new_valid_from_expr} as {vf}, - cast(null as timestamp) as {vt}, - true as {is_cur}, - {new_hash_expr} as {hash_col} -from {src_q} as s -left join {target} as t - on {keys_pred} - and t.{is_cur} = true -where - t.{first_key} is null - or {change_condition}; -""" - self._execute_sql(insert_sql, conn=conn) - - # Temp table will be dropped automatically at end of session; dropping - # explicitly here is harmless and keeps the connection clean for tests. - self._execute_sql(f"drop table if exists {src_q}", conn=conn) - self._analyze_relations([target], conn=conn) - - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row). - """ - if keep_last <= 0: - return - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - keys = [k for k in unique_key if k] - if not keys: - return - - target = self._qualified(relation) - part_by = ", ".join(keys) - key_select = ", ".join(keys) - - ranked_sql = f""" -select - {key_select}, - {vf}, - row_number() over ( - partition by {part_by} - order by {vf} desc - ) as rn -from {target} -""" - - if dry_run: - sql = f""" -with ranked as ( - {ranked_sql} -) -select count(*) as rows_to_delete -from ranked -where rn > {int(keep_last)} -""" - res = self._execute_sql(sql).fetchone() - rows = int(res[0]) if res else 0 - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {rows} row(s) " - f"(keep_last={keep_last})" - ) - return + def _cleanup() -> None: + self._execute_sql(f"drop table if exists {src_q}") - delete_sql = f""" -delete from {target} t -using ( - {ranked_sql} -) r -where - r.rn > {int(keep_last)} - and {" AND ".join([f"t.{k} = r.{k}" for k in keys])} - and t.{vf} = r.{vf}; -""" - self._execute_sql(delete_sql) - self._analyze_relations([relation]) + return src_q, _cleanup def execute_hook_sql(self, sql: str) -> None: """ diff --git a/src/fastflowtransform/executors/snowflake_snowpark.py b/src/fastflowtransform/executors/snowflake_snowpark.py index 5241370..400bf97 100644 --- a/src/fastflowtransform/executors/snowflake_snowpark.py +++ b/src/fastflowtransform/executors/snowflake_snowpark.py @@ -2,26 +2,25 @@ from __future__ import annotations import json -from collections.abc import Iterable +from collections.abc import Callable, Iterable from contextlib import suppress from time import perf_counter from typing import Any, cast import pandas as pd -from jinja2 import Environment from fastflowtransform.core import Node, relation_for from fastflowtransform.executors._budget_runner import run_sql_with_budget +from fastflowtransform.executors._snapshot_sql_mixin import SnapshotSqlMixin +from fastflowtransform.executors._sql_identifier import SqlIdentifierMixin from fastflowtransform.executors.base import BaseExecutor from fastflowtransform.executors.budget import BudgetGuard from fastflowtransform.executors.query_stats import QueryStats -from fastflowtransform.logging import echo from fastflowtransform.meta import ensure_meta_table, upsert_meta -from fastflowtransform.snapshots import resolve_snapshot_config from fastflowtransform.typing import SNDF, SnowparkSession as Session -class SnowflakeSnowparkExecutor(BaseExecutor[SNDF]): +class SnowflakeSnowparkExecutor(SqlIdentifierMixin, SnapshotSqlMixin, BaseExecutor[SNDF]): ENGINE_NAME = "snowflake_snowpark" """Snowflake executor operating on Snowpark DataFrames (no pandas).""" _BUDGET_GUARD = BudgetGuard( @@ -150,9 +149,25 @@ def _exec_many(self, sql: str) -> None: def _q(self, s: str) -> str: return '"' + s.replace('"', '""') + '"' + def _quote_identifier(self, ident: str) -> str: + # Keep identifiers unquoted to match legacy Snowflake behaviour. + return ident + + def _default_schema(self) -> str | None: + return self.schema + + def _default_catalog(self) -> str | None: + return self.database + + def _should_include_catalog( + self, catalog: str | None, schema: str | None, *, explicit: bool + ) -> bool: + # Always include database when present; Snowflake expects DB.SCHEMA.TABLE. + return bool(catalog) + def _qualified(self, rel: str) -> str: # DATABASE.SCHEMA.TABLE (no quotes) - return f"{self.database}.{self.schema}.{rel}" + return self._qualify_identifier(rel, quote=False) def _ensure_schema(self) -> None: """ @@ -327,15 +342,12 @@ def _frame_name(self) -> str: return "Snowpark" # ---- SQL hooks ---- - def _format_relation_for_ref(self, name: str) -> str: - return self._qualified(relation_for(name)) - def _this_identifier(self, node: Node) -> str: """ Identifier for {{ this }} in SQL models. Use fully-qualified DB.SCHEMA.TABLE so all build/read/test paths agree. """ - return self._qualified(relation_for(node.name)) + return self._qualify_identifier(relation_for(node.name), quote=False) def _format_source_reference( self, cfg: dict[str, Any], source_name: str, table_name: str @@ -347,13 +359,13 @@ def _format_source_reference( if not ident: raise KeyError(f"Source {source_name}.{table_name} missing identifier") - db = cfg.get("database") or cfg.get("catalog") or self.database - sch = cfg.get("schema") or self.schema + sch = self._pick_schema(cfg) + db = self._pick_catalog(cfg, sch) if not db or not sch: raise KeyError( f"Source {source_name}.{table_name} missing database/schema for Snowflake" ) - return f"{db}.{sch}.{ident}" + return self._qualify_identifier(ident, schema=sch, catalog=db, quote=False) def _create_or_replace_view(self, target_sql: str, select_body: str, node: Node) -> None: self._execute_sql(f"CREATE OR REPLACE VIEW {target_sql} AS {select_body}").collect() @@ -487,223 +499,42 @@ def alter_table_sync_schema( cols_sql = ", ".join(f"{self._q(c)} STRING" for c in to_add) self._execute_sql(f"ALTER TABLE {qrel} ADD COLUMN {cols_sql}").collect() - # ── Snapshot API ───────────────────────────────────────────────────── - def run_snapshot_sql(self, node: Node, env: Environment) -> None: - """ - Snapshot materialization for Snowflake Snowpark. + # ---- Snapshot API (mixin hooks) -------------------------------------- + def _snapshot_target_identifier(self, rel_name: str) -> str: + return self._qualified(rel_name) - Uses the shared snapshot config resolver so all engines share the - same semantics and validation. - """ - if node.kind != "sql": - raise TypeError( - f"Snapshot materialization is only supported for SQL models, " - f"got kind={node.kind!r} for {node.name}." - ) + def _snapshot_current_timestamp(self) -> str: + return "CURRENT_TIMESTAMP()" - meta = getattr(node, "meta", {}) or {} - if not self._meta_is_snapshot(meta): - raise ValueError(f"Node {node.name} is not configured with materialized='snapshot'.") + def _snapshot_create_keyword(self) -> str: + return "CREATE OR REPLACE TABLE" - cfg = resolve_snapshot_config(node, meta) + def _snapshot_null_timestamp(self) -> str: + return "CAST(NULL AS TIMESTAMP)" - # Render model SQL and extract the SELECT body - rendered = self.render_sql( - node, - env, - ref_resolver=lambda name: self._resolve_ref(name, env), - source_resolver=self._resolve_source, - ) - sql = self._strip_leading_config(rendered).strip() - body = self._selectable_body(sql).rstrip(";\n\t ") - - rel_name = relation_for(node.name) - target = self._qualified(rel_name) - - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - vt = BaseExecutor.SNAPSHOT_VALID_TO_COL - is_cur = BaseExecutor.SNAPSHOT_IS_CURRENT_COL - hash_col = BaseExecutor.SNAPSHOT_HASH_COL - upd_meta = BaseExecutor.SNAPSHOT_UPDATED_AT_COL - - # ---- First run: create snapshot table ---- - if not self.exists_relation(rel_name): - if cfg.strategy == "timestamp": - # cfg.updated_at is guaranteed non-None by resolve_snapshot_config - if cfg.updated_at is None: # defensive, for type-checkers - raise ValueError( - "strategy='timestamp' snapshot requires a non-null updated_at column." - ) + def _snapshot_null_hash(self) -> str: + return "CAST(NULL AS VARCHAR)" - create_sql = f""" -CREATE OR REPLACE TABLE {target} AS -SELECT - s.*, - s.{cfg.updated_at} AS {upd_meta}, - s.{cfg.updated_at} AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - CAST(NULL AS VARCHAR) AS {hash_col} -FROM ({body}) AS s -""" - else: # strategy == "check" - # hash over check_cols to detect changes - col_exprs = [f"COALESCE(CAST(s.{col} AS VARCHAR), '')" for col in cfg.check_cols] - concat_expr = " || '||' || ".join(col_exprs) or "''" - hash_expr = f"CAST(MD5({concat_expr}) AS VARCHAR)" - upd_expr = ( - f"s.{cfg.updated_at}" if cfg.updated_at is not None else "CURRENT_TIMESTAMP()" - ) + def _snapshot_hash_expr(self, check_cols: list[str], src_alias: str) -> str: + concat_expr = self._snapshot_concat_expr(check_cols, src_alias) + return f"CAST(MD5({concat_expr}) AS VARCHAR)" - create_sql = f""" -CREATE OR REPLACE TABLE {target} AS -SELECT - s.*, - {upd_expr} AS {upd_meta}, - CURRENT_TIMESTAMP() AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - {hash_expr} AS {hash_col} -FROM ({body}) AS s -""" - self._execute_sql(create_sql).collect() - return + def _snapshot_cast_as_string(self, expr: str) -> str: + return f"CAST({expr} AS VARCHAR)" - # ---- Incremental snapshot update ---- + def _snapshot_source_ref( + self, rel_name: str, select_body: str + ) -> tuple[str, Callable[[], None]]: src_name = f"__ff_snapshot_src_{rel_name}".replace(".", "_") + src_quoted = self._q(src_name) + self._execute_sql( + f"CREATE OR REPLACE TEMPORARY VIEW {src_quoted} AS {select_body}" + ).collect() - # Use a temporary view for the current source rows - self._execute_sql(f"CREATE OR REPLACE TEMPORARY VIEW {src_name} AS {body}").collect() - - try: - keys_pred = " AND ".join([f"t.{k} = s.{k}" for k in cfg.unique_key]) or "FALSE" - - if cfg.strategy == "timestamp": - if cfg.updated_at is None: - raise ValueError( - "strategy='timestamp' snapshot requires a non-null updated_at column." - ) - change_condition = f"s.{cfg.updated_at} > t.{upd_meta}" - hash_expr_s = "NULL" - new_upd_expr = f"s.{cfg.updated_at}" - new_valid_from_expr = f"s.{cfg.updated_at}" - new_hash_expr = "NULL" - else: - col_exprs_s = [f"COALESCE(CAST(s.{col} AS VARCHAR), '')" for col in cfg.check_cols] - concat_expr_s = " || '||' || ".join(col_exprs_s) or "''" - hash_expr_s = f"CAST(MD5({concat_expr_s}) AS VARCHAR)" - change_condition = f"COALESCE({hash_expr_s}, '') <> COALESCE(t.{hash_col}, '')" - new_upd_expr = ( - f"s.{cfg.updated_at}" if cfg.updated_at is not None else "CURRENT_TIMESTAMP()" - ) - new_valid_from_expr = "CURRENT_TIMESTAMP()" - new_hash_expr = hash_expr_s - - # 1) Close changed current rows - close_sql = f""" -UPDATE {target} AS t -SET - {vt} = CURRENT_TIMESTAMP(), - {is_cur} = FALSE -FROM {src_name} AS s -WHERE - {keys_pred} - AND t.{is_cur} = TRUE - AND {change_condition} -""" - self._execute_sql(close_sql).collect() - - # 2) Insert new current versions (new keys or changed rows) - first_key = cfg.unique_key[0] - insert_sql = f""" -INSERT INTO {target} -SELECT - s.*, - {new_upd_expr} AS {upd_meta}, - {new_valid_from_expr} AS {vf}, - CAST(NULL AS TIMESTAMP) AS {vt}, - TRUE AS {is_cur}, - {new_hash_expr} AS {hash_col} -FROM {src_name} AS s -LEFT JOIN {target} AS t - ON {keys_pred} - AND t.{is_cur} = TRUE -WHERE - t.{first_key} IS NULL - OR {change_condition} -""" - self._execute_sql(insert_sql).collect() - finally: - with suppress(Exception): - self._execute_sql(f"DROP VIEW IF EXISTS {src_name}").collect() - - def snapshot_prune( - self, - relation: str, - unique_key: list[str], - keep_last: int, - *, - dry_run: bool = False, - ) -> None: - """ - Delete older snapshot versions while keeping the most recent `keep_last` - rows per business key (including the current row). - """ - if keep_last <= 0: - return - - keys = [k for k in unique_key if k] - if not keys: - return - - target = self._qualified(relation) - vf = BaseExecutor.SNAPSHOT_VALID_FROM_COL - - part_by = ", ".join(keys) - key_select = ", ".join(keys) - - ranked_sql = f""" -SELECT - {key_select}, - {vf}, - ROW_NUMBER() OVER ( - PARTITION BY {part_by} - ORDER BY {vf} DESC - ) AS rn -FROM {target} -""" - - if dry_run: - sql = f""" -WITH ranked AS ( - {ranked_sql} -) -SELECT COUNT(*) AS rows_to_delete -FROM ranked -WHERE rn > {int(keep_last)} -""" - res_raw = self._execute_sql(sql).collect() - # Snowflake returns a list of Row objects; treat them as tuples for typing. - res = cast("list[tuple[Any, ...]]", res_raw) - rows = int(res[0][0]) if res else 0 - - echo( - f"[DRY-RUN] snapshot_prune({relation}): would delete {rows} row(s) " - f"(keep_last={keep_last})" - ) - return + def _cleanup() -> None: + self._execute_sql(f"DROP VIEW IF EXISTS {src_quoted}").collect() - delete_sql = f""" -DELETE FROM {target} t -USING ( - {ranked_sql} -) r -WHERE - r.rn > {int(keep_last)} - AND {" AND ".join([f"t.{k} = r.{k}" for k in keys])} - AND t.{vf} = r.{vf} -""" - self._execute_sql(delete_sql).collect() + return src_quoted, _cleanup def execute_hook_sql(self, sql: str) -> None: """ diff --git a/uv.lock b/uv.lock index f431107..dfa5e7a 100644 --- a/uv.lock +++ b/uv.lock @@ -733,7 +733,7 @@ wheels = [ [[package]] name = "fastflowtransform" -version = "0.6.9" +version = "0.6.10" source = { editable = "." } dependencies = [ { name = "duckdb" },