Skip to content

Commit 430336f

Browse files
eakmanrqtobymao
andauthored
fix!: normalize scd type 2 columns (#2458)
* fix!: normalize scd type 2 columns * add migration script * add pydantic 1 support * bump migration number --------- Co-authored-by: Toby Mao <toby.mao@gmail.com>
1 parent 85b4403 commit 430336f

File tree

11 files changed

+233
-162
lines changed

11 files changed

+233
-162
lines changed

sqlmesh/core/engine_adapter/base.py

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,10 +1150,10 @@ def scd_type_2_by_time(
11501150
target_table: TableName,
11511151
source_table: QueryOrDF,
11521152
unique_key: t.Sequence[exp.Expression],
1153-
valid_from_name: str,
1154-
valid_to_name: str,
1153+
valid_from_col: exp.Column,
1154+
valid_to_col: exp.Column,
11551155
execution_time: TimeLike,
1156-
updated_at_name: str,
1156+
updated_at_col: exp.Column,
11571157
invalidate_hard_deletes: bool = True,
11581158
updated_at_as_valid_from: bool = False,
11591159
columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
@@ -1166,10 +1166,10 @@ def scd_type_2_by_time(
11661166
target_table=target_table,
11671167
source_table=source_table,
11681168
unique_key=unique_key,
1169-
valid_from_name=valid_from_name,
1170-
valid_to_name=valid_to_name,
1169+
valid_from_col=valid_from_col,
1170+
valid_to_col=valid_to_col,
11711171
execution_time=execution_time,
1172-
updated_at_name=updated_at_name,
1172+
updated_at_col=updated_at_col,
11731173
invalidate_hard_deletes=invalidate_hard_deletes,
11741174
updated_at_as_valid_from=updated_at_as_valid_from,
11751175
columns_to_types=columns_to_types,
@@ -1183,8 +1183,8 @@ def scd_type_2_by_column(
11831183
target_table: TableName,
11841184
source_table: QueryOrDF,
11851185
unique_key: t.Sequence[exp.Expression],
1186-
valid_from_name: str,
1187-
valid_to_name: str,
1186+
valid_from_col: exp.Column,
1187+
valid_to_col: exp.Column,
11881188
execution_time: TimeLike,
11891189
check_columns: t.Union[exp.Star, t.Sequence[exp.Column]],
11901190
invalidate_hard_deletes: bool = True,
@@ -1199,8 +1199,8 @@ def scd_type_2_by_column(
11991199
target_table=target_table,
12001200
source_table=source_table,
12011201
unique_key=unique_key,
1202-
valid_from_name=valid_from_name,
1203-
valid_to_name=valid_to_name,
1202+
valid_from_col=valid_from_col,
1203+
valid_to_col=valid_to_col,
12041204
execution_time=execution_time,
12051205
check_columns=check_columns,
12061206
columns_to_types=columns_to_types,
@@ -1216,11 +1216,11 @@ def _scd_type_2(
12161216
target_table: TableName,
12171217
source_table: QueryOrDF,
12181218
unique_key: t.Sequence[exp.Expression],
1219-
valid_from_name: str,
1220-
valid_to_name: str,
1219+
valid_from_col: exp.Column,
1220+
valid_to_col: exp.Column,
12211221
execution_time: TimeLike,
12221222
invalidate_hard_deletes: bool = True,
1223-
updated_at_name: t.Optional[str] = None,
1223+
updated_at_col: t.Optional[exp.Column] = None,
12241224
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
12251225
updated_at_as_valid_from: bool = False,
12261226
execution_time_as_valid_from: bool = False,
@@ -1233,6 +1233,9 @@ def _scd_type_2(
12331233
source_table, columns_to_types, target_table=target_table, batch_size=0
12341234
)
12351235
columns_to_types = columns_to_types or self.columns(target_table)
1236+
valid_from_name = valid_from_col.name
1237+
valid_to_name = valid_to_col.name
1238+
updated_at_name = updated_at_col.name if updated_at_col else None
12361239
if (
12371240
valid_from_name not in columns_to_types
12381241
or valid_to_name not in columns_to_types
@@ -1243,7 +1246,7 @@ def _scd_type_2(
12431246
raise SQLMeshError(f"Could not get columns_to_types. Does {target_table} exist?")
12441247
if not unique_key:
12451248
raise SQLMeshError("unique_key must be provided for SCD Type 2")
1246-
if check_columns and updated_at_name:
1249+
if check_columns and updated_at_col:
12471250
raise SQLMeshError(
12481251
"Cannot use both `check_columns` and `updated_at_name` for SCD Type 2"
12491252
)
@@ -1270,7 +1273,7 @@ def _scd_type_2(
12701273
table_columns = [exp.column(c, quoted=True) for c in columns_to_types]
12711274
if updated_at_name:
12721275
select_source_columns.append(
1273-
exp.cast(updated_at_name, time_data_type).as_(updated_at_name)
1276+
exp.cast(updated_at_col, time_data_type).as_(updated_at_col.this) # type: ignore
12741277
)
12751278

12761279
# If a star is provided, we include all unmanaged columns in the check.
@@ -1282,25 +1285,29 @@ def _scd_type_2(
12821285
check_columns = [exp.column(col) for col in unmanaged_columns]
12831286
execution_ts = to_time_column(execution_time, time_data_type)
12841287
if updated_at_as_valid_from:
1285-
if not updated_at_name:
1288+
if not updated_at_col:
12861289
raise SQLMeshError(
12871290
"Cannot use `updated_at_as_valid_from` without `updated_at_name` for SCD Type 2"
12881291
)
1289-
update_valid_from_start: t.Union[str, exp.Expression] = updated_at_name
1292+
update_valid_from_start: t.Union[str, exp.Expression] = updated_at_col
12901293
elif execution_time_as_valid_from:
12911294
update_valid_from_start = execution_ts
12921295
else:
12931296
update_valid_from_start = to_time_column("1970-01-01 00:00:00+00:00", time_data_type)
1294-
insert_valid_from_start = execution_ts if check_columns else exp.column(updated_at_name) # type: ignore
1297+
insert_valid_from_start = execution_ts if check_columns else updated_at_col # type: ignore
12951298
# joined._exists IS NULL is saying "if the row is deleted"
12961299
delete_check = (
12971300
exp.column("_exists", "joined").is_(exp.Null()) if invalidate_hard_deletes else None
12981301
)
1302+
prefixed_valid_to_col = valid_to_col.copy()
1303+
prefixed_valid_to_col.this.set("this", f"t_{prefixed_valid_to_col.name}")
1304+
prefixed_valid_from_col = valid_from_col.copy()
1305+
prefixed_valid_from_col.this.set("this", f"t_{valid_from_col.name}")
12991306
if check_columns:
13001307
row_check_conditions = []
13011308
for col in check_columns:
13021309
t_col = col.copy()
1303-
t_col.set("this", exp.to_identifier(f"t_{col.name}"))
1310+
t_col.this.set("this", f"t_{col.name}")
13041311
row_check_conditions.extend(
13051312
[
13061313
col.neq(t_col),
@@ -1312,7 +1319,7 @@ def _scd_type_2(
13121319
unique_key_conditions = []
13131320
for col in unique_key:
13141321
t_col = col.copy()
1315-
t_col.set("this", exp.to_identifier(f"t_{col.name}"))
1322+
t_col.this.set("this", f"t_{col.name}")
13161323
unique_key_conditions.extend(
13171324
[t_col.is_(exp.Null()).not_(), col.is_(exp.Null()).not_()]
13181325
)
@@ -1331,52 +1338,62 @@ def _scd_type_2(
13311338
),
13321339
execution_ts,
13331340
)
1334-
.else_(exp.column(f"t_{valid_to_name}"))
1335-
.as_(valid_to_name)
1341+
.else_(prefixed_valid_to_col)
1342+
.as_(valid_to_col.this)
13361343
)
13371344
valid_from_case_stmt = exp.func(
13381345
"COALESCE",
1339-
exp.column(f"t_{valid_from_name}"),
1346+
prefixed_valid_from_col,
13401347
update_valid_from_start,
1341-
).as_(valid_from_name)
1348+
).as_(valid_from_col.this)
13421349
else:
1343-
assert updated_at_name is not None
1344-
updated_row_filter = exp.column(updated_at_name) > exp.column(f"t_{updated_at_name}")
1350+
assert updated_at_col is not None
1351+
prefixed_updated_at_col = updated_at_col.copy()
1352+
prefixed_updated_at_col.this.set("this", f"t_{updated_at_col.name}")
1353+
updated_row_filter = updated_at_col > prefixed_updated_at_col
13451354

1346-
valid_to_case_stmt_builder = exp.Case().when(
1347-
updated_row_filter, exp.column(updated_at_name)
1348-
)
1355+
valid_to_case_stmt_builder = exp.Case().when(updated_row_filter, updated_at_col)
13491356
if delete_check:
13501357
valid_to_case_stmt_builder = valid_to_case_stmt_builder.when(
13511358
delete_check, execution_ts
13521359
)
1353-
valid_to_case_stmt = valid_to_case_stmt_builder.else_(
1354-
exp.column(f"t_{valid_to_name}")
1355-
).as_(valid_to_name)
1360+
valid_to_case_stmt = valid_to_case_stmt_builder.else_(prefixed_valid_to_col).as_(
1361+
valid_to_col.this
1362+
)
13561363

13571364
valid_from_case_stmt = (
13581365
exp.Case()
13591366
.when(
13601367
exp.and_(
1361-
exp.column(f"t_{valid_from_name}").is_(exp.Null()),
1368+
prefixed_valid_from_col.is_(exp.Null()),
13621369
exp.column("_exists", "latest_deleted").is_(exp.Null()).not_(),
13631370
),
13641371
exp.Case()
13651372
.when(
1366-
exp.column(valid_to_name, "latest_deleted") > exp.column(updated_at_name),
1367-
exp.column(valid_to_name, "latest_deleted"),
1373+
exp.column(valid_to_col.this, "latest_deleted") > updated_at_col,
1374+
exp.column(valid_to_col.this, "latest_deleted"),
13681375
)
1369-
.else_(exp.column(updated_at_name)),
1376+
.else_(updated_at_col),
13701377
)
1371-
.when(exp.column(f"t_{valid_from_name}").is_(exp.Null()), update_valid_from_start)
1372-
.else_(exp.column(f"t_{valid_from_name}"))
1373-
).as_(valid_from_name)
1378+
.when(prefixed_valid_from_col.is_(exp.Null()), update_valid_from_start)
1379+
.else_(prefixed_valid_from_col)
1380+
).as_(valid_from_col.this)
13741381

13751382
existing_rows_query = exp.select(*table_columns).from_(target_table)
13761383
if truncate:
13771384
existing_rows_query = existing_rows_query.limit(0)
13781385

13791386
with source_queries[0] as source_query:
1387+
prefixed_columns_to_types = []
1388+
for column in columns_to_types:
1389+
prefixed_col = exp.column(column).copy()
1390+
prefixed_col.this.set("this", f"t_{prefixed_col.name}")
1391+
prefixed_columns_to_types.append(prefixed_col)
1392+
prefixed_unmanaged_columns = []
1393+
for column in unmanaged_columns:
1394+
prefixed_col = exp.column(column).copy()
1395+
prefixed_col.this.set("this", f"t_{prefixed_col.name}")
1396+
prefixed_unmanaged_columns.append(prefixed_col)
13801397
query = (
13811398
exp.Select() # type: ignore
13821399
.with_(
@@ -1388,17 +1405,17 @@ def _scd_type_2(
13881405
# Historical Records that Do Not Change
13891406
.with_(
13901407
"static",
1391-
existing_rows_query.where(f"{valid_to_name} IS NOT NULL"),
1408+
existing_rows_query.where(valid_to_col.is_(exp.Null()).not_()),
13921409
)
13931410
# Latest Records that can be updated
13941411
.with_(
13951412
"latest",
1396-
existing_rows_query.where(f"{valid_to_name} IS NULL"),
1413+
existing_rows_query.where(valid_to_col.is_(exp.Null())),
13971414
)
13981415
# Deleted records which can be used to determine `valid_from` for undeleted source records
13991416
.with_(
14001417
"deleted",
1401-
exp.select(*[f"static.{col}" for col in columns_to_types])
1418+
exp.select(*[exp.column(col, "static") for col in columns_to_types])
14021419
.from_("static")
14031420
.join(
14041421
"latest",
@@ -1410,15 +1427,15 @@ def _scd_type_2(
14101427
),
14111428
join_type="left",
14121429
)
1413-
.where(f"latest.{valid_to_name} IS NULL"),
1430+
.where(exp.column(valid_to_col.this, "latest").is_(exp.Null())),
14141431
)
14151432
# Get the latest `valid_to` deleted record for each unique key
14161433
.with_(
14171434
"latest_deleted",
14181435
exp.select(
14191436
exp.true().as_("_exists"),
14201437
*(part.as_(f"_key{i}") for i, part in enumerate(unique_key)),
1421-
f"MAX({valid_to_name}) AS {valid_to_name}",
1438+
exp.Max(this=valid_to_col).as_(valid_to_col.this),
14221439
)
14231440
.from_("deleted")
14241441
.group_by(*unique_key),
@@ -1430,8 +1447,8 @@ def _scd_type_2(
14301447
exp.select(
14311448
exp.column("_exists", table="source"),
14321449
*(
1433-
exp.column(col, table="latest").as_(f"t_{col}")
1434-
for col in columns_to_types
1450+
exp.column(col, table="latest").as_(prefixed_columns_to_types[i].this)
1451+
for i, col in enumerate(columns_to_types)
14351452
),
14361453
*(exp.column(col, table="source").as_(col) for col in unmanaged_columns),
14371454
)
@@ -1450,8 +1467,10 @@ def _scd_type_2(
14501467
exp.select(
14511468
exp.column("_exists", table="source"),
14521469
*(
1453-
exp.column(col, table="latest").as_(f"t_{col}")
1454-
for col in columns_to_types
1470+
exp.column(col, table="latest").as_(
1471+
prefixed_columns_to_types[i].this
1472+
)
1473+
for i, col in enumerate(columns_to_types)
14551474
),
14561475
*(
14571476
exp.column(col, table="source").as_(col)
@@ -1478,10 +1497,10 @@ def _scd_type_2(
14781497
*(
14791498
exp.func(
14801499
"COALESCE",
1481-
exp.column(f"t_{col}", table="joined"),
1500+
exp.column(prefixed_unmanaged_columns[i].this, table="joined"),
14821501
exp.column(col, table="joined"),
14831502
).as_(col)
1484-
for col in unmanaged_columns
1503+
for i, col in enumerate(unmanaged_columns)
14851504
),
14861505
valid_from_case_stmt,
14871506
valid_to_case_stmt,
@@ -1505,8 +1524,8 @@ def _scd_type_2(
15051524
"inserted_rows",
15061525
exp.select(
15071526
*unmanaged_columns,
1508-
insert_valid_from_start.as_(valid_from_name),
1509-
to_time_column(exp.null(), time_data_type).as_(valid_to_name),
1527+
insert_valid_from_start.as_(valid_from_col.this), # type: ignore
1528+
to_time_column(exp.null(), time_data_type).as_(valid_to_col.this),
15101529
)
15111530
.from_("joined")
15121531
.where(updated_row_filter),

sqlmesh/core/engine_adapter/trino.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ def _scd_type_2(
205205
target_table: TableName,
206206
source_table: QueryOrDF,
207207
unique_key: t.Sequence[exp.Expression],
208-
valid_from_name: str,
209-
valid_to_name: str,
208+
valid_from_col: exp.Column,
209+
valid_to_col: exp.Column,
210210
execution_time: TimeLike,
211211
invalidate_hard_deletes: bool = True,
212-
updated_at_name: t.Optional[str] = None,
212+
updated_at_col: t.Optional[exp.Column] = None,
213213
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
214214
updated_at_as_valid_from: bool = False,
215215
execution_time_as_valid_from: bool = False,
@@ -225,11 +225,11 @@ def _scd_type_2(
225225
target_table,
226226
source_table,
227227
unique_key,
228-
valid_from_name,
229-
valid_to_name,
228+
valid_from_col,
229+
valid_to_col,
230230
execution_time,
231231
invalidate_hard_deletes,
232-
updated_at_name,
232+
updated_at_col,
233233
check_columns,
234234
updated_at_as_valid_from,
235235
execution_time_as_valid_from,

0 commit comments

Comments
 (0)