Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "fastflowtransform"
version = "0.6.9"
version = "0.6.10"
description = "Python framework for SQL & Python data transformation, ETL pipelines, and dbt-style data modeling"
readme = "README.md"
license = { text = "Apache-2.0" }
Expand Down
44 changes: 14 additions & 30 deletions src/fastflowtransform/executors/_budget_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from time import perf_counter
from typing import Any

from fastflowtransform.executors._query_stats_adapter import QueryStatsAdapter, RowcountStatsAdapter
from fastflowtransform.executors.budget import BudgetGuard
from fastflowtransform.executors.query_stats import QueryStats

Expand All @@ -20,6 +21,7 @@ def run_sql_with_budget(
estimate_fn: Callable[[str], int | None] | None = None,
post_estimate_fn: Callable[[str, Any], int | None] | None = None,
record_stats: bool = True,
stats_adapter: QueryStatsAdapter | None = None,
) -> Any:
"""
Shared helper for guarded SQL execution with timing + stats recording.
Expand Down Expand Up @@ -51,36 +53,18 @@ def run_sql_with_budget(
result = exec_fn()
duration_ms = int((perf_counter() - started) * 1000)

rows: int | None = None
if rowcount_extractor is not None:
with suppress(Exception):
rows = rowcount_extractor(result)

stats = QueryStats(bytes_processed=estimated_bytes, rows=rows, duration_ms=duration_ms)

if stats.bytes_processed is None and post_estimate_fn is not None:
with suppress(Exception):
post_estimate = post_estimate_fn(sql, result)
if post_estimate is not None:
stats = QueryStats(
bytes_processed=post_estimate,
rows=stats.rows,
duration_ms=stats.duration_ms,
)

if extra_stats is not None:
with suppress(Exception):
extra = extra_stats(result)
if extra:
stats = QueryStats(
bytes_processed=extra.bytes_processed
if extra.bytes_processed is not None
else stats.bytes_processed,
rows=extra.rows if extra.rows is not None else stats.rows,
duration_ms=extra.duration_ms
if extra.duration_ms is not None
else stats.duration_ms,
)
adapter = stats_adapter
if adapter is None and (rowcount_extractor or post_estimate_fn or extra_stats):
adapter = RowcountStatsAdapter(
rowcount_extractor=rowcount_extractor,
post_estimate_fn=post_estimate_fn,
extra_stats=extra_stats,
sql=sql,
)
if adapter is None:
stats = QueryStats(bytes_processed=estimated_bytes, rows=None, duration_ms=duration_ms)
else:
stats = adapter.collect(result, duration_ms=duration_ms, estimated_bytes=estimated_bytes)

executor._record_query_stats(stats)
return result
142 changes: 142 additions & 0 deletions src/fastflowtransform/executors/_query_stats_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# fastflowtransform/executors/_query_stats_adapter.py
from __future__ import annotations

from collections.abc import Callable
from typing import Any, Protocol

from fastflowtransform.executors.query_stats import QueryStats


class QueryStatsAdapter(Protocol):
"""Adapter interface to normalize stats extraction across engines."""

def collect(
self, result: Any, *, duration_ms: int | None, estimated_bytes: int | None
) -> QueryStats: ...


class RowcountStatsAdapter:
"""
Default stats adapter for DB-API style executors that expose rowcount.
Preserves existing post_estimate/extra_stats hook behavior.
"""

def __init__(
self,
*,
rowcount_extractor: Callable[[Any], int | None] | None = None,
post_estimate_fn: Callable[[str, Any], int | None] | None = None,
extra_stats: Callable[[Any], QueryStats | None] | None = None,
sql: str | None = None,
) -> None:
self.rowcount_extractor = rowcount_extractor
self.post_estimate_fn = post_estimate_fn
self.extra_stats = extra_stats
self.sql = sql

def collect(
self, result: Any, *, duration_ms: int | None, estimated_bytes: int | None
) -> QueryStats:
rows: int | None = None
if self.rowcount_extractor is not None:
try:
rows = self.rowcount_extractor(result)
except Exception:
rows = None

stats = QueryStats(bytes_processed=estimated_bytes, rows=rows, duration_ms=duration_ms)

if stats.bytes_processed is None and self.post_estimate_fn is not None:
try:
post_estimate = self.post_estimate_fn(self.sql or "", result)
except Exception:
post_estimate = None
if post_estimate is not None:
stats = QueryStats(
bytes_processed=post_estimate,
rows=stats.rows,
duration_ms=stats.duration_ms,
)

if self.extra_stats is not None:
try:
extra = self.extra_stats(result)
except Exception:
extra = None
if extra:
stats = QueryStats(
bytes_processed=extra.bytes_processed
if extra.bytes_processed is not None
else stats.bytes_processed,
rows=extra.rows if extra.rows is not None else stats.rows,
duration_ms=(
extra.duration_ms if extra.duration_ms is not None else stats.duration_ms
),
)

return stats


class JobStatsAdapter:
"""
Generic job-handle stats extractor (BigQuery/Snowflake/Spark-like objects).
Mirrors the previous `_record_query_job_stats` heuristics.
"""

def collect(self, job: Any) -> QueryStats:
def _safe_int(val: Any) -> int | None:
try:
if val is None:
return None
return int(val)
except Exception:
return None

bytes_processed = _safe_int(
getattr(job, "total_bytes_processed", None) or getattr(job, "bytes_processed", None)
)

rows = _safe_int(
getattr(job, "num_dml_affected_rows", None)
or getattr(job, "total_rows", None)
or getattr(job, "rowcount", None)
)

duration_ms: int | None = None
try:
started = getattr(job, "started", None)
ended = getattr(job, "ended", None)
if started is not None and ended is not None:
duration_ms = int((ended - started).total_seconds() * 1000)
except Exception:
duration_ms = None

return QueryStats(bytes_processed=bytes_processed, rows=rows, duration_ms=duration_ms)


class SparkDataFrameStatsAdapter:
"""
Adapter for Spark DataFrames that mirrors existing Databricks behaviour:
- bytes via provided bytes_fn (plan-based best effort)
- rows left as None
- duration passed through
"""

def __init__(self, bytes_fn: Callable[[Any], int | None]) -> None:
self.bytes_fn = bytes_fn

def collect(
self, df: Any, *, duration_ms: int | None, estimated_bytes: int | None = None
) -> QueryStats:
bytes_val = estimated_bytes
if bytes_val is None:
try:
bytes_val = self.bytes_fn(df)
except Exception:
bytes_val = None

return QueryStats(
bytes_processed=bytes_val if bytes_val is not None and bytes_val > 0 else None,
rows=None,
duration_ms=duration_ms,
)
Loading