diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 44291626d0..ce3b9c411f 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -9,6 +9,43 @@ from sqlglot.parsers.databricks import DatabricksParser from sqlglot.tokens import TokenType from sqlglot.optimizer.annotate_types import TypeAnnotator +from sqlglot.typing.spark import EXPRESSION_METADATA as SPARK_EXPRESSION_METADATA + + +def _string_promotes(values: list[exp.Expression]) -> bool: + """ + Whether a least-common-type function string-promotes given its value arguments. + + Databricks resolves COALESCE/IF/CASE via Spark's findWiderCommonType string + promotion: when an argument is text and the rest are non-boolean/non-binary + atomics, the common type is text. boolean+text and binary+text have no common + type (query-time DATATYPE_MISMATCH), so we defer those to numeric widening. + """ + return any(v.is_type(*exp.DataType.TEXT_TYPES) for v in values) and not any( + v.is_type(exp.DType.BOOLEAN, exp.DType.BINARY) for v in values + ) + + +def _annotate_coalesce(self: TypeAnnotator, e: exp.Coalesce) -> exp.Coalesce: + if _string_promotes([v for v in (e.this, *e.expressions) if v]): + self._set_type(e, exp.DType.TEXT) + return e + return self._annotate_by_args(e, "this", "expressions", promote=True) + + +def _annotate_if(self: TypeAnnotator, e: exp.If) -> exp.If: + if _string_promotes([v for v in (e.args.get("true"), e.args.get("false")) if v]): + self._set_type(e, exp.DType.TEXT) + return e + return self._annotate_by_args(e, "true", "false", promote=True) + + +def _annotate_case(self: TypeAnnotator, e: exp.Case) -> exp.Case: + thens = [if_expr.args["true"] for if_expr in e.args["ifs"]] + if _string_promotes([v for v in (*thens, e.args.get("default")) if v]): + self._set_type(e, exp.DType.TEXT) + return e + return self._annotate_by_args(e, *thens, "default") class Databricks(Spark): @@ -25,6 +62,13 @@ class Databricks(Spark): exp.DType.INTERVAL, } + EXPRESSION_METADATA = { + **SPARK_EXPRESSION_METADATA, + exp.Coalesce: {"annotator": _annotate_coalesce}, + exp.If: {"annotator": _annotate_if}, + exp.Case: {"annotator": _annotate_case}, + } + class JSONPathTokenizer(Spark.JSONPathTokenizer): IDENTIFIERS = ["`", '"'] diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 4487444e34..2cef6f271f 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -307,11 +307,11 @@ STRING; # dialect: databricks IF(cond, tbl.str_col, tbl.double_col); -DOUBLE; +STRING; # dialect: databricks IF(cond, tbl.double_col, tbl.str_col); -DOUBLE; +STRING; # dialect: hive, spark2, spark IF(cond, tbl.date_col, tbl.str_col); @@ -323,11 +323,11 @@ STRING; # dialect: databricks IF(cond, tbl.date_col, tbl.str_col); -DATE; +STRING; # dialect: databricks IF(cond, tbl.str_col, tbl.date_col); -DATE; +STRING; # dialect: hive, spark2, spark, databricks IF(cond, tbl.date_col, tbl.timestamp_col); @@ -371,19 +371,19 @@ STRING; # dialect: databricks COALESCE(tbl.str_col, tbl.bigint_col); -BIGINT; +STRING; # dialect: databricks COALESCE(tbl.bigint_col, tbl.str_col); -BIGINT; +STRING; # dialect: databricks COALESCE(tbl.str_col, NULL, tbl.bigint_col); -BIGINT; +STRING; # dialect: databricks COALESCE(tbl.bigint_col, NULL, tbl.str_col); -BIGINT; +STRING; # dialect: databricks COALESCE(tbl.bool_col, tbl.str_col); @@ -395,12 +395,28 @@ STRING; # dialect: databricks COALESCE(tbl.interval_col, tbl.str_col); -INTERVAL; +STRING; # dialect: databricks COALESCE(tbl.bin_col, tbl.str_col); BINARY; +# dialect: databricks +COALESCE(tbl.int_col, tbl.str_col); +STRING; + +# dialect: databricks +NVL(tbl.int_col, tbl.str_col); +STRING; + +# dialect: databricks +CASE WHEN cond THEN tbl.int_col ELSE tbl.str_col END; +STRING; + +# dialect: databricks +COALESCE(tbl.int_col, tbl.bigint_col); +BIGINT; + # dialect: spark, databricks LOCALTIMESTAMP(); TIMESTAMPNTZ;