Skip to content

Commit d6e6d07

Browse files
authored
Chore: Refactor the scheduler API to support running of precomputed missing intervals (#3120)
1 parent eae04ba commit d6e6d07

File tree

1 file changed

+50
-24
lines changed

1 file changed

+50
-24
lines changed

sqlmesh/core/scheduler.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sqlmesh.core.snapshot.definition import SnapshotId
2727
from sqlmesh.core.state_sync import StateSync
2828
from sqlmesh.utils import format_exception
29-
from sqlmesh.utils.concurrency import concurrent_apply_to_dag
29+
from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError
3030
from sqlmesh.utils.dag import DAG
3131
from sqlmesh.utils.date import (
3232
TimeLike,
@@ -347,14 +347,60 @@ def run(
347347
if not batches:
348348
return True
349349

350-
dag = self._dag(batches)
351-
352350
self.console.start_evaluation_progress(
353351
{snapshot: len(intervals) for snapshot, intervals in batches.items()},
354352
environment_naming_info,
355353
self.default_catalog,
356354
)
357355

356+
errors, skipped_intervals = self.run_batches(
357+
batches=batches,
358+
deployability_index=deployability_index,
359+
execution_time=execution_time,
360+
circuit_breaker=circuit_breaker,
361+
)
362+
363+
self.console.stop_evaluation_progress(success=not errors)
364+
365+
skipped_snapshots = {i[0] for i in skipped_intervals}
366+
for skipped in skipped_snapshots:
367+
log_message = f"SKIPPED snapshot {skipped}\n"
368+
self.console.log_status_update(log_message)
369+
logger.info(log_message)
370+
371+
for error in errors:
372+
if isinstance(error.__cause__, CircuitBreakerError):
373+
raise error.__cause__
374+
sid = error.node[0]
375+
formatted_exception = "".join(format_exception(error.__cause__ or error))
376+
log_message = f"FAILED processing snapshot {sid}\n{formatted_exception}"
377+
self.console.log_error(log_message)
378+
# Log with INFO level to prevent duplicate messages in the console.
379+
logger.info(log_message)
380+
381+
return not errors
382+
383+
def run_batches(
384+
self,
385+
batches: SnapshotToBatches,
386+
deployability_index: DeployabilityIndex,
387+
execution_time: t.Optional[TimeLike] = None,
388+
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
389+
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
390+
"""Runs precomputed batches of missing intervals.
391+
392+
Args:
393+
batches: The batches of snapshots and intervals to evaluate.
394+
deployability_index: Determines snapshots that are deployable in the context of this evaluation.
395+
execution_time: The date/time time reference to use for execution time.
396+
circuit_breaker: An optional handler which checks if the run should be aborted.
397+
398+
Returns:
399+
A tuple of errors and skipped intervals.
400+
"""
401+
execution_time = execution_time or now()
402+
dag = self._dag(batches)
403+
358404
snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()}
359405

360406
def evaluate_node(node: SchedulingUnit) -> None:
@@ -383,7 +429,7 @@ def evaluate_node(node: SchedulingUnit) -> None:
383429

384430
try:
385431
with self.snapshot_evaluator.concurrent_context():
386-
errors, skipped_intervals = concurrent_apply_to_dag(
432+
return concurrent_apply_to_dag(
387433
dag,
388434
evaluate_node,
389435
self.max_workers,
@@ -392,26 +438,6 @@ def evaluate_node(node: SchedulingUnit) -> None:
392438
finally:
393439
self.state_sync.recycle()
394440

395-
self.console.stop_evaluation_progress(success=not errors)
396-
397-
skipped_snapshots = {i[0] for i in skipped_intervals}
398-
for skipped in skipped_snapshots:
399-
log_message = f"SKIPPED snapshot {skipped}\n"
400-
self.console.log_status_update(log_message)
401-
logger.info(log_message)
402-
403-
for error in errors:
404-
if isinstance(error.__cause__, CircuitBreakerError):
405-
raise error.__cause__
406-
sid = error.node[0]
407-
formatted_exception = "".join(format_exception(error.__cause__ or error))
408-
log_message = f"FAILED processing snapshot {sid}\n{formatted_exception}"
409-
self.console.log_error(log_message)
410-
# Log with INFO level to prevent duplicate messages in the console.
411-
logger.info(log_message)
412-
413-
return not errors
414-
415441
def _dag(self, batches: SnapshotToBatches) -> DAG[SchedulingUnit]:
416442
"""Builds a DAG of snapshot intervals to be evaluated.
417443

0 commit comments

Comments
 (0)