Skip to content

Commit c6ec1d7

Browse files
authored
Fix: When removing intervals also remove them for snapshots that have been deleted by the janitor process (#2489)
1 parent 606cc94 commit c6ec1d7

File tree

3 files changed

+73
-12
lines changed

3 files changed

+73
-12
lines changed

sqlmesh/core/snapshot/definition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class SnapshotIntervals(PydanticModel, frozen=True):
162162
def snapshot_id(self) -> SnapshotId:
163163
return SnapshotId(name=self.name, identifier=self.identifier)
164164

165+
@property
166+
def name_version(self) -> SnapshotNameVersion:
167+
return SnapshotNameVersion(name=self.name, version=self.version)
168+
165169

166170
class SnapshotDataVersion(PydanticModel, frozen=True):
167171
fingerprint: SnapshotFingerprint

sqlmesh/core/state_sync/engine_adapter.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -645,26 +645,38 @@ def remove_interval(
645645
execution_time: t.Optional[TimeLike] = None,
646646
remove_shared_versions: bool = False,
647647
) -> None:
648+
intervals_to_remove: t.Sequence[
649+
t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]
650+
] = snapshot_intervals
648651
if remove_shared_versions:
649-
name_version_mapping = {
650-
s.name_version: (s, interval) for s, interval in snapshot_intervals
651-
}
652-
all_snapshots = self._get_snapshots_with_same_version(
653-
[s[0] for s in snapshot_intervals]
654-
)
655-
snapshot_intervals = [
656-
(snapshot, name_version_mapping[snapshot.name_version][1])
652+
name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals}
653+
all_snapshots = []
654+
for where in self._snapshot_name_version_filter(name_version_mapping, alias=None):
655+
all_snapshots.extend(
656+
[
657+
SnapshotIntervals(
658+
name=r[0], identifier=r[1], version=r[2], intervals=[], dev_intervals=[]
659+
)
660+
for r in self._fetchall(
661+
exp.select("name", "identifier", "version")
662+
.from_(self.intervals_table)
663+
.where(where)
664+
)
665+
]
666+
)
667+
intervals_to_remove = [
668+
(snapshot, name_version_mapping[snapshot.name_version])
657669
for snapshot in all_snapshots
658670
]
659671

660672
if logger.isEnabledFor(logging.INFO):
661-
snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in snapshot_intervals)
673+
snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove)
662674
logger.info("Removing interval for snapshots: %s", snapshot_ids)
663675

664676
for is_dev in (True, False):
665677
self.engine_adapter.insert_append(
666678
self.intervals_table,
667-
_intervals_to_df(snapshot_intervals, is_dev=is_dev, is_removed=True),
679+
_intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True),
668680
columns_to_types=self._interval_columns_to_types,
669681
)
670682

@@ -1262,7 +1274,9 @@ def _snapshot_id_filter(
12621274
)
12631275

12641276
def _snapshot_name_version_filter(
1265-
self, snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], alias: str = "snapshots"
1277+
self,
1278+
snapshot_name_versions: t.Iterable[SnapshotNameVersionLike],
1279+
alias: t.Optional[str] = "snapshots",
12661280
) -> t.Iterator[exp.Condition]:
12671281
name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions})
12681282
batches = self._snapshot_batches(name_versions)
@@ -1305,7 +1319,7 @@ def _transaction(self) -> t.Iterator[None]:
13051319

13061320

13071321
def _intervals_to_df(
1308-
snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]],
1322+
snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]],
13091323
is_dev: bool,
13101324
is_removed: bool,
13111325
) -> pd.DataFrame:

tests/core/test_state_sync.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,49 @@ def test_remove_interval(state_sync: EngineAdapterStateSync, make_snapshot: t.Ca
330330
]
331331

332332

333+
def test_remove_interval_missing_snapshot(
334+
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable
335+
) -> None:
336+
snapshot_a = make_snapshot(
337+
SqlModel(
338+
name="a",
339+
cron="@daily",
340+
query=parse_one("select 1, ds"),
341+
),
342+
version="a",
343+
)
344+
snapshot_b = make_snapshot(
345+
SqlModel(
346+
name="a",
347+
cron="@daily",
348+
query=parse_one("select 2::INT, '2022-01-01'::TEXT AS ds"),
349+
),
350+
version="a",
351+
)
352+
# Only add snapshot_a to simulate that snapshot_b is missing
353+
state_sync.push_snapshots([snapshot_a])
354+
state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-10")
355+
state_sync.add_interval(snapshot_b, "2020-01-11", "2020-01-30")
356+
357+
snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b])
358+
assert len(snapshots) == 1
359+
assert snapshots[snapshot_a.snapshot_id].intervals == [
360+
(to_timestamp("2020-01-01"), to_timestamp("2020-01-31")),
361+
]
362+
363+
state_sync.remove_interval(
364+
[(snapshot_a, snapshot_a.inclusive_exclusive("2020-01-15", "2020-01-17"))],
365+
remove_shared_versions=True,
366+
)
367+
368+
snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b])
369+
assert len(snapshots) == 1
370+
assert snapshots[snapshot_a.snapshot_id].intervals == [
371+
(to_timestamp("2020-01-01"), to_timestamp("2020-01-15")),
372+
(to_timestamp("2020-01-18"), to_timestamp("2020-01-31")),
373+
]
374+
375+
333376
def test_refresh_snapshot_intervals(
334377
state_sync: EngineAdapterStateSync, make_snapshot: t.Callable
335378
) -> None:

0 commit comments

Comments
 (0)