Skip to content

Commit 7144853

Browse files
committed
Chore: Reintroduce tagging queries with correlation ID
1 parent 317df56 commit 7144853

File tree

5 files changed

+71
-22
lines changed

5 files changed

+71
-22
lines changed

sqlmesh/core/context.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,8 @@ def engine_adapter(self) -> EngineAdapter:
451451
@property
452452
def snapshot_evaluator(self) -> SnapshotEvaluator:
453453
if not self._snapshot_evaluator:
454-
self._snapshot_evaluator = SnapshotEvaluator(
455-
{
456-
gateway: adapter.with_settings(log_level=logging.INFO)
457-
for gateway, adapter in self.engine_adapters.items()
458-
},
459-
ddl_concurrent_tasks=self.concurrent_tasks,
460-
selected_gateway=self.selected_gateway,
461-
)
454+
self._snapshot_evaluator = self._create_snapshot_evaluator(log_level=logging.INFO)
455+
462456
return self._snapshot_evaluator
463457

464458
def execution_context(
@@ -520,7 +514,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model:
520514

521515
return model
522516

523-
def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
517+
def scheduler(
518+
self,
519+
environment: t.Optional[str] = None,
520+
snapshot_evaluator: t.Optional[SnapshotEvaluator] = None,
521+
) -> Scheduler:
524522
"""Returns the built-in scheduler.
525523
526524
Args:
@@ -542,9 +540,11 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
542540
if not snapshots:
543541
raise ConfigError("No models were found")
544542

545-
return self.create_scheduler(snapshots)
543+
return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator)
546544

547-
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
545+
def create_scheduler(
546+
self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator
547+
) -> Scheduler:
548548
"""Creates the built-in scheduler.
549549
550550
Args:
@@ -555,7 +555,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
555555
"""
556556
return Scheduler(
557557
snapshots,
558-
self.snapshot_evaluator,
558+
snapshot_evaluator,
559559
self.state_sync,
560560
default_catalog=self.default_catalog,
561561
max_workers=self.concurrent_tasks,
@@ -3064,6 +3064,16 @@ def load_model_tests(
30643064

30653065
return model_tests
30663066

3067+
def _create_snapshot_evaluator(self, **kwargs: t.Any) -> SnapshotEvaluator:
3068+
return SnapshotEvaluator(
3069+
{
3070+
gateway: adapter.with_settings(**kwargs)
3071+
for gateway, adapter in self.engine_adapters.items()
3072+
},
3073+
ddl_concurrent_tasks=self.concurrent_tasks,
3074+
selected_gateway=self.selected_gateway,
3075+
)
3076+
30673077

30683078
class Context(GenericContext[Config]):
30693079
CONFIG_TYPE = Config

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,15 @@ def __init__(
147147
self._multithreaded = multithreaded
148148
self.correlation_id = correlation_id
149149

150-
def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
150+
def with_settings(self, log_level: int = logging.DEBUG, **kwargs: t.Any) -> EngineAdapter:
151151
adapter = self.__class__(
152152
self._connection_pool,
153153
dialect=self.dialect,
154154
sql_gen_kwargs=self._sql_gen_kwargs,
155155
default_catalog=self._default_catalog,
156156
execute_log_level=log_level,
157157
register_comments=self._register_comments,
158-
null_connection=True,
158+
null_connection=self._extra_config.pop("null_connection", True),
159159
multithreaded=self._multithreaded,
160160
pretty_sql=self._pretty_sql,
161161
**self._extra_config,

sqlmesh/core/plan/evaluator.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from sqlmesh.utils import to_snake_case
4040
from sqlmesh.core.state_sync import StateSync
41+
from sqlmesh.utils import CorrelationId
4142
from sqlmesh.utils.concurrency import NodeExecutionFailedError
4243
from sqlmesh.utils.errors import PlanError, SQLMeshError
4344
from sqlmesh.utils.dag import DAG
@@ -71,7 +72,7 @@ def __init__(
7172
self,
7273
state_sync: StateSync,
7374
snapshot_evaluator: SnapshotEvaluator,
74-
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
75+
create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler],
7576
default_catalog: t.Optional[str],
7677
console: t.Optional[Console] = None,
7778
):
@@ -89,6 +90,7 @@ def evaluate(
8990
) -> None:
9091
self._circuit_breaker = circuit_breaker
9192

93+
self.set_correlation_id(CorrelationId.from_plan_id(plan.plan_id))
9294
self.console.start_plan_evaluation(plan)
9395
analytics.collector.on_plan_apply_start(
9496
plan=plan,
@@ -228,7 +230,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
228230
self.console.log_success("SKIP: No model batches to execute")
229231
return
230232

231-
scheduler = self.create_scheduler(stage.all_snapshots.values())
233+
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
232234
errors, _ = scheduler.run_merged_intervals(
233235
merged_intervals=stage.snapshot_to_intervals,
234236
deployability_index=stage.deployability_index,
@@ -249,7 +251,7 @@ def visit_audit_only_run_stage(
249251
return
250252

251253
# If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them
252-
scheduler = self.create_scheduler(audit_snapshots)
254+
scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator)
253255
completion_status = scheduler.audit(
254256
plan.environment,
255257
plan.start,
@@ -349,6 +351,13 @@ def visit_finalize_environment_stage(
349351
) -> None:
350352
self.state_sync.finalize(plan.environment)
351353

354+
def set_correlation_id(self, correlation_id: CorrelationId) -> None:
355+
for key, adapter in self.snapshot_evaluator.adapters.items():
356+
if correlation_id != adapter.correlation_id:
357+
self.snapshot_evaluator.adapters[key] = adapter.with_settings(
358+
correlation_id=correlation_id
359+
)
360+
352361
def _promote_snapshots(
353362
self,
354363
plan: EvaluatablePlan,

sqlmesh/core/snapshot/evaluator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,7 @@ def __init__(
122122
self.adapters = (
123123
adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters}
124124
)
125-
self.adapter = (
126-
next(iter(self.adapters.values()))
127-
if not selected_gateway
128-
else self.adapters[selected_gateway]
129-
)
125+
self.selected_gateway = selected_gateway
130126
self.ddl_concurrent_tasks = ddl_concurrent_tasks
131127

132128
def evaluate(
@@ -603,6 +599,14 @@ def close(self) -> None:
603599
except Exception:
604600
logger.exception("Failed to close Snapshot Evaluator")
605601

602+
@property
603+
def adapter(self) -> EngineAdapter:
604+
return (
605+
next(iter(self.adapters.values()))
606+
if not self.selected_gateway
607+
else self.adapters[self.selected_gateway]
608+
)
609+
606610
def _evaluate_snapshot(
607611
self,
608612
snapshot: Snapshot,

tests/core/test_integration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
7272
from sqlmesh.utils.pydantic import validate_string
7373
from tests.conftest import DuckDBMetadata, SushiDataValidator
74+
from sqlmesh.utils import CorrelationId
7475
from tests.utils.test_helpers import use_terminal_console
7576
from tests.utils.test_filesystem import create_temp_file
7677

@@ -6815,3 +6816,28 @@ def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Call
68156816
# valid_from should be the epoch, valid_to should be NaT
68166817
assert str(row["valid_from"]) == "1970-01-01 00:00:00"
68176818
assert pd.isna(row["valid_to"])
6819+
6820+
6821+
def test_plan_evaluator_correlation_id(tmp_path: Path):
6822+
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
6823+
sqls = [call[0][0] for call in mock_logger.call_args_list]
6824+
return any(f"/* {correlation_id} */" in sql for sql in sqls)
6825+
6826+
ctx = Context(paths=[tmp_path], config=Config())
6827+
6828+
# Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan
6829+
for i in range(2):
6830+
create_temp_file(
6831+
tmp_path,
6832+
Path("models", "test.sql"),
6833+
f"MODEL (name test.a, kind FULL); SELECT {i} AS col",
6834+
)
6835+
6836+
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6837+
ctx.load()
6838+
plan = ctx.plan(auto_apply=True, no_prompts=True)
6839+
6840+
correlation_id = CorrelationId.from_plan_id(plan.plan_id)
6841+
assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}"
6842+
6843+
assert _correlation_id_in_sqls(correlation_id, mock_logger)

0 commit comments

Comments
 (0)