2121from sqlmesh .utils .errors import SQLMeshError , MissingDefaultCatalogError
2222
2323if t .TYPE_CHECKING :
24- from sqlmesh .core ._typing import SchemaName , TableName
24+ from sqlmesh .core ._typing import SchemaName , TableName , SessionProperties
2525 from sqlmesh .core .engine_adapter ._typing import DF , PySparkSession , Query
2626
2727logger = logging .getLogger (__name__ )
@@ -48,11 +48,9 @@ class DatabricksEngineAdapter(SparkEngineAdapter):
4848 },
4949 )
5050
51- def __init__ (self , * args : t .Any , ** kwargs : t .Any ):
51+ def __init__ (self , * args : t .Any , ** kwargs : t .Any ) -> None :
5252 super ().__init__ (* args , ** kwargs )
5353 self ._set_spark_engine_adapter_if_needed ()
54- # Set the default catalog for both connections to make sure they are aligned
55- self .set_current_catalog (self .default_catalog ) # type: ignore
5654
5755 @classmethod
5856 def can_access_spark_session (cls , disable_spark_session : bool ) -> bool :
@@ -121,7 +119,8 @@ def _set_spark_engine_adapter_if_needed(self) -> None:
121119 DatabricksSession .builder .remote (** connect_kwargs ).userAgent ("sqlmesh" ).getOrCreate ()
122120 )
123121 self ._spark_engine_adapter = SparkEngineAdapter (
124- partial (connection , spark = spark , catalog = catalog )
122+ partial (connection , spark = spark , catalog = catalog ),
123+ default_catalog = catalog ,
125124 )
126125
127126 @property
@@ -149,6 +148,11 @@ def spark(self) -> PySparkSession:
149148 def catalog_support (self ) -> CatalogSupport :
150149 return CatalogSupport .FULL_SUPPORT
151150
151+ def _begin_session (self , properties : SessionProperties ) -> t .Any :
152+ """Begin a new session."""
153+ # Align the different possible connectors to a single catalog
154+ self .set_current_catalog (self .default_catalog ) # type: ignore
155+
152156 def _end_session (self ) -> None :
153157 self ._connection_pool .set_attribute ("use_spark_engine_adapter" , False )
154158
@@ -181,7 +185,7 @@ def _fetch_native_df(
181185 """Fetches a DataFrame that can be either Pandas or PySpark from the cursor"""
182186 if self .is_spark_session_connection :
183187 return super ()._fetch_native_df (query , quote_identifiers = quote_identifiers )
184- if self ._use_spark_session :
188+ if self ._spark_engine_adapter :
185189 return self ._spark_engine_adapter ._fetch_native_df ( # type: ignore
186190 query , quote_identifiers = quote_identifiers
187191 )
@@ -211,6 +215,8 @@ def get_current_catalog(self) -> t.Optional[str]:
211215 pyspark_catalog = self ._spark_engine_adapter .get_current_catalog ()
212216 except (Py4JError , SparkConnectGrpcException ):
213217 pass
218+ elif self .is_spark_session_connection :
219+ pyspark_catalog = self .connection .spark .catalog .currentCatalog ()
214220 if not self .is_spark_session_connection :
215221 result = self .fetchone (exp .select (self .CURRENT_CATALOG_EXPRESSION ))
216222 sql_connector_catalog = result [0 ] if result else None
@@ -221,20 +227,26 @@ def get_current_catalog(self) -> t.Optional[str]:
221227 return pyspark_catalog or sql_connector_catalog
222228
223229 def set_current_catalog (self , catalog_name : str ) -> None :
224- # Since Databricks splits commands across the Dataframe API and the SQL Connector
225- # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
226- # are set to the same catalog since they maintain their default catalog separately
227- self .execute (exp .Use (this = exp .to_identifier (catalog_name ), kind = "CATALOG" ))
228- if self ._use_spark_session :
230+ def _set_spark_session_current_catalog (spark : PySparkSession ) -> None :
229231 from py4j .protocol import Py4JError
230232 from pyspark .errors .exceptions .connect import SparkConnectGrpcException
231233
232234 try :
233235 # Note: Spark 3.4+ Only API
234- self . _spark_engine_adapter . set_current_catalog (catalog_name ) # type: ignore
236+ spark . catalog . setCurrentCatalog (catalog_name )
235237 except (Py4JError , SparkConnectGrpcException ):
236238 pass
237239
240+ # Since Databricks splits commands across the Dataframe API and the SQL Connector
241+ # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both
242+ # are set to the same catalog since they maintain their default catalog separately
243+ self .execute (exp .Use (this = exp .to_identifier (catalog_name ), kind = "CATALOG" ))
244+ if self .is_spark_session_connection :
245+ _set_spark_session_current_catalog (self .connection .spark )
246+
247+ if self ._spark_engine_adapter :
248+ _set_spark_session_current_catalog (self ._spark_engine_adapter .spark )
249+
238250 def _get_data_objects (
239251 self , schema_name : SchemaName , object_names : t .Optional [t .Set [str ]] = None
240252 ) -> t .List [DataObject ]:
0 commit comments