Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sqlglot/parsers/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ class DatabricksParser(SparkParser):
"CURDATE": lambda self: self._parse_curdate(),
}

FUNCTION_PARSERS = {
**SparkParser.FUNCTION_PARSERS,
"REGR_AVGX": lambda self: self._parse_regr(exp.RegrAvgx),
"REGR_AVGY": lambda self: self._parse_regr(exp.RegrAvgy),
"REGR_COUNT": lambda self: self._parse_regr(exp.RegrCount),
"REGR_INTERCEPT": lambda self: self._parse_regr(exp.RegrIntercept),
"REGR_R2": lambda self: self._parse_regr(exp.RegrR2),
"REGR_SLOPE": lambda self: self._parse_regr(exp.RegrSlope),
Comment on lines +35 to +41

@geooo109 geooo109 Jul 3, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you verify where the DISTINCT is applied on for each function of the REGR list ? (on 1-arg or on both args)

For example in REGR_AVGX , REGR_AVGY as it seems the distinct is applied on 1-arg (x and y respectively). On the other hand, forREGR_COUNT distinct is applied on both args (as a tuple). So, the parsing function should seperate the args based on this ^ and not seperate it for all the functions in the REGR_ list.

So, let's verify each function and parse accordingly.

}

FACTOR = {
**SparkParser.FACTOR,
TokenType.COLON: exp.JSONExtract,
Expand Down Expand Up @@ -59,3 +69,13 @@ def _parse_cluster_property(self):
if self._match_texts(("AUTO", "NONE")):
return self.expression(exp.ClusterProperty(this=self._prev.text.upper()))
return super()._parse_cluster_property()

def _parse_regr(self, expr_type: type[exp.AggFunc]) -> exp.AggFunc:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty similar to _parse_quantile_function of hive right ?

args: list[exp.Expr] = []
if self._match(TokenType.DISTINCT):
args.append(self.expression(exp.Distinct(expressions=[self._parse_lambda()])))
self._match(TokenType.COMMA)
else:
self._match(TokenType.ALL)
args.extend(self._parse_function_args())
return self.expression(expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)))
11 changes: 11 additions & 0 deletions sqlglot/typing/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@

EXPRESSION_METADATA = {
**EXPRESSION_METADATA,
**{
exp_type: {"returns": exp.DType.DOUBLE}
for exp_type in {
exp.RegrAvgx,
exp.RegrAvgy,
exp.RegrIntercept,
exp.RegrR2,
exp.RegrSlope,
}
},
**{
exp_type: {"returns": exp.DType.INT}
for exp_type in {
exp.RegexpCount,
}
},
exp.RegrCount: {"returns": exp.DType.BIGINT},
exp.RegexpExtractAll: {
"annotator": lambda self, e: self._set_type(
e, exp.DataType.from_str("ARRAY<STRING>", dialect="databricks")
Expand Down
104 changes: 104 additions & 0 deletions tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4219,6 +4219,10 @@ DOUBLE;
REGR_AVGX(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_AVGX(DISTINCT tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: snowflake
REGR_AVGY(tbl.double_col, tbl.double_col);
DOUBLE;
Expand All @@ -4231,6 +4235,26 @@ DOUBLE;
REGR_AVGY(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_AVGY(tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_AVGY(tbl.int_col, tbl.int_col);
DOUBLE;
Comment thread
geooo109 marked this conversation as resolved.

# dialect: databricks
REGR_AVGY(ALL tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_AVGY(DISTINCT tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_AVGY(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1);
DOUBLE;

# dialect: snowflake
REGR_COUNT(tbl.double_col, tbl.double_col);
DOUBLE;
Expand All @@ -4247,6 +4271,26 @@ DOUBLE;
REGR_COUNT(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_COUNT(tbl.double_col, tbl.double_col);
BIGINT;

# dialect: databricks
REGR_COUNT(tbl.int_col, tbl.int_col);
BIGINT;
Comment thread
geooo109 marked this conversation as resolved.

# dialect: databricks
REGR_COUNT(ALL tbl.double_col, tbl.double_col);
BIGINT;

# dialect: databricks
REGR_COUNT(DISTINCT tbl.double_col, tbl.double_col);
BIGINT;

# dialect: databricks
REGR_COUNT(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1);
BIGINT;

# dialect: snowflake
REGR_INTERCEPT(tbl.double_col, tbl.double_col);
DOUBLE;
Expand All @@ -4263,6 +4307,26 @@ DOUBLE;
REGR_INTERCEPT(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_INTERCEPT(tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_INTERCEPT(tbl.int_col, tbl.int_col);
DOUBLE;
Comment thread
geooo109 marked this conversation as resolved.

# dialect: databricks
REGR_INTERCEPT(ALL tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_INTERCEPT(DISTINCT tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_INTERCEPT(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1);
DOUBLE;

# dialect: snowflake
REGR_R2(tbl.double_col, tbl.double_col);
DOUBLE;
Expand All @@ -4279,6 +4343,26 @@ DOUBLE;
REGR_R2(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_R2(tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_R2(tbl.int_col, tbl.int_col);
DOUBLE;
Comment thread
geooo109 marked this conversation as resolved.

# dialect: databricks
REGR_R2(ALL tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_R2(DISTINCT tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_R2(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1);
DOUBLE;

# dialect: snowflake
REGR_SXX(tbl.double_col, tbl.double_col);
DOUBLE;
Expand Down Expand Up @@ -4343,6 +4427,26 @@ DOUBLE;
REGR_SLOPE(tbl.decfloat_col, tbl.decfloat_col);
DECFLOAT;

# dialect: databricks
REGR_SLOPE(tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_SLOPE(tbl.int_col, tbl.int_col);
DOUBLE;
Comment thread
geooo109 marked this conversation as resolved.

# dialect: databricks
REGR_SLOPE(ALL tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_SLOPE(DISTINCT tbl.double_col, tbl.double_col);
DOUBLE;

# dialect: databricks
REGR_SLOPE(tbl.double_col, tbl.double_col) OVER (PARTITION BY 1);
DOUBLE;

# dialect: snowflake
REGR_VALX(NULL, 2.0);
DOUBLE;
Expand Down
Loading