Skip to content

Commit 0e51dd8

Browse files
Feat!: Add support for merge_filter and dbt incremental_predicates for Incremental By Unique Key (#3540)
1 parent 1f71537 commit 0e51dd8

File tree

13 files changed

+354
-15
lines changed

13 files changed

+354
-15
lines changed

sqlmesh/core/dialect.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,8 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]:
429429
lambda: _parse_macro_or_clause(self, self._parse_when_matched),
430430
optional=True,
431431
)
432+
elif name == "merge_filter":
433+
value = self._parse_conjunction()
432434
elif self._match(TokenType.L_PAREN):
433435
value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality))
434436
self._match_r_paren()
@@ -1260,3 +1262,19 @@ def extract_func_call(
12601262

12611263
def is_meta_expression(v: t.Any) -> bool:
12621264
return isinstance(v, (Audit, Metric, Model))
1265+
1266+
1267+
def replace_merge_table_aliases(expression: exp.Expression) -> exp.Expression:
1268+
"""
1269+
Resolves references from the "source" and "target" tables (or their DBT equivalents)
1270+
with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS)
1271+
"""
1272+
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
1273+
1274+
if isinstance(expression, exp.Column):
1275+
if expression.table.lower() in ("target", "dbt_internal_dest"):
1276+
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
1277+
elif expression.table.lower() in ("source", "dbt_internal_source"):
1278+
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))
1279+
1280+
return expression

sqlmesh/core/engine_adapter/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,6 +1804,7 @@ def merge(
18041804
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
18051805
unique_key: t.Sequence[exp.Expression],
18061806
when_matched: t.Optional[exp.Whens] = None,
1807+
merge_filter: t.Optional[exp.Expression] = None,
18071808
) -> None:
18081809
source_queries, columns_to_types = self._get_source_queries_and_columns_to_types(
18091810
source_table, columns_to_types, target_table=target_table
@@ -1815,6 +1816,9 @@ def merge(
18151816
for part in unique_key
18161817
)
18171818
)
1819+
if merge_filter:
1820+
on = exp.and_(merge_filter, on)
1821+
18181822
if not when_matched:
18191823
when_matched = exp.Whens()
18201824
when_matched.append(

sqlmesh/core/engine_adapter/mixins.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def merge(
3131
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
3232
unique_key: t.Sequence[exp.Expression],
3333
when_matched: t.Optional[exp.Whens] = None,
34+
merge_filter: t.Optional[exp.Expression] = None,
3435
) -> None:
3536
logical_merge(
3637
self,
@@ -39,6 +40,7 @@ def merge(
3940
columns_to_types,
4041
unique_key,
4142
when_matched=when_matched,
43+
merge_filter=merge_filter,
4244
)
4345

4446

@@ -409,6 +411,7 @@ def logical_merge(
409411
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
410412
unique_key: t.Sequence[exp.Expression],
411413
when_matched: t.Optional[exp.Whens] = None,
414+
merge_filter: t.Optional[exp.Expression] = None,
412415
) -> None:
413416
"""
414417
Merge implementation for engine adapters that do not support merge natively.
@@ -420,10 +423,12 @@ def logical_merge(
420423
within the temporary table are ommitted.
421424
4. Drop the temporary table.
422425
"""
423-
if when_matched:
426+
if when_matched or merge_filter:
427+
prop = "when_matched" if when_matched else "merge_filter"
424428
raise SQLMeshError(
425-
"This engine does not support MERGE expressions and therefore `when_matched` is not supported."
429+
f"This engine does not support MERGE expressions and therefore `{prop}` is not supported."
426430
)
431+
427432
engine_adapter._replace_by_key(
428433
target_table, source_table, columns_to_types, unique_key, is_unique_key=True
429434
)

sqlmesh/core/engine_adapter/postgres.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def merge(
107107
columns_to_types: t.Optional[t.Dict[str, exp.DataType]],
108108
unique_key: t.Sequence[exp.Expression],
109109
when_matched: t.Optional[exp.Whens] = None,
110+
merge_filter: t.Optional[exp.Expression] = None,
110111
) -> None:
111112
# Merge isn't supported until Postgres 15
112113
merge_impl = (
@@ -120,4 +121,5 @@ def merge(
120121
columns_to_types,
121122
unique_key,
122123
when_matched=when_matched,
124+
merge_filter=merge_filter,
123125
)

sqlmesh/core/model/kind.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
444444
)
445445
unique_key: SQLGlotListOfFields
446446
when_matched: t.Optional[exp.Whens] = None
447+
merge_filter: t.Optional[exp.Expression] = None
447448
batch_concurrency: t.Literal[1] = 1
448449

449450
@field_validator("when_matched", mode="before")
@@ -453,17 +454,6 @@ def _when_matched_validator(
453454
v: t.Optional[t.Union[str, exp.Whens]],
454455
values: t.Dict[str, t.Any],
455456
) -> t.Optional[exp.Whens]:
456-
def replace_table_references(expression: exp.Expression) -> exp.Expression:
457-
from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS
458-
459-
if isinstance(expression, exp.Column):
460-
if expression.table.lower() == "target":
461-
expression.set("table", exp.to_identifier(MERGE_TARGET_ALIAS))
462-
elif expression.table.lower() == "source":
463-
expression.set("table", exp.to_identifier(MERGE_SOURCE_ALIAS))
464-
465-
return expression
466-
467457
if v is None:
468458
return v
469459
if isinstance(v, str):
@@ -474,14 +464,30 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:
474464

475465
return t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=get_dialect(values)))
476466

477-
return t.cast(exp.Whens, v.transform(replace_table_references))
467+
return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases))
468+
469+
@field_validator("merge_filter", mode="before")
470+
@field_validator_v1_args
471+
def _merge_filter_validator(
472+
cls,
473+
v: t.Optional[exp.Expression],
474+
values: t.Dict[str, t.Any],
475+
) -> t.Optional[exp.Expression]:
476+
if v is None:
477+
return v
478+
if isinstance(v, str):
479+
v = v.strip()
480+
return d.parse_one(v, dialect=get_dialect(values))
481+
482+
return v.transform(d.replace_merge_table_aliases)
478483

479484
@property
480485
def data_hash_values(self) -> t.List[t.Optional[str]]:
481486
return [
482487
*super().data_hash_values,
483488
*(gen(k) for k in self.unique_key),
484489
gen(self.when_matched) if self.when_matched is not None else None,
490+
gen(self.merge_filter) if self.merge_filter is not None else None,
485491
]
486492

487493
def to_expression(
@@ -494,6 +500,7 @@ def to_expression(
494500
{
495501
"unique_key": exp.Tuple(expressions=self.unique_key),
496502
"when_matched": self.when_matched,
503+
"merge_filter": self.merge_filter,
497504
}
498505
),
499506
],

sqlmesh/core/model/meta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ def when_matched(self) -> t.Optional[exp.Whens]:
436436
return self.kind.when_matched
437437
return None
438438

439+
@property
440+
def merge_filter(self) -> t.Optional[exp.Expression]:
441+
if isinstance(self.kind, IncrementalByUniqueKeyKind):
442+
return self.kind.merge_filter
443+
return None
444+
439445
@property
440446
def catalog(self) -> t.Optional[str]:
441447
"""Returns the catalog of a model."""

sqlmesh/core/snapshot/evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,7 @@ def insert(
13921392
columns_to_types=model.columns_to_types,
13931393
unique_key=model.unique_key,
13941394
when_matched=model.when_matched,
1395+
merge_filter=model.merge_filter,
13951396
)
13961397

13971398
def append(
@@ -1407,6 +1408,7 @@ def append(
14071408
columns_to_types=model.columns_to_types,
14081409
unique_key=model.unique_key,
14091410
when_matched=model.when_matched,
1411+
merge_filter=model.merge_filter,
14101412
)
14111413

14121414

sqlmesh/dbt/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@ def model_kind(self, context: DbtContext) -> ModelKind:
294294
f"{self.canonical_name(context)}: SQLMesh incremental by unique key strategy is not compatible with '{strategy}'"
295295
f" incremental strategy. Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}."
296296
)
297+
298+
if self.incremental_predicates:
299+
dialect = self.dialect(context)
300+
incremental_kind_kwargs["merge_filter"] = exp.and_(
301+
*[
302+
d.parse_one(predicate, dialect=dialect)
303+
for predicate in self.incremental_predicates
304+
],
305+
dialect=dialect,
306+
).transform(d.replace_merge_table_aliases)
307+
297308
return IncrementalByUniqueKeyKind(
298309
unique_key=self.unique_key,
299310
disable_restatement=disable_restatement,

tests/core/engine_adapter/test_base.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,79 @@ def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, ass
11261126
)
11271127

11281128

1129+
def test_merge_filter(make_mocked_engine_adapter: t.Callable, assert_exp_eq):
1130+
adapter = make_mocked_engine_adapter(EngineAdapter)
1131+
1132+
adapter.merge(
1133+
target_table="target",
1134+
source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')),
1135+
columns_to_types={
1136+
"ID": exp.DataType.build("int"),
1137+
"ts": exp.DataType.build("timestamp"),
1138+
"val": exp.DataType.build("int"),
1139+
},
1140+
unique_key=[exp.to_identifier("ID", quoted=True)],
1141+
when_matched=exp.Whens(
1142+
expressions=[
1143+
exp.When(
1144+
matched=True,
1145+
source=False,
1146+
then=exp.Update(
1147+
expressions=[
1148+
exp.column("val", "__MERGE_TARGET__").eq(
1149+
exp.column("val", "__MERGE_SOURCE__")
1150+
),
1151+
exp.column("ts", "__MERGE_TARGET__").eq(
1152+
exp.Coalesce(
1153+
this=exp.column("ts", "__MERGE_SOURCE__"),
1154+
expressions=[exp.column("ts", "__MERGE_TARGET__")],
1155+
)
1156+
),
1157+
],
1158+
),
1159+
)
1160+
]
1161+
),
1162+
merge_filter=exp.And(
1163+
this=exp.GT(
1164+
this=exp.column("ID", "__MERGE_SOURCE__"),
1165+
expression=exp.Literal(this="0", is_string=False),
1166+
),
1167+
expression=exp.LT(
1168+
this=exp.column("ts", "__MERGE_TARGET__"),
1169+
expression=exp.Timestamp(this=exp.column("2020-02-05", quoted=True)),
1170+
),
1171+
),
1172+
)
1173+
1174+
assert_exp_eq(
1175+
adapter.cursor.execute.call_args[0][0],
1176+
"""
1177+
MERGE INTO "target" AS "__MERGE_TARGET__"
1178+
USING (
1179+
SELECT "ID", "ts", "val"
1180+
FROM "source"
1181+
) AS "__MERGE_SOURCE__"
1182+
ON (
1183+
"__MERGE_SOURCE__"."ID" > 0
1184+
AND "__MERGE_TARGET__"."ts" < TIMESTAMP("2020-02-05")
1185+
)
1186+
AND "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID"
1187+
WHEN MATCHED THEN
1188+
UPDATE SET
1189+
"__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val",
1190+
"__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts")
1191+
WHEN NOT MATCHED THEN
1192+
INSERT ("ID", "ts", "val")
1193+
VALUES (
1194+
"__MERGE_SOURCE__"."ID",
1195+
"__MERGE_SOURCE__"."ts",
1196+
"__MERGE_SOURCE__"."val"
1197+
);
1198+
""",
1199+
)
1200+
1201+
11291202
def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable):
11301203
adapter = make_mocked_engine_adapter(EngineAdapter)
11311204

0 commit comments

Comments
 (0)