diff --git a/sqlglot/typing/spark2.py b/sqlglot/typing/spark2.py index c734a740bb..9def31ec7a 100644 --- a/sqlglot/typing/spark2.py +++ b/sqlglot/typing/spark2.py @@ -12,33 +12,34 @@ from sqlglot.typing import ExprMetadataType -def _annotate_by_similar_args( - self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DType -) -> E: +def _annotate_by_similar_args(self: TypeAnnotator, expression: E, *arg_keys: str) -> E: """ - Infers the type of the expression according to the following rules: - - If all args are of the same type OR any arg is of target_type, the expr is inferred as such - - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN - """ - expressions: list[exp.Expr] = [] - for arg in args: - arg_expr = expression.args.get(arg) - expressions.extend(expr for expr in ensure_list(arg_expr) if expr) + Type inference for CONCAT-family expressions (CONCAT, LPAD, RPAD). - last_datatype = None + - All-BINARY → BINARY (the binary overload). + - Otherwise, if any arg has a known, non-array, non-binary type → STRING. + Spark coerces scalars (dates, ints, etc.) to string when mixed with a + string-resolving arg. The binary exclusion preserves the binary+unknown + case as UNKNOWN: Spark can't disambiguate the string vs. binary overload + there. + - Else → UNKNOWN. Covers all-unknown, binary+unknown, and anything + involving arrays (array handling is intentionally out of scope here). + """ + arg_exprs: list[exp.Expression] = [] + for key in arg_keys: + arg_exprs.extend(e for e in ensure_list(expression.args.get(key)) if e) - has_unknown = False - for expr in expressions: - if expr.is_type(exp.DType.UNKNOWN): - has_unknown = True - elif expr.is_type(target_type): - has_unknown = False - last_datatype = target_type - break - else: - last_datatype = expr.type + if arg_exprs and all(e.is_type(exp.DType.BINARY) for e in arg_exprs): + result: exp.DataType | exp.DType = exp.DType.BINARY + elif any( + e.type is not None and not e.is_type(exp.DType.UNKNOWN, exp.DType.ARRAY, exp.DType.BINARY) + for e in arg_exprs + ): + result = exp.DType.TEXT + else: + result = exp.DType.UNKNOWN - self._set_type(expression, exp.DType.UNKNOWN if has_unknown else last_datatype) + self._set_type(expression, result) return expression @@ -72,15 +73,9 @@ def _annotate_by_similar_args( ) }, exp.AtTimeZone: {"returns": exp.DType.TIMESTAMP}, - exp.Concat: { - "annotator": lambda self, e: _annotate_by_similar_args( - self, e, "expressions", target_type=exp.DType.TEXT - ) - }, + exp.Concat: {"annotator": lambda self, e: _annotate_by_similar_args(self, e, "expressions")}, exp.NextDay: {"returns": exp.DType.DATE}, exp.Pad: { - "annotator": lambda self, e: _annotate_by_similar_args( - self, e, "this", "fill_pattern", target_type=exp.DType.TEXT - ) + "annotator": lambda self, e: _annotate_by_similar_args(self, e, "this", "fill_pattern") }, } diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 8799b883fc..d71e13bc1a 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -257,6 +257,26 @@ UNKNOWN; CONCAT(unknown, unknown); UNKNOWN; +# dialect: spark2, spark, databricks +CONCAT('x', tbl.date_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.date_col, tbl.date_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT('x', tbl.bin_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.date_col, tbl.int_col); +STRING; + +# dialect: spark2, spark, databricks +LPAD('x', 10, tbl.date_col); +STRING; + # dialect: spark2, spark, databricks LPAD(tbl.bin_col, 1, tbl.bin_col); BINARY;