diff --git a/sqlglot/parsers/databricks.py b/sqlglot/parsers/databricks.py index 2805c1d933..345f5fc310 100644 --- a/sqlglot/parsers/databricks.py +++ b/sqlglot/parsers/databricks.py @@ -31,6 +31,16 @@ class DatabricksParser(SparkParser): "CURDATE": lambda self: self._parse_curdate(), } + FUNCTION_PARSERS = { + **SparkParser.FUNCTION_PARSERS, + "REGR_AVGX": lambda self: self._parse_regr(exp.RegrAvgx), + "REGR_AVGY": lambda self: self._parse_regr(exp.RegrAvgy), + "REGR_COUNT": lambda self: self._parse_regr(exp.RegrCount), + "REGR_INTERCEPT": lambda self: self._parse_regr(exp.RegrIntercept), + "REGR_R2": lambda self: self._parse_regr(exp.RegrR2), + "REGR_SLOPE": lambda self: self._parse_regr(exp.RegrSlope), + } + FACTOR = { **SparkParser.FACTOR, TokenType.COLON: exp.JSONExtract, @@ -59,3 +69,13 @@ def _parse_cluster_property(self): if self._match_texts(("AUTO", "NONE")): return self.expression(exp.ClusterProperty(this=self._prev.text.upper())) return super()._parse_cluster_property() + + def _parse_regr(self, expr_type: type[exp.AggFunc]) -> exp.AggFunc: + args: list[exp.Expr] = [] + if self._match(TokenType.DISTINCT): + args.append(self.expression(exp.Distinct(expressions=[self._parse_lambda()]))) + self._match(TokenType.COMMA) + else: + self._match(TokenType.ALL) + args.extend(self._parse_function_args()) + return self.expression(expr_type(this=seq_get(args, 0), expression=seq_get(args, 1))) diff --git a/sqlglot/typing/databricks.py b/sqlglot/typing/databricks.py index aa0f35147d..77b6c3e181 100644 --- a/sqlglot/typing/databricks.py +++ b/sqlglot/typing/databricks.py @@ -5,12 +5,23 @@ EXPRESSION_METADATA = { **EXPRESSION_METADATA, + **{ + exp_type: {"returns": exp.DType.DOUBLE} + for exp_type in { + exp.RegrAvgx, + exp.RegrAvgy, + exp.RegrIntercept, + exp.RegrR2, + exp.RegrSlope, + } + }, **{ exp_type: {"returns": exp.DType.INT} for exp_type in { exp.RegexpCount, } }, + exp.RegrCount: {"returns": exp.DType.BIGINT}, exp.RegexpExtractAll: { "annotator": lambda self, e: self._set_type( e, exp.DataType.from_str("ARRAY", dialect="databricks") diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 271d8fff9e..88d7b3349b 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -4219,6 +4219,10 @@ DOUBLE; REGR_AVGX(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_AVGX(DISTINCT tbl.double_col, tbl.double_col); +DOUBLE; + # dialect: snowflake REGR_AVGY(tbl.double_col, tbl.double_col); DOUBLE; @@ -4231,6 +4235,26 @@ DOUBLE; REGR_AVGY(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_AVGY(tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_AVGY(tbl.int_col, tbl.int_col); +DOUBLE; + +# dialect: databricks +REGR_AVGY(ALL tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_AVGY(DISTINCT tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_AVGY(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1); +DOUBLE; + # dialect: snowflake REGR_COUNT(tbl.double_col, tbl.double_col); DOUBLE; @@ -4247,6 +4271,26 @@ DOUBLE; REGR_COUNT(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_COUNT(tbl.double_col, tbl.double_col); +BIGINT; + +# dialect: databricks +REGR_COUNT(tbl.int_col, tbl.int_col); +BIGINT; + +# dialect: databricks +REGR_COUNT(ALL tbl.double_col, tbl.double_col); +BIGINT; + +# dialect: databricks +REGR_COUNT(DISTINCT tbl.double_col, tbl.double_col); +BIGINT; + +# dialect: databricks +REGR_COUNT(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1); +BIGINT; + # dialect: snowflake REGR_INTERCEPT(tbl.double_col, tbl.double_col); DOUBLE; @@ -4263,6 +4307,26 @@ DOUBLE; REGR_INTERCEPT(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_INTERCEPT(tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_INTERCEPT(tbl.int_col, tbl.int_col); +DOUBLE; + +# dialect: databricks +REGR_INTERCEPT(ALL tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_INTERCEPT(DISTINCT tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_INTERCEPT(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1); +DOUBLE; + # dialect: snowflake REGR_R2(tbl.double_col, tbl.double_col); DOUBLE; @@ -4279,6 +4343,26 @@ DOUBLE; REGR_R2(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_R2(tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_R2(tbl.int_col, tbl.int_col); +DOUBLE; + +# dialect: databricks +REGR_R2(ALL tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_R2(DISTINCT tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_R2(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1); +DOUBLE; + # dialect: snowflake REGR_SXX(tbl.double_col, tbl.double_col); DOUBLE; @@ -4343,6 +4427,26 @@ DOUBLE; REGR_SLOPE(tbl.decfloat_col, tbl.decfloat_col); DECFLOAT; +# dialect: databricks +REGR_SLOPE(tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_SLOPE(tbl.int_col, tbl.int_col); +DOUBLE; + +# dialect: databricks +REGR_SLOPE(ALL tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_SLOPE(DISTINCT tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: databricks +REGR_SLOPE(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1); +DOUBLE; + # dialect: snowflake REGR_VALX(NULL, 2.0); DOUBLE;