Skip to content
Closed
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
### Internal Changes
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
* Add `host_type` to `HostMetadata` and `HostType.from_api_value()` for normalizing host type strings from the discovery endpoint.

### API Changes
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.
Expand Down
19 changes: 19 additions & 0 deletions databricks/sdk/client_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Optional


class HostType(Enum):
Expand All @@ -8,6 +9,24 @@ class HostType(Enum):
WORKSPACE = "workspace"
UNIFIED = "unified"

@staticmethod
def from_api_value(value: str) -> Optional["HostType"]:
"""Normalize a host_type string from the API to a HostType enum value.

Maps "workspace" -> WORKSPACE, "account" -> ACCOUNTS, "unified" -> UNIFIED.
Returns None for unrecognized or empty values.
"""
if not value:
return None
normalized = value.lower()
if normalized == "workspace":
return HostType.WORKSPACE
if normalized == "account":
return HostType.ACCOUNTS
if normalized == "unified":
return HostType.UNIFIED
return None


class ClientType(Enum):
"""Enum representing the type of client configuration."""
Expand Down
13 changes: 12 additions & 1 deletion databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def __init__(
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
self._resolved_host_type = None
self._custom_headers = custom_headers or {}
if credentials_strategy and credentials_provider:
raise ValueError("When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
Expand Down Expand Up @@ -655,10 +656,20 @@ def _resolve_host_metadata(self) -> None:
if not self.cloud and meta.cloud:
logger.debug(f"Resolved cloud from host metadata: {meta.cloud.value}")
self.cloud = meta.cloud
if self._resolved_host_type is None and meta.host_type:
resolved = HostType.from_api_value(meta.host_type)
if resolved is not None:
logger.debug(f"Resolved host_type from host metadata: {meta.host_type}")
self._resolved_host_type = resolved
if not self.token_audience and meta.token_federation_default_oidc_audiences:
audience = meta.token_federation_default_oidc_audiences[0]
logger.debug(
f"Resolved token_audience from host metadata token_federation_default_oidc_audiences: {audience}"
)
self.token_audience = audience
# Account hosts use account_id as the OIDC token audience instead of the token endpoint.
# This is a special case: when the metadata has no workspace_id, the host is acting as an
# account-level endpoint and the audience must be scoped to the account.
# TODO: Add explicit audience to the metadata discovery endpoint.
if not self.token_audience and not meta.workspace_id and self.account_id:
logger.debug(f"Setting token_audience to account_id for account host: {self.account_id}")
self.token_audience = self.account_id
Expand Down
6 changes: 6 additions & 0 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ class HostMetadata:
account_id: Optional[str] = None
workspace_id: Optional[str] = None
cloud: Optional[Cloud] = None
host_type: Optional[str] = None
token_federation_default_oidc_audiences: Optional[List[str]] = None

@staticmethod
def from_dict(d: dict) -> "HostMetadata":
Expand All @@ -456,6 +458,8 @@ def from_dict(d: dict) -> "HostMetadata":
account_id=d.get("account_id"),
workspace_id=d.get("workspace_id"),
cloud=Cloud.parse(d.get("cloud", "")),
host_type=d.get("host_type"),
token_federation_default_oidc_audiences=d.get("token_federation_default_oidc_audiences"),
)

def as_dict(self) -> dict:
Expand All @@ -464,6 +468,8 @@ def as_dict(self) -> dict:
"account_id": self.account_id,
"workspace_id": self.workspace_id,
"cloud": self.cloud.value if self.cloud else None,
"host_type": self.host_type,
"token_federation_default_oidc_audiences": self.token_federation_default_oidc_audiences,
}


Expand Down
190 changes: 190 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,3 +973,193 @@ def test_resolve_host_metadata_does_not_overwrite_token_audience(mocker):
token_audience="custom-audience",
)
assert config.token_audience == "custom-audience"


# ---------------------------------------------------------------------------
# HostType.from_api_value tests
# ---------------------------------------------------------------------------


@pytest.mark.parametrize(
"api_value,expected",
[
("workspace", HostType.WORKSPACE),
("Workspace", HostType.WORKSPACE),
("WORKSPACE", HostType.WORKSPACE),
("account", HostType.ACCOUNTS),
("Account", HostType.ACCOUNTS),
("ACCOUNT", HostType.ACCOUNTS),
("unified", HostType.UNIFIED),
("Unified", HostType.UNIFIED),
("UNIFIED", HostType.UNIFIED),
("unknown", None),
("", None),
(None, None),
],
)
def test_host_type_from_api_value(api_value, expected):
assert HostType.from_api_value(api_value) == expected


# ---------------------------------------------------------------------------
# HostMetadata.from_dict with host_type field
# ---------------------------------------------------------------------------


def test_host_metadata_from_dict_with_host_type():
"""HostMetadata.from_dict parses the host_type field."""
d = {
"oidc_endpoint": "https://host/oidc",
"account_id": "acc-1",
"host_type": "workspace",
}
meta = HostMetadata.from_dict(d)
assert meta.host_type == "workspace"


def test_host_metadata_from_dict_without_host_type():
"""HostMetadata.from_dict returns None for missing host_type."""
d = {"oidc_endpoint": "https://host/oidc"}
meta = HostMetadata.from_dict(d)
assert meta.host_type is None


# ---------------------------------------------------------------------------
# _resolve_host_metadata populates _resolved_host_type
# ---------------------------------------------------------------------------


def test_resolve_host_metadata_populates_resolved_host_type(mocker):
"""_resolved_host_type is populated from metadata host_type."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"host_type": "unified",
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t")
assert config._resolved_host_type == HostType.UNIFIED


def test_resolve_host_metadata_does_not_overwrite_existing_resolved_host_type(mocker):
"""An already-set _resolved_host_type is not overwritten by metadata."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"host_type": "account",
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t")
# Manually set resolved host type then re-resolve
config._resolved_host_type = HostType.UNIFIED
config._resolve_host_metadata()
assert config._resolved_host_type == HostType.UNIFIED


def test_resolve_host_metadata_resolved_host_type_none_when_missing(mocker):
"""_resolved_host_type stays None when metadata has no host_type."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t")
assert config._resolved_host_type is None


def test_resolve_host_metadata_resolved_host_type_unknown_string(mocker):
"""_resolved_host_type stays None for unrecognized host_type strings."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"host_type": "some_future_type",
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t")
assert config._resolved_host_type is None


# ---------------------------------------------------------------------------
# token_federation_default_oidc_audiences resolution from host metadata
# ---------------------------------------------------------------------------


def test_resolve_host_metadata_sets_token_audience_from_token_federation_default_oidc_audiences(mocker):
"""token_audience is set from token_federation_default_oidc_audiences in host metadata."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"account_id": _DUMMY_ACCOUNT_ID,
"workspace_id": _DUMMY_WORKSPACE_ID,
"token_federation_default_oidc_audiences": [f"{_DUMMY_WS_HOST}/oidc/v1/token"],
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t")
assert config.token_audience == f"{_DUMMY_WS_HOST}/oidc/v1/token"


def test_resolve_host_metadata_token_federation_default_oidc_audiences_takes_priority_over_account_id_fallback(mocker):
"""token_federation_default_oidc_audiences takes priority over the account_id fallback."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_ACC_HOST}/oidc/accounts/{_DUMMY_ACCOUNT_ID}",
"account_id": _DUMMY_ACCOUNT_ID,
"token_federation_default_oidc_audiences": ["custom-audience-from-server"],
}
),
)
config = Config(host=_DUMMY_ACC_HOST, token="t", account_id=_DUMMY_ACCOUNT_ID)
# token_federation_default_oidc_audiences should take priority over the account_id fallback
assert config.token_audience == "custom-audience-from-server"


def test_resolve_host_metadata_token_federation_default_oidc_audiences_does_not_override_existing_token_audience(
mocker,
):
"""An explicitly set token_audience is not overwritten by token_federation_default_oidc_audiences."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_WS_HOST}/oidc",
"account_id": _DUMMY_ACCOUNT_ID,
"workspace_id": _DUMMY_WORKSPACE_ID,
"token_federation_default_oidc_audiences": [f"{_DUMMY_WS_HOST}/oidc/v1/token"],
}
),
)
config = Config(host=_DUMMY_WS_HOST, token="t", token_audience="my-custom-audience")
assert config.token_audience == "my-custom-audience"


def test_resolve_host_metadata_falls_back_to_account_id_when_no_token_federation_default_oidc_audiences(mocker):
"""When no token_federation_default_oidc_audiences and no workspace_id, falls back to account_id."""
mocker.patch(
"databricks.sdk.config.get_host_metadata",
return_value=HostMetadata.from_dict(
{
"oidc_endpoint": f"{_DUMMY_ACC_HOST}/oidc/accounts/{_DUMMY_ACCOUNT_ID}",
"account_id": _DUMMY_ACCOUNT_ID,
}
),
)
config = Config(host=_DUMMY_ACC_HOST, token="t", account_id=_DUMMY_ACCOUNT_ID)
# No token_federation_default_oidc_audiences and no workspace_id → falls back to account_id
assert config.token_audience == _DUMMY_ACCOUNT_ID
Loading