Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlglot-integration-tests
3 changes: 3 additions & 0 deletions sqlglot/expressions/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,9 @@ class Pivot(Expression):
"default_on_null": False,
"into": False,
"with_": False,
"identify_pivot_strings": False,
"prefixed_pivot_columns": False,
"pivot_column_naming": False,
}

@property
Expand Down
77 changes: 77 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,77 @@ def tablesample_sql(

return f" {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}"

def _pivot_in_value_aliases(self, expression: exp.Pivot) -> list[exp.Expression] | None:
Comment thread
georgesittas marked this conversation as resolved.
# Returns the rewritten field.expressions list with PivotAlias wrappers injected where
# the stored column name differs from the target dialect's natural output.
columns = expression.args.get("columns")
if not columns or len(expression.fields) != 1:
return None

parser_cls = self.dialect.parser_class
# if the source and target emit identical values, exit early
if (
expression.args.get("identify_pivot_strings") == parser_cls.IDENTIFY_PIVOT_STRINGS
and expression.args.get("prefixed_pivot_columns") == parser_cls.PREFIXED_PIVOT_COLUMNS
and expression.args.get("pivot_column_naming") == parser_cls.PIVOT_COLUMN_NAMING
):
return None

in_exprs = expression.fields[0].expressions
step = len(columns) // len(in_exprs)

# Derive the per-value suffix from the first stored column vs the first IN-list value.
# This correctly handles dialects (e.g. Spark single-agg) that ignore agg aliases.
source_identify = expression.args.get("identify_pivot_strings", False)
first_base = in_exprs[0].sql() if source_identify else in_exprs[0].alias_or_name
first_stored = columns[0].name

# exit if only suffix matches, not prefix. (e.g. BigQuery, which cannot be fixed)
if not first_stored.lower().startswith(first_base.lower()):
# Should we emit an unsupported here?
return None
suffix = first_stored[len(first_base) :]

target_identify = parser_cls.IDENTIFY_PIVOT_STRINGS
target_naming = parser_cls.PIVOT_COLUMN_NAMING

# Whether the target dialect would append an agg-name suffix for this pivot.
# Spark single-agg uniquely drops the agg alias entirely.
target_has_suffix = (len(expression.expressions) > 1 or target_naming != "spark") and any(
getattr(a, "alias", None) for a in expression.expressions
)
source_has_suffix = suffix != ""

new_exprs: list[exp.Expression] = []
modified = False
for val_idx, e in enumerate(in_exprs):
i = val_idx * step
stored_full = columns[i].name
stored_value = stored_full[: -len(suffix)] if suffix else stored_full
target_value = e.sql() if target_identify else e.alias_or_name

if isinstance(e, exp.PivotAlias):
new_exprs.append(e)
continue

# Source had a suffix, target won't apply one (e.g. DuckDB→Spark single-agg
# aliased): inject the full stored column name as the IN-list alias so the
# target uses it verbatim as the column name.
if source_has_suffix and not target_has_suffix:
new_exprs.append(
exp.PivotAlias(this=e, alias=exp.to_identifier(stored_full, quoted=True))
)
modified = True
# Value-part mismatch (e.g. Snowflake's literal-style values vs others).
elif stored_value.lower() != target_value.lower():
new_exprs.append(
exp.PivotAlias(this=e, alias=exp.to_identifier(stored_value, quoted=True))
)
modified = True
Comment on lines +2519 to +2528
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is quoted set to True preemptively?

else:
new_exprs.append(e)
return new_exprs if modified else None

def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
direction = "UNPIVOT" if expression.unpivot else "PIVOT"
Expand All @@ -2478,6 +2549,12 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
sql = f"{direction} {this}{on}{into}{using}{group}"
return self.prepend_ctes(expression, sql)

if not expression.unpivot:
# Wrap IN-list values with explicit aliases where the target dialect would differ
new_field_exprs = self._pivot_in_value_aliases(expression)
if new_field_exprs is not None:
expression.fields[0].set("expressions", new_field_exprs)

alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""

Expand Down
4 changes: 4 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,7 @@ def _parse_partitioned_by_bucket_or_truncate(self) -> exp.Expr | None:

PREFIXED_PIVOT_COLUMNS: t.ClassVar = False
IDENTIFY_PIVOT_STRINGS: t.ClassVar = False
PIVOT_COLUMN_NAMING: t.ClassVar[str] = ""

LOG_DEFAULTS_TO_LN: t.ClassVar = False

Expand Down Expand Up @@ -5252,6 +5253,9 @@ def _parse_pivot(self) -> exp.Pivot | None:
columns.append(exp.to_identifier("_".join(fld_parts)))

pivot.set("columns", columns)
pivot.set("identify_pivot_strings", self.IDENTIFY_PIVOT_STRINGS)
pivot.set("prefixed_pivot_columns", self.PREFIXED_PIVOT_COLUMNS)
pivot.set("pivot_column_naming", self.PIVOT_COLUMN_NAMING)

return pivot

Expand Down
1 change: 1 addition & 0 deletions sqlglot/parsers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _convert_text_type(dtype: exp.DataType) -> exp.DataType:

class DuckDBParser(parser.Parser):
MAP_KEYS_ARE_ARBITRARY_EXPRESSIONS = True
PIVOT_COLUMN_NAMING = "duckdb"

NO_PAREN_FUNCTIONS = {
**parser.Parser.NO_PAREN_FUNCTIONS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/parsers/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def build_as_cast(to_type: str) -> t.Callable[[list], exp.Expr]:
class Spark2Parser(HiveParser):
TRIM_PATTERN_FIRST = True
CHANGE_COLUMN_ALTER_SYNTAX = True
PIVOT_COLUMN_NAMING = "spark"

FUNCTIONS = {
**HiveParser.FUNCTIONS,
Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,11 @@ def test_duckdb(self):
"SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))",
read={
"duckdb": "SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))",
},
)
self.validate_all(
"SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1' AS \"'Q1'\", 'Q2' AS \"'Q2'\"))",
read={
"snowflake": "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))",
},
)
Expand Down
6 changes: 3 additions & 3 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,19 +712,19 @@ def test_spark(self):
},
)
self.validate_all(
"SELECT piv.Q1 FROM (SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))) AS piv",
"SELECT piv.Q1 FROM (SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1' AS `'Q1'`, 'Q2' AS `'Q2'`))) AS piv",
read={
"snowflake": "SELECT piv.Q1 FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')) piv",
},
)
self.validate_all(
"SELECT piv.Q1 FROM (SELECT * FROM (SELECT * FROM produce) PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))) AS piv",
"SELECT piv.Q1 FROM (SELECT * FROM (SELECT * FROM produce) PIVOT(SUM(sales) FOR quarter IN ('Q1' AS `'Q1'`, 'Q2' AS `'Q2'`))) AS piv",
read={
"snowflake": "SELECT piv.Q1 FROM (SELECT * FROM produce) PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')) piv",
},
)
self.validate_all(
"SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))",
"SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1' AS `'Q1'`, 'Q2' AS `'Q2'`))",
read={
"snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))",
},
Expand Down
Loading