Skip to content
57 changes: 26 additions & 31 deletions sqlglot/typing/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,34 @@
from sqlglot.typing import ExprMetadataType


def _annotate_by_similar_args(
self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DType
) -> E:
def _annotate_by_similar_args(self: TypeAnnotator, expression: E, *arg_keys: str) -> E:
"""
Infers the type of the expression according to the following rules:
- If all args are of the same type OR any arg is of target_type, the expr is inferred as such
- If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN
"""
expressions: list[exp.Expr] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
Type inference for CONCAT-family expressions (CONCAT, LPAD, RPAD).

last_datatype = None
- All-BINARY → BINARY (the binary overload).
- Otherwise, if any arg has a known, non-array, non-binary type → STRING.
Spark coerces scalars (dates, ints, etc.) to string when mixed with a
string-resolving arg. The binary exclusion preserves the binary+unknown
case as UNKNOWN: Spark can't disambiguate the string vs. binary overload
there.
- Else → UNKNOWN. Covers all-unknown, binary+unknown, and anything
involving arrays (array handling is intentionally out of scope here).
"""
arg_exprs: list[exp.Expression] = []
for key in arg_keys:
arg_exprs.extend(e for e in ensure_list(expression.args.get(key)) if e)

has_unknown = False
for expr in expressions:
if expr.is_type(exp.DType.UNKNOWN):
has_unknown = True
elif expr.is_type(target_type):
has_unknown = False
last_datatype = target_type
break
else:
last_datatype = expr.type
if arg_exprs and all(e.is_type(exp.DType.BINARY) for e in arg_exprs):
result: exp.DataType | exp.DType = exp.DType.BINARY
elif any(
e.type is not None and not e.is_type(exp.DType.UNKNOWN, exp.DType.ARRAY, exp.DType.BINARY)
for e in arg_exprs
):
result = exp.DType.TEXT
else:
result = exp.DType.UNKNOWN

self._set_type(expression, exp.DType.UNKNOWN if has_unknown else last_datatype)
self._set_type(expression, result)
return expression


Expand Down Expand Up @@ -72,15 +73,9 @@ def _annotate_by_similar_args(
)
},
exp.AtTimeZone: {"returns": exp.DType.TIMESTAMP},
exp.Concat: {
"annotator": lambda self, e: _annotate_by_similar_args(
self, e, "expressions", target_type=exp.DType.TEXT
)
},
exp.Concat: {"annotator": lambda self, e: _annotate_by_similar_args(self, e, "expressions")},
exp.NextDay: {"returns": exp.DType.DATE},
exp.Pad: {
"annotator": lambda self, e: _annotate_by_similar_args(
self, e, "this", "fill_pattern", target_type=exp.DType.TEXT
)
"annotator": lambda self, e: _annotate_by_similar_args(self, e, "this", "fill_pattern")
},
}
20 changes: 20 additions & 0 deletions tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,26 @@ UNKNOWN;
CONCAT(unknown, unknown);
UNKNOWN;

# dialect: spark2, spark, databricks
CONCAT('x', tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT(tbl.date_col, tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT('x', tbl.bin_col);
STRING;

# dialect: spark2, spark, databricks
CONCAT(tbl.date_col, tbl.int_col);
STRING;

# dialect: spark2, spark, databricks
LPAD('x', 10, tbl.date_col);
STRING;

# dialect: spark2, spark, databricks
LPAD(tbl.bin_col, 1, tbl.bin_col);
BINARY;
Expand Down
Loading