diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 77ea872..2710c06 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -117,6 +117,7 @@ class ExecuteSqlError(Exception): _handle_iam_params(sql_alchemy_dict) _handle_federated_auth_params(sql_alchemy_dict) + _handle_mssql_trusted_connection(sql_alchemy_dict) requires_bigquery_oauth = ( sql_alchemy_dict["url"] == "bigquery://?user_supplied_client=true" @@ -351,6 +352,38 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: ) +def _handle_mssql_trusted_connection(sql_alchemy_dict: dict[str, Any]) -> None: + """Handle MS SQL Server Windows Authentication (trusted connections) in-place. + + When 'mssql_trusted_connection' is True in params, removes username and password + from the connection URL to enable Windows Authentication via pymssql. + """ + params = sql_alchemy_dict.get("params", {}) + + if not params.get("mssql_trusted_connection"): + return + + url_obj = make_url(sql_alchemy_dict["url"]) + + if not url_obj.drivername.startswith("mssql"): + logger.warning( + "mssql_trusted_connection is only supported for MS SQL Server connections" + ) + return + + # Build new URL without username/password for Windows Authentication + new_url = URL.create( + drivername=url_obj.drivername, + host=url_obj.host, + port=url_obj.port, + database=url_obj.database, + query=url_obj.query, + ) + + sql_alchemy_dict["url"] = str(new_url) + del params["mssql_trusted_connection"] + + @contextlib.contextmanager def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): server = None diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index de46ba4..69b35f5 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -940,3 +940,88 @@ def test_federated_auth_params_bigquery_missing_params( # Verify the dict was not modified self.assertEqual(sql_alchemy_dict, original_dict) + + +class TestMssqlTrustedConnection(unittest.TestCase): + """Tests for MS SQL Server Windows Authentication (trusted connections).""" + + def test_trusted_connection_removes_credentials(self): + """Test that mssql_trusted_connection removes username and password from URL.""" + from deepnote_toolkit.sql.sql_execution import _handle_mssql_trusted_connection + + sql_alchemy_dict = { + "url": "mssql+pymssql://user:password@myserver:1433/mydb", + "params": {"mssql_trusted_connection": True}, + } + + _handle_mssql_trusted_connection(sql_alchemy_dict) + + # Verify credentials were removed from URL + self.assertIn("mssql+pymssql://", sql_alchemy_dict["url"]) + self.assertIn("myserver", sql_alchemy_dict["url"]) + self.assertIn("mydb", sql_alchemy_dict["url"]) + self.assertNotIn("user", sql_alchemy_dict["url"]) + self.assertNotIn("password", sql_alchemy_dict["url"]) + + # Verify the flag was removed from params + self.assertNotIn("mssql_trusted_connection", sql_alchemy_dict["params"]) + + def test_trusted_connection_preserves_query_params(self): + """Test that existing query parameters are preserved.""" + from deepnote_toolkit.sql.sql_execution import _handle_mssql_trusted_connection + + sql_alchemy_dict = { + "url": "mssql+pymssql://user:password@myserver/mydb?charset=utf8", + "params": {"mssql_trusted_connection": True}, + } + + _handle_mssql_trusted_connection(sql_alchemy_dict) + + # Verify query params are preserved + self.assertIn("charset=utf8", sql_alchemy_dict["url"]) + + def test_trusted_connection_not_enabled(self): + """Test that no action is taken when mssql_trusted_connection is not set.""" + from deepnote_toolkit.sql.sql_execution import _handle_mssql_trusted_connection + + sql_alchemy_dict = { + "url": "mssql+pymssql://user:password@myserver/mydb", + "params": {}, + } + original_url = sql_alchemy_dict["url"] + + _handle_mssql_trusted_connection(sql_alchemy_dict) + + self.assertEqual(sql_alchemy_dict["url"], original_url) + + def test_trusted_connection_false(self): + """Test that no action is taken when mssql_trusted_connection is False.""" + from deepnote_toolkit.sql.sql_execution import _handle_mssql_trusted_connection + + sql_alchemy_dict = { + "url": "mssql+pymssql://user:password@myserver/mydb", + "params": {"mssql_trusted_connection": False}, + } + original_url = sql_alchemy_dict["url"] + + _handle_mssql_trusted_connection(sql_alchemy_dict) + + self.assertEqual(sql_alchemy_dict["url"], original_url) + + @mock.patch("deepnote_toolkit.sql.sql_execution.logger") + def test_trusted_connection_non_mssql_url_warns(self, mock_logger): + """Test that a warning is logged for non-mssql URLs.""" + from deepnote_toolkit.sql.sql_execution import _handle_mssql_trusted_connection + + sql_alchemy_dict = { + "url": "postgresql://user:password@localhost/db", + "params": {"mssql_trusted_connection": True}, + } + original_url = sql_alchemy_dict["url"] + + _handle_mssql_trusted_connection(sql_alchemy_dict) + + mock_logger.warning.assert_called_once_with( + "mssql_trusted_connection is only supported for MS SQL Server connections" + ) + self.assertEqual(sql_alchemy_dict["url"], original_url)