From e135d94f1341333a64e1c4a20ebfbc4bd91b149a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5kon=20V=2E=20Treider?= Date: Tue, 5 May 2026 10:18:42 +0200 Subject: [PATCH] fix(credentials): token cache reuse issue for OAuthDeviceCode (#2602) --- cognite/client/credentials.py | 140 +++++++++--------- tests/tests_unit/test_credential_providers.py | 53 ++++--- 2 files changed, 98 insertions(+), 95 deletions(-) diff --git a/cognite/client/credentials.py b/cognite/client/credentials.py index b11ae91475..fce5ba3c55 100644 --- a/cognite/client/credentials.py +++ b/cognite/client/credentials.py @@ -3,6 +3,7 @@ import atexit import inspect import json +import operator import tempfile import threading import time @@ -463,72 +464,78 @@ def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]: return token def _refresh_access_token(self) -> tuple[str, float]: - # First check if a token cache exists on disk. If yes, find and use: - # - A valid access token. - # - A valid refresh token, and if so, use it automatically to redeem a new access token. + # Token resolution order (cheapest option first): + # 1. Valid access token in cache → use directly, no network call + # 2. Refresh token in cache → exchange for new AT, one network call + # 3. Device code flow → interactive, requires user action credentials = None - for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN): - if "expires_on" in token and token["expires_on"] > time.time(): - credentials = token + + # 1. Check for a still-valid access token. search() does NOT filter by expiry, + # so we check manually and respect the leeway to avoid handing out a near-expired token. + for token in self.__app.token_cache.search( + self.__app.token_cache.CredentialType.ACCESS_TOKEN, + query={"client_id": self.client_id}, + ): + expiry = int(token.get("expires_on", 0)) - time.time() - self.token_expiry_leeway_seconds + if expiry > 0: + credentials = {"access_token": token["secret"], "expires_in": expiry} break + + # 2. No valid AT — try to silently redeem a refresh token. + if credentials is None: + rt_entry = None + for token in self.__app.token_cache.search( + self.__app.token_cache.CredentialType.REFRESH_TOKEN, + query={"client_id": self.client_id}, + ): + rt_entry = token + break # MSAL RTs have no 'expires_on'; use the first found + + if rt_entry is not None: + # Pass the full RT cache entry (not just the secret string) so MSAL's + # on_removing_rt callback can properly remove it on invalid_grant. + # Exclude OIDC meta-scopes that are not valid in token-endpoint requests. + oidc_scopes = frozenset({"openid", "profile", "email", "offline_access"}) + resp = self.__app.client.obtain_token_by_refresh_token( + rt_entry, + rt_getter=operator.itemgetter("secret"), + scope=" ".join(s for s in self.__scopes if s not in oidc_scopes), + ) + if isinstance(resp, dict) and "error" not in resp: + credentials = resp + # else: RT rejected by server, fall through to device code flow + if credentials is not None: - credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", "")) - else: - for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN): - if expiry := int(token.get("expires_on", 0)) - time.time() > 0: - credentials = { - "access_token": token.get("secret"), - "expires_in": expiry, - } - break - # If we're unable to find (or acquire a new) access token, we initiate the device code auth flow. + self._verify_credentials(credentials) + return credentials["access_token"], time.time() + float(credentials["expires_in"]) + + # 3. If we're unable to find (or acquire a new) access token, we initiate the device code auth flow. # The msal device_code flow does not support setting the audience, so we need to handle it manually. # We use the http client instantiated as part of the msal client, as well as the details found # in oauth discovery. - if credentials is None: - data = { - "scope": self.scope_string(), - "client_id": self.client_id, - } - for key, value in self.__token_custom_args.items(): - data[key] = value - - device_flow_endpoint = self._get_device_authorization_endpoint() - device_flow_response = self._get_device_code_response(device_flow_endpoint, data) - if "verification_uri" in device_flow_response: - print( # noqa: T201 - f"Visit {device_flow_response['verification_uri']} and enter the code: {device_flow_response.get('user_code', 'ERROR')}" - ) - elif "message" in device_flow_response: - print( # noqa: T201 - f"Device code: {device_flow_response.get('message', device_flow_response.get('user_code', 'ERROR'))}" - ) - else: - raise CogniteOAuthError( - device_flow_response.get("error", ""), device_flow_response.get("error_description", "") - ) - - if "interval" not in device_flow_response: - # Set default interval according to standard - device_flow_response["interval"] = 5 - if "expires_in" in device_flow_response: - # msal library uses expires_at instead of the standard expires_in - device_flow_response["expires_at"] = float(device_flow_response["expires_in"]) + time.time() - # Poll for token - credentials = self.__app.client.obtain_token_by_device_flow( - flow=device_flow_response, - data=dict( - data, - code=device_flow_response.get( - "device_code" - ), # Hack from msal library to get the code from the device flow, not standard - ), + data = {"scope": self.scope_string(), "client_id": self.client_id, **self.__token_custom_args} + device_flow_endpoint = self._get_device_authorization_endpoint() + response = self._get_device_code_response(device_flow_endpoint, data) + if "verification_uri" in response: + print(f"Visit {response['verification_uri']} and enter the code: {response.get('user_code', 'ERROR')}") # noqa: T201 + elif "message" in response: + print(f"Device code: {response.get('message', response.get('user_code', 'ERROR'))}") # noqa: T201 + else: + raise CogniteAuthError( + f"Error initiating device flow: {response.get('error')} - {response.get('error_description')}" ) - - self._verify_credentials(credentials) - self.__app.token_cache.add( - dict(credentials, environment=self.__app.authority.instance), + if "interval" not in response: + response["interval"] = 5 # Set default interval according to standard + if "expires_in" in response: + # msal library uses expires_at instead of the standard expires_in + response["expires_at"] = float(response["expires_in"]) + time.time() + + credentials = self.__app.client.obtain_token_by_device_flow( + flow=response, + # Hack from msal library to get the code from the device flow, not standard: + data=dict(data, code=response.get("device_code")), ) + self._verify_credentials(credentials) return credentials["access_token"], time.time() + float(credentials["expires_in"]) @classmethod @@ -580,18 +587,16 @@ def default_for_entra_id( mem_cache_only: bool = False, ) -> OAuthDeviceCode: """ - Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite - app registration for device code flow. If you need device code flow with another app registration, instantiate - OAuthDeviceCode directly. + Create an OAuthDeviceCode instance for Azure with default URLs and scopes. The default configuration creates the URLs based on the tenant id and cluster: * Authority URL: "https://login.microsoftonline.com/{tenant_id}" - * Scopes: [f"https://{cdf_cluster}.cognitedata.com/.default"] + * Scopes: [f"https://{cdf_cluster}.cognitedata.com/IDENTITY", f"https://{cdf_cluster}.cognitedata.com/user_impersonation", "profile", "openid", "offline_access"] Args: tenant_id (str): The Azure tenant id - client_id (str): An app registration that allows device code flow. + client_id (str): Your app registration client id. Must have device code flow enabled. cdf_cluster (str): The CDF cluster where the CDF project is located. token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin. token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec @@ -602,18 +607,18 @@ def default_for_entra_id( """ return cls( authority_url=f"https://login.microsoftonline.com/{tenant_id}", - client_id=client_id, # Default application for CDF API for device code flow + client_id=client_id, scopes=[ f"https://{cdf_cluster}.cognitedata.com/IDENTITY", f"https://{cdf_cluster}.cognitedata.com/user_impersonation", "profile", "openid", + "offline_access", # required for Azure to issue a refresh token ], token_cache_path=token_cache_path, token_expiry_leeway_seconds=token_expiry_leeway_seconds, clear_cache=clear_cache, mem_cache_only=mem_cache_only, - audience=f"https://{cdf_cluster}.cognitedata.com", ) @classmethod @@ -697,9 +702,8 @@ def scopes(self) -> list[str]: return self.__scopes def _refresh_access_token(self) -> tuple[str, float]: - # First check if a token cache exists on disk. If yes, find and use: - # - A valid access token. - # - A valid refresh token, and if so, use it automatically to redeem a new access token. + # Try the in-memory token cache silently (MSAL checks AT first, then RT automatically). + # Falls through to interactive flow if nothing usable is found. credentials = None if accounts := self.__app.get_accounts(): credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0]) diff --git a/tests/tests_unit/test_credential_providers.py b/tests/tests_unit/test_credential_providers.py index 0aa454e33e..a69645d569 100644 --- a/tests/tests_unit/test_credential_providers.py +++ b/tests/tests_unit/test_credential_providers.py @@ -48,11 +48,11 @@ def test_invalid_not_dict(self, config: dict, error_type: type[Exception], error class TestToken: def test_token_auth_header(self) -> None: creds = Token("abc") - assert "Authorization", "Bearer abc" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer abc") def test_token_factory_auth_header(self) -> None: creds = Token(lambda: "abc") - assert "Authorization", "Bearer abc" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer abc") def test_token_non_string(self) -> None: with pytest.raises( @@ -71,7 +71,7 @@ def test_token_non_string(self) -> None: def test_load(self, config: dict) -> None: creds = Token.load(config) assert isinstance(creds, Token) - assert "Authorization", "Bearer abc" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer abc") @pytest.mark.parametrize( "config", @@ -87,7 +87,7 @@ def test_load(self, config: dict) -> None: def test_create_from_credential_provider(self, config: dict) -> None: creds = CredentialProvider.load(config) assert isinstance(creds, Token) - assert "Authorization", "Bearer abc" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer abc") class TestOAuthDeviceCode: @@ -114,7 +114,7 @@ def test_access_token_generated(self, mock_public_client: MagicMock, expires_in: } creds = OAuthDeviceCode(**self.DEFAULT_PROVIDER_ARGS) creds._refresh_access_token() - assert "Authorization", "Bearer azure_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token") @patch("cognite.client.credentials.PublicClientApplication") def test_entra_id_uses_authority_endpoint(self, mock_public_client: MagicMock) -> None: @@ -145,20 +145,18 @@ def test_entra_id_uses_authority_endpoint(self, mock_public_client: MagicMock) - call_args = mock_public_client().http_client.post.call_args assert call_args[0][0] == "https://login.microsoftonline.com/xyz/oauth2/v2.0/devicecode" - assert "Authorization", "Bearer azure_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token") @patch("cognite.client.credentials.PublicClientApplication") def test_load(self, mock_public_client: MagicMock) -> None: creds = OAuthDeviceCode.load(dict(self.DEFAULT_PROVIDER_ARGS)) assert isinstance(creds, OAuthDeviceCode) - assert "Authorization", "Bearer azure_token" == creds.authorization_header() @patch("cognite.client.credentials.PublicClientApplication") def test_create_from_credential_provider(self, mock_public_client: MagicMock) -> None: config = {"device_code": dict(self.DEFAULT_PROVIDER_ARGS)} creds = CredentialProvider.load(config) assert isinstance(creds, OAuthDeviceCode) - assert "Authorization", "Bearer azure_token" == creds.authorization_header() @patch("cognite.client.credentials.PublicClientApplication") def test_oauth_discovery_url_device_flow(self, mock_public_client: MagicMock) -> None: @@ -215,7 +213,7 @@ def http_client_side_effect(url: str, **kwargs: Any) -> Any: call_args = mock_public_client().http_client.post.call_args assert call_args[0][0] == "https://auth0.example.com/oauth/device/code" - assert "Authorization", "Bearer auth0_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer auth0_token") @patch("cognite.client.credentials.PublicClientApplication") def test_device_code_msal_response(self, mock_public_client: MagicMock) -> None: @@ -244,7 +242,7 @@ def __init__(self, text: str) -> None: creds = OAuthDeviceCode(**self.DEFAULT_PROVIDER_ARGS) creds._refresh_access_token() - assert "Authorization", "Bearer token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer token") @patch("cognite.client.credentials.PublicClientApplication") def test_oidc_discovery_url_failure_error(self, mock_public_client: MagicMock) -> None: @@ -333,7 +331,7 @@ def test_device_flow_response_invalid_client(self, mock_public_client: MagicMock creds = OAuthDeviceCode(**self.DEFAULT_PROVIDER_ARGS) - with pytest.raises(CogniteAuthError, match=r"Error generating access token: 'invalid_client'"): + with pytest.raises(CogniteAuthError, match=r"Error initiating device flow: invalid_client"): creds._refresh_access_token() @patch("cognite.client.credentials.PublicClientApplication") @@ -341,10 +339,14 @@ def test_refresh_token_from_cache(self, mock_public_client: MagicMock) -> None: # Mock a valid refresh token in cache mock_refresh_token = { "secret": "refresh_token_secret", - "expires_on": time.time() + 3600, # Valid for 1 hour + "client_id": "test-client-id", } - mock_public_client().token_cache.search.return_value = [mock_refresh_token] + # AT search returns nothing; RT search returns the entry + mock_public_client().token_cache.search.side_effect = [ + [], # No valid access tokens + [mock_refresh_token], # Refresh token found + ] mock_public_client().client.obtain_token_by_refresh_token.return_value = { "access_token": "new_access_token", "expires_in": 3600, @@ -353,16 +355,16 @@ def test_refresh_token_from_cache(self, mock_public_client: MagicMock) -> None: creds = OAuthDeviceCode(**self.DEFAULT_PROVIDER_ARGS) creds._refresh_access_token() - # Verify refresh token was used - mock_public_client().client.obtain_token_by_refresh_token.assert_called_once_with("refresh_token_secret") - assert "Authorization", "Bearer new_access_token" == creds.authorization_header() + # Verify refresh token was used (full entry + rt_getter kwarg) + call_args = mock_public_client().client.obtain_token_by_refresh_token.call_args + assert call_args.args[0] == mock_refresh_token + assert call_args.kwargs["rt_getter"](mock_refresh_token) == "refresh_token_secret" @patch("cognite.client.credentials.PublicClientApplication") def test_access_token_from_cache(self, mock_public_client: MagicMock) -> None: # Mock no refresh token, but valid access token mock_public_client().token_cache.search.side_effect = [ - [], # No refresh tokens - [ # Access tokens + [ # Access token found on first search { "secret": "cached_access_token", "expires_on": str(int(time.time() + 3600)), # Valid for 1 hour @@ -371,11 +373,10 @@ def test_access_token_from_cache(self, mock_public_client: MagicMock) -> None: ] creds = OAuthDeviceCode(**self.DEFAULT_PROVIDER_ARGS) - creds._refresh_access_token() # Verify device flow was NOT triggered + assert creds.authorization_header() == ("Authorization", "Bearer cached_access_token") mock_public_client().http_client.post.assert_not_called() - assert "Authorization", "Bearer cached_access_token" == creds.authorization_header() class TestOAuthInteractive: @@ -396,20 +397,18 @@ def test_access_token_generated(self, mock_public_client: MagicMock, expires_in: } creds = OAuthInteractive(**self.DEFAULT_PROVIDER_ARGS) creds._refresh_access_token() - assert "Authorization", "Bearer azure_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token") @patch("cognite.client.credentials.PublicClientApplication") def test_load(self, mock_public_client: MagicMock) -> None: creds = OAuthInteractive.load(dict(self.DEFAULT_PROVIDER_ARGS)) assert isinstance(creds, OAuthInteractive) - assert "Authorization", "Bearer azure_token" == creds.authorization_header() @patch("cognite.client.credentials.PublicClientApplication") def test_create_from_credential_provider(self, mock_public_client: MagicMock) -> None: config = {"interactive": dict(self.DEFAULT_PROVIDER_ARGS)} creds = CredentialProvider.load(config) assert isinstance(creds, OAuthInteractive) - assert "Authorization", "Bearer azure_token" == creds.authorization_header() class TestOauthClientCredentials: @@ -429,7 +428,7 @@ def test_access_token_generated(self, mock_oauth_client: MagicMock, expires_in: mock_oauth_client().fetch_token.return_value = {"access_token": "azure_token", "expires_in": expires_in} creds = OAuthClientCredentials(**self.DEFAULT_PROVIDER_ARGS) creds._refresh_access_token() - assert "Authorization", "Bearer azure_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token") @patch("authlib.integrations.httpx_client.OAuth2Client") def test_access_token_not_generated_due_to_error(self, mock_oauth_client: MagicMock) -> None: @@ -448,8 +447,8 @@ def test_access_token_expired(self, mock_oauth_client: MagicMock) -> None: {"access_token": "azure_token_refreshed", "expires_in": 1000}, ] creds = OAuthClientCredentials(**self.DEFAULT_PROVIDER_ARGS) - assert "Authorization", "Bearer azure_token_expired" == creds.authorization_header() - assert "Authorization", "Bearer azure_token_refreshed" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token_expired") + assert creds.authorization_header() == ("Authorization", "Bearer azure_token_refreshed") def test_load(self) -> None: creds = OAuthClientCredentials.load(dict(self.DEFAULT_PROVIDER_ARGS)) @@ -483,7 +482,7 @@ def test_access_token_generated(self, mock_msal_app: MagicMock) -> None: "expires_in": 1000, } creds = OAuthClientCertificate(**self.DEFAULT_PROVIDER_ARGS) - assert "Authorization", "Bearer azure_token" == creds.authorization_header() + assert creds.authorization_header() == ("Authorization", "Bearer azure_token") @patch("cognite.client.credentials.ConfidentialClientApplication") def test_load(self, mock_msal_app: MagicMock) -> None: