diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 009573d910..50c242056f 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4400,10 +4400,13 @@ def _parse_column_as_identifier() -> exp.Expr | None: return self._parse_wrapped_csv(_parse_column_as_identifier, optional=True) def _parse_join( - self, skip_join_token: bool = False, parse_bracket: bool = False + self, + skip_join_token: bool = False, + parse_bracket: bool = False, + alias_tokens: t.Collection[TokenType] | None = None, ) -> exp.Join | None: if self._match(TokenType.COMMA): - table = self._try_parse(self._parse_table) + table = self._try_parse(lambda: self._parse_table(alias_tokens=alias_tokens)) cross_join = self.expression(exp.Join(this=table)) if table else None if cross_join and self.JOINS_HAVE_EQUAL_PRECEDENCE: @@ -4430,10 +4433,12 @@ def _parse_join( if not skip_join_token and not join and not outer_apply and not cross_apply: return None - kwargs: dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} + kwargs: dict[str, t.Any] = { + "this": self._parse_table(parse_bracket=parse_bracket, alias_tokens=alias_tokens) + } if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA): kwargs["expressions"] = self._parse_csv( - lambda: self._parse_table(parse_bracket=parse_bracket) + lambda: self._parse_table(parse_bracket=parse_bracket, alias_tokens=alias_tokens) ) if method: @@ -4459,7 +4464,7 @@ def _parse_join( and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY)) ): index = self._index - joins: list | None = list(self._parse_joins()) + joins: list | None = list(self._parse_joins(alias_tokens=alias_tokens)) if joins and self._match(TokenType.ON): kwargs["on"] = self._parse_disjunction() @@ -4860,7 +4865,7 @@ def _parse_table( this.set("version", self._parse_version()) if joins: - for join in self._parse_joins(): + for join in self._parse_joins(alias_tokens=alias_tokens): this.append("joins", join) if self._match_pair(TokenType.WITH, TokenType.ORDINALITY): @@ -5056,8 +5061,10 @@ def _parse_pivots(self) -> list[exp.Pivot] | None: return None return list(iter(self._parse_pivot, None)) or None - def _parse_joins(self) -> t.Iterator[exp.Join]: - return iter(self._parse_join, None) + def _parse_joins( + self, alias_tokens: t.Collection[TokenType] | None = None + ) -> t.Iterator[exp.Join]: + return iter(lambda: self._parse_join(alias_tokens=alias_tokens), None) def _parse_unpivot_columns(self) -> exp.UnpivotColumns | None: if not self._match(TokenType.INTO): diff --git a/sqlglot/parsers/clickhouse.py b/sqlglot/parsers/clickhouse.py index 8d72bbfeeb..fbe22e4555 100644 --- a/sqlglot/parsers/clickhouse.py +++ b/sqlglot/parsers/clickhouse.py @@ -663,9 +663,14 @@ def _parse_join_parts( return is_global, side or kind, kind_pre or kind def _parse_join( - self, skip_join_token: bool = False, parse_bracket: bool = False + self, + skip_join_token: bool = False, + parse_bracket: bool = False, + alias_tokens: t.Collection[TokenType] | None = None, ) -> exp.Join | None: - join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True) + join = super()._parse_join( + skip_join_token=skip_join_token, parse_bracket=True, alias_tokens=alias_tokens + ) if join: method = join.args.get("method") join.set("method", None) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 3c9ac41ce2..b1eb4f455a 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -244,6 +244,13 @@ def test_ddl(self): self.validate_identity( "UPDATE foo JOIN bar ON TRUE SET foo.a = bar.a WHERE foo.id = bar.id" ) + self.validate_identity( + "UPDATE items, month SET items.price = month.price WHERE items.id = month.id" + ) + self.validate_identity("UPDATE a CROSS JOIN b SET a.x = 1") + self.validate_identity( + "UPDATE a, b LEFT JOIN c ON b.id = c.id SET a.x = 1, b.y = 2, c.z = 3" + ) # PARTITION BY RANGE - simple column self.validate_identity(