Skip to content

Commit 765b737

Browse files
authored
Fix: Create a temporary view to source the model schema instead of parsing the Redshift plan (#2308)
1 parent c1ce829 commit 765b737

File tree

4 files changed

+52
-882
lines changed

4 files changed

+52
-882
lines changed

sqlmesh/core/engine_adapter/base_postgres.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def columns(
3535
sql = (
3636
exp.select("column_name", "data_type")
3737
.from_(self.COLUMNS_TABLE)
38-
.where(
39-
f"table_name = '{table.alias_or_name}' AND table_schema = '{table.args['db'].name}'"
40-
)
38+
.where(f"table_name = '{table.alias_or_name}'")
4139
)
40+
if table.args.get("db"):
41+
sql = sql.where(f"table_schema = '{table.args['db'].name}'")
42+
4243
self.execute(sql)
4344
resp = self.cursor.fetchall()
4445
if not resp:

sqlmesh/core/engine_adapter/redshift.py

Lines changed: 24 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pandas as pd
66
from sqlglot import exp
7-
from sqlglot.errors import ErrorLevel
87

98
from sqlmesh.core.dialect import to_schema
109
from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter
@@ -20,7 +19,7 @@
2019
SourceQuery,
2120
set_catalog,
2221
)
23-
from sqlmesh.utils.errors import SQLMeshError
22+
from sqlmesh.utils import random_id
2423

2524
if t.TYPE_CHECKING:
2625
from sqlmesh.core._typing import SchemaName, TableName
@@ -121,68 +120,36 @@ def _build_create_table_exp(
121120
and statement.expression.args.get("limit") is not None
122121
and statement.expression.args["limit"].expression.this == "0"
123122
):
123+
assert not isinstance(table_name_or_schema, exp.Schema)
124124
# redshift has a bug where CTAS statements have non determistic types. if a limit
125125
# is applied to a ctas statement, VARCHAR types default to 1 in some instances.
126126
# this checks the explain plain from redshift and tries to detect when these optimizer
127127
# bugs occur and force a cast
128-
explain_statement = statement.copy()
129-
for select in explain_statement.find_all(exp.Select):
130-
if select.args.get("from"):
131-
select.set("limit", None)
132-
select.set("where", None)
128+
select_statement = statement.expression.copy()
129+
for select_or_union in select_statement.find_all(exp.Select, exp.Union):
130+
select_or_union.set("limit", None)
131+
select_or_union.set("where", None)
132+
133+
temp_view_name = exp.table_(f"#sqlmesh__{random_id()}")
134+
self.create_view(
135+
temp_view_name, select_statement, replace=False, no_schema_binding=False
136+
)
137+
columns_to_types_from_view = self.columns(temp_view_name)
133138

134-
explain_statement_sql = explain_statement.sql(
135-
dialect=self.dialect, identify=True, unsupported_level=ErrorLevel.IGNORE, copy=False
139+
schema = self._build_schema_exp(
140+
exp.to_table(table_name_or_schema),
141+
columns_to_types_from_view,
136142
)
137-
plan = parse_plan(
138-
"\n".join(r[0] for r in self.fetchall(f"EXPLAIN VERBOSE {explain_statement_sql}"))
143+
statement = super()._build_create_table_exp(
144+
schema,
145+
None,
146+
exists=exists,
147+
replace=replace,
148+
columns_to_types=columns_to_types_from_view,
149+
table_description=table_description,
150+
**kwargs,
139151
)
140152

141-
if plan:
142-
select = exp.Select().from_(statement.expression.subquery("_subquery"))
143-
statement.expression.replace(select)
144-
145-
for target in plan["targetlist"]: # type: ignore
146-
if target["name"] == "TARGETENTRY":
147-
resdom = target["resdom"]
148-
resname = resdom["resname"]
149-
if resname == "<>":
150-
# A synthetic column added by Redshift to compute a window function.
151-
continue
152-
if resname == "? column ?":
153-
table_name_str = (
154-
table_name_or_schema
155-
if isinstance(table_name_or_schema, str)
156-
else table_name_or_schema.sql(dialect=self.dialect)
157-
)
158-
raise SQLMeshError(f"Missing column name for table '{table_name_str}'")
159-
# https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
160-
restype = resdom["restype"]
161-
data_type: t.Optional[str] = None
162-
if restype == "1043":
163-
size = (
164-
int(resdom["restypmod"]) - 4
165-
if resdom["restypmod"] != "- 1"
166-
else "MAX"
167-
)
168-
# Cast NULL instead of the original projection to trick the planner into assigning a
169-
# correct type to the column.
170-
data_type = f"VARCHAR({size})"
171-
else:
172-
data_type = REDSHIFT_PLAN_TYPE_MAPPINGS.get(restype)
173-
174-
if data_type:
175-
select.select(
176-
exp.cast(
177-
exp.null(),
178-
data_type,
179-
dialect=self.dialect,
180-
).as_(resname),
181-
copy=False,
182-
)
183-
else:
184-
select.select(resname, copy=False)
185-
186153
return statement
187154

188155
def create_view(
@@ -209,7 +176,7 @@ def create_view(
209176
materialized,
210177
table_description=table_description,
211178
column_descriptions=column_descriptions,
212-
no_schema_binding=True,
179+
no_schema_binding=create_kwargs.pop("no_schema_binding", True),
213180
**create_kwargs,
214181
)
215182

@@ -309,90 +276,3 @@ def _get_data_objects(
309276
)
310277
for row in df.itertuples()
311278
]
312-
313-
314-
def parse_plan(plan: str) -> t.Optional[t.Dict]:
315-
"""Parse the output of a redshift explain verbose query plan into a Python dict."""
316-
from sqlglot import Tokenizer, TokenType
317-
from sqlglot.tokens import Token
318-
319-
tokens = Tokenizer().tokenize(plan)
320-
i = 0
321-
terminal_tokens = {TokenType.L_PAREN, TokenType.R_PAREN, TokenType.R_BRACE, TokenType.COLON}
322-
323-
def curr() -> t.Optional[TokenType]:
324-
return tokens[i].token_type if i < len(tokens) else None
325-
326-
def advance() -> Token:
327-
nonlocal i
328-
i += 1
329-
return tokens[i - 1]
330-
331-
def match(token_type: TokenType, raise_unmatched: bool = False) -> t.Optional[Token]:
332-
if curr() == token_type:
333-
return advance()
334-
if raise_unmatched:
335-
raise Exception(f"Expected {token_type}")
336-
return None
337-
338-
def parse_value() -> t.Any:
339-
if match(TokenType.L_PAREN):
340-
values = []
341-
while not match(TokenType.R_PAREN):
342-
values.append(parse_value())
343-
return values
344-
345-
nested = parse_nested()
346-
347-
if nested:
348-
return nested
349-
350-
value = []
351-
352-
while not curr() in terminal_tokens:
353-
value.append(advance().text)
354-
355-
return " ".join(value)
356-
357-
def parse_nested() -> t.Optional[t.Dict]:
358-
if not match(TokenType.L_BRACE):
359-
return None
360-
query_plan = {}
361-
query_plan["name"] = advance().text
362-
363-
while match(TokenType.COLON):
364-
key = advance().text
365-
366-
while match(TokenType.DOT):
367-
key += f".{advance().text}"
368-
369-
query_plan[key] = parse_value()
370-
371-
match(TokenType.R_BRACE, True)
372-
return query_plan
373-
374-
while curr():
375-
nested = parse_nested()
376-
377-
if nested and nested.get("name") in ("RESULT", "SEQSCAN"):
378-
return nested
379-
advance()
380-
return None
381-
382-
383-
# https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
384-
REDSHIFT_PLAN_TYPE_MAPPINGS = {
385-
"16": "BOOL",
386-
"18": "CHAR",
387-
"21": "SMALLINT",
388-
"23": "INT",
389-
"20": "BIGINT",
390-
"1700": "NUMERIC",
391-
"700": "FLOAT",
392-
"701": "DOUBLE",
393-
"1114": "TIMESTAMP",
394-
"1184": "TIMESTAMPTZ",
395-
"1083": "TIME",
396-
"1266": "TIMETZ",
397-
"1082": "DATE",
398-
}

sqlmesh/core/snapshot/evaluator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,10 +1094,12 @@ def create(
10941094
column_descriptions=model.column_descriptions if is_table_deployable else None,
10951095
)
10961096

1097-
# only sql models have queries that can be tested
1098-
# additionally, we always create temp tables and sometimes
1099-
# we additionally created prod tables, so we only need to test one.
1100-
if model.is_sql and not is_table_deployable:
1097+
# Only sql models have queries that can be tested.
1098+
# Additionally, we always create temp tables and sometimes we additionally created prod tables,
1099+
# we need to make sure that we only dry run once.
1100+
# We also need to make sure that we don't dry run on Redshift because its planner / optimizer sometimes
1101+
# breaks on our CTAS queries due to us relying on the WHERE FALSE LIMIT 0 combo.
1102+
if model.is_sql and not is_table_deployable and self.adapter.dialect != "redshift":
11011103
logger.info("Dry running model '%s'", model.name)
11021104
self.adapter.fetchall(ctas_query)
11031105
else:

0 commit comments

Comments
 (0)