@@ -444,6 +444,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
444444 )
445445 unique_key : SQLGlotListOfFields
446446 when_matched : t .Optional [exp .Whens ] = None
447+ merge_filter : t .Optional [exp .Expression ] = None
447448 batch_concurrency : t .Literal [1 ] = 1
448449
449450 @field_validator ("when_matched" , mode = "before" )
@@ -453,17 +454,6 @@ def _when_matched_validator(
453454 v : t .Optional [t .Union [str , exp .Whens ]],
454455 values : t .Dict [str , t .Any ],
455456 ) -> t .Optional [exp .Whens ]:
456- def replace_table_references (expression : exp .Expression ) -> exp .Expression :
457- from sqlmesh .core .engine_adapter .base import MERGE_SOURCE_ALIAS , MERGE_TARGET_ALIAS
458-
459- if isinstance (expression , exp .Column ):
460- if expression .table .lower () == "target" :
461- expression .set ("table" , exp .to_identifier (MERGE_TARGET_ALIAS ))
462- elif expression .table .lower () == "source" :
463- expression .set ("table" , exp .to_identifier (MERGE_SOURCE_ALIAS ))
464-
465- return expression
466-
467457 if v is None :
468458 return v
469459 if isinstance (v , str ):
@@ -474,14 +464,30 @@ def replace_table_references(expression: exp.Expression) -> exp.Expression:
474464
475465 return t .cast (exp .Whens , d .parse_one (v , into = exp .Whens , dialect = get_dialect (values )))
476466
477- return t .cast (exp .Whens , v .transform (replace_table_references ))
467+ return t .cast (exp .Whens , v .transform (d .replace_merge_table_aliases ))
468+
469+ @field_validator ("merge_filter" , mode = "before" )
470+ @field_validator_v1_args
471+ def _merge_filter_validator (
472+ cls ,
473+ v : t .Optional [exp .Expression ],
474+ values : t .Dict [str , t .Any ],
475+ ) -> t .Optional [exp .Expression ]:
476+ if v is None :
477+ return v
478+ if isinstance (v , str ):
479+ v = v .strip ()
480+ return d .parse_one (v , dialect = get_dialect (values ))
481+
482+ return v .transform (d .replace_merge_table_aliases )
478483
479484 @property
480485 def data_hash_values (self ) -> t .List [t .Optional [str ]]:
481486 return [
482487 * super ().data_hash_values ,
483488 * (gen (k ) for k in self .unique_key ),
484489 gen (self .when_matched ) if self .when_matched is not None else None ,
490+ gen (self .merge_filter ) if self .merge_filter is not None else None ,
485491 ]
486492
487493 def to_expression (
@@ -494,6 +500,7 @@ def to_expression(
494500 {
495501 "unique_key" : exp .Tuple (expressions = self .unique_key ),
496502 "when_matched" : self .when_matched ,
503+ "merge_filter" : self .merge_filter ,
497504 }
498505 ),
499506 ],
0 commit comments