Skip to content

Commit c6af34b

Browse files
authored
fix: db properly support with_log_level (#3799)
1 parent fd11150 commit c6af34b

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

sqlmesh/core/engine_adapter/databricks.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError
2222

2323
if t.TYPE_CHECKING:
24-
from sqlmesh.core._typing import SchemaName, TableName
24+
from sqlmesh.core._typing import SchemaName, TableName, SessionProperties
2525
from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query
2626

2727
logger = logging.getLogger(__name__)
@@ -48,11 +48,9 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
4848
},
4949
)
5050

51-
def __init__(self, *args: t.Any, **kwargs: t.Any):
51+
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
5252
super().__init__(*args, **kwargs)
5353
self._set_spark_engine_adapter_if_needed()
54-
# Set the default catalog for both connections to make sure they are aligned
55-
self.set_current_catalog(self.default_catalog) # type: ignore
5654

5755
@classmethod
5856
def can_access_spark_session(cls, disable_spark_session: bool) -> bool:
@@ -121,7 +119,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
121119
DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate()
122120
)
123121
self._spark_engine_adapter = SparkEngineAdapter(
124-
partial(connection, spark=spark, catalog=catalog)
122+
partial(connection, spark=spark, catalog=catalog),
123+
default_catalog=catalog,
125124
)
126125

127126
@property
@@ -149,6 +148,11 @@ def spark(self) -> PySparkSession:
149148
def catalog_support(self) -> CatalogSupport:
150149
return CatalogSupport.FULL_SUPPORT
151150

151+
def _begin_session(self, properties: SessionProperties) -> t.Any:
152+
"""Begin a new session."""
153+
# Align the different possible connectors to a single catalog
154+
self.set_current_catalog(self.default_catalog) # type: ignore
155+
152156
def _end_session(self) -> None:
153157
self._connection_pool.set_attribute("use_spark_engine_adapter", False)
154158

@@ -181,7 +185,7 @@ def _fetch_native_df(
181185
"""Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
182186
if self.is_spark_session_connection:
183187
return super()._fetch_native_df(query, quote_identifiers=quote_identifiers)
184-
if self._use_spark_session:
188+
if self._spark_engine_adapter:
185189
return self._spark_engine_adapter._fetch_native_df( # type: ignore
186190
query, quote_identifiers=quote_identifiers
187191
)
@@ -211,6 +215,8 @@ def get_current_catalog(self) -> t.Optional[str]:
211215
pyspark_catalog = self._spark_engine_adapter.get_current_catalog()
212216
except (Py4JError, SparkConnectGrpcException):
213217
pass
218+
elif self.is_spark_session_connection:
219+
pyspark_catalog = self.connection.spark.catalog.currentCatalog()
214220
if not self.is_spark_session_connection:
215221
result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION))
216222
sql_connector_catalog = result[0] if result else None
@@ -221,20 +227,26 @@ def get_current_catalog(self) -> t.Optional[str]:
221227
return pyspark_catalog or sql_connector_catalog
222228

223229
def set_current_catalog(self, catalog_name: str) -> None:
224-
# Since Databricks splits commands across the Dataframe API and the SQL Connector
225-
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
226-
# are set to the same catalog since they maintain their default catalog separately
227-
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
228-
if self._use_spark_session:
230+
def _set_spark_session_current_catalog(spark: PySparkSession) -> None:
229231
from py4j.protocol import Py4JError
230232
from pyspark.errors.exceptions.connect import SparkConnectGrpcException
231233

232234
try:
233235
# Note: Spark 3.4+ Only API
234-
self._spark_engine_adapter.set_current_catalog(catalog_name) # type: ignore
236+
spark.catalog.setCurrentCatalog(catalog_name)
235237
except (Py4JError, SparkConnectGrpcException):
236238
pass
237239

240+
# Since Databricks splits commands across the Dataframe API and the SQL Connector
241+
# (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
242+
# are set to the same catalog since they maintain their default catalog separately
243+
self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG"))
244+
if self.is_spark_session_connection:
245+
_set_spark_session_current_catalog(self.connection.spark)
246+
247+
if self._spark_engine_adapter:
248+
_set_spark_session_current_catalog(self._spark_engine_adapter.spark)
249+
238250
def _get_data_objects(
239251
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
240252
) -> t.List[DataObject]:

tests/core/engine_adapter/test_databricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.
103103
adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog")
104104
adapter.set_current_catalog("test_catalog2")
105105

106-
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog`", "USE CATALOG `test_catalog2`"]
106+
assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"]
107107

108108

109109
def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable):

0 commit comments

Comments
 (0)