Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 72 additions & 68 deletions cognite/client/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import atexit
import inspect
import json
import operator
import tempfile
import threading
import time
Expand Down Expand Up @@ -463,72 +464,78 @@ def _get_token(self, convert_timestamps: bool = True) -> dict[str, Any]:
return token

def _refresh_access_token(self) -> tuple[str, float]:
# First check if a token cache exists on disk. If yes, find and use:
# - A valid access token.
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
# Token resolution order (cheapest option first):
# 1. Valid access token in cache → use directly, no network call
# 2. Refresh token in cache → exchange for new AT, one network call
# 3. Device code flow → interactive, requires user action
credentials = None
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.REFRESH_TOKEN):
if "expires_on" in token and token["expires_on"] > time.time():
credentials = token

# 1. Check for a still-valid access token. search() does NOT filter by expiry,
# so we check manually and respect the leeway to avoid handing out a near-expired token.
for token in self.__app.token_cache.search(
self.__app.token_cache.CredentialType.ACCESS_TOKEN,
query={"client_id": self.client_id},
):
expiry = int(token.get("expires_on", 0)) - time.time() - self.token_expiry_leeway_seconds
if expiry > 0:
credentials = {"access_token": token["secret"], "expires_in": expiry}
break

# 2. No valid AT — try to silently redeem a refresh token.
if credentials is None:
rt_entry = None
for token in self.__app.token_cache.search(
self.__app.token_cache.CredentialType.REFRESH_TOKEN,
query={"client_id": self.client_id},
):
rt_entry = token
break # MSAL RTs have no 'expires_on'; use the first found

if rt_entry is not None:
# Pass the full RT cache entry (not just the secret string) so MSAL's
# on_removing_rt callback can properly remove it on invalid_grant.
# Exclude OIDC meta-scopes that are not valid in token-endpoint requests.
oidc_scopes = frozenset({"openid", "profile", "email", "offline_access"})
resp = self.__app.client.obtain_token_by_refresh_token(
rt_entry,
rt_getter=operator.itemgetter("secret"),
scope=" ".join(s for s in self.__scopes if s not in oidc_scopes),
)
if isinstance(resp, dict) and "error" not in resp:
credentials = resp
# else: RT rejected by server, fall through to device code flow

if credentials is not None:
credentials = self.__app.client.obtain_token_by_refresh_token(credentials.get("secret", ""))
else:
for token in self.__app.token_cache.search(self.__app.token_cache.CredentialType.ACCESS_TOKEN):
if expiry := int(token.get("expires_on", 0)) - time.time() > 0:
credentials = {
"access_token": token.get("secret"),
"expires_in": expiry,
}
break
# If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
self._verify_credentials(credentials)
return credentials["access_token"], time.time() + float(credentials["expires_in"])

# 3. If we're unable to find (or acquire a new) access token, we initiate the device code auth flow.
# The msal device_code flow does not support setting the audience, so we need to handle it manually.
# We use the http client instantiated as part of the msal client, as well as the details found
# in oauth discovery.
if credentials is None:
data = {
"scope": self.scope_string(),
"client_id": self.client_id,
}
for key, value in self.__token_custom_args.items():
data[key] = value

device_flow_endpoint = self._get_device_authorization_endpoint()
device_flow_response = self._get_device_code_response(device_flow_endpoint, data)
if "verification_uri" in device_flow_response:
print( # noqa: T201
f"Visit {device_flow_response['verification_uri']} and enter the code: {device_flow_response.get('user_code', 'ERROR')}"
)
elif "message" in device_flow_response:
print( # noqa: T201
f"Device code: {device_flow_response.get('message', device_flow_response.get('user_code', 'ERROR'))}"
)
else:
raise CogniteOAuthError(
device_flow_response.get("error", ""), device_flow_response.get("error_description", "")
)

if "interval" not in device_flow_response:
# Set default interval according to standard
device_flow_response["interval"] = 5
if "expires_in" in device_flow_response:
# msal library uses expires_at instead of the standard expires_in
device_flow_response["expires_at"] = float(device_flow_response["expires_in"]) + time.time()
# Poll for token
credentials = self.__app.client.obtain_token_by_device_flow(
flow=device_flow_response,
data=dict(
data,
code=device_flow_response.get(
"device_code"
), # Hack from msal library to get the code from the device flow, not standard
),
data = {"scope": self.scope_string(), "client_id": self.client_id, **self.__token_custom_args}
device_flow_endpoint = self._get_device_authorization_endpoint()
response = self._get_device_code_response(device_flow_endpoint, data)
if "verification_uri" in response:
print(f"Visit {response['verification_uri']} and enter the code: {response.get('user_code', 'ERROR')}") # noqa: T201
elif "message" in response:
print(f"Device code: {response.get('message', response.get('user_code', 'ERROR'))}") # noqa: T201
else:
raise CogniteAuthError(
f"Error initiating device flow: {response.get('error')} - {response.get('error_description')}"
)

self._verify_credentials(credentials)
self.__app.token_cache.add(
dict(credentials, environment=self.__app.authority.instance),
if "interval" not in response:
response["interval"] = 5 # Set default interval according to standard
if "expires_in" in response:
# msal library uses expires_at instead of the standard expires_in
response["expires_at"] = float(response["expires_in"]) + time.time()

credentials = self.__app.client.obtain_token_by_device_flow(
flow=response,
# Hack from msal library to get the code from the device flow, not standard:
data=dict(data, code=response.get("device_code")),
)
self._verify_credentials(credentials)
return credentials["access_token"], time.time() + float(credentials["expires_in"])

@classmethod
Expand Down Expand Up @@ -580,18 +587,16 @@ def default_for_entra_id(
mem_cache_only: bool = False,
) -> OAuthDeviceCode:
"""
Create an OAuthDeviceCode instance for Azure with default URLs and scopes. It uses the pre-configured Cognite
app registration for device code flow. If you need device code flow with another app registration, instantiate
OAuthDeviceCode directly.
Create an OAuthDeviceCode instance for Azure with default URLs and scopes.

The default configuration creates the URLs based on the tenant id and cluster:

* Authority URL: "https://login.microsoftonline.com/{tenant_id}"
* Scopes: [f"https://{cdf_cluster}.cognitedata.com/.default"]
* Scopes: [f"https://{cdf_cluster}.cognitedata.com/IDENTITY", f"https://{cdf_cluster}.cognitedata.com/user_impersonation", "profile", "openid", "offline_access"]

Args:
tenant_id (str): The Azure tenant id
client_id (str): An app registration that allows device code flow.
client_id (str): Your app registration client id. Must have device code flow enabled.
cdf_cluster (str): The CDF cluster where the CDF project is located.
token_cache_path (Path | None): Location to store token cache, defaults to os temp directory/cognitetokencache.{client_id}.bin.
token_expiry_leeway_seconds (int): The token is refreshed at the earliest when this number of seconds is left before expiry. Default: 30 sec
Expand All @@ -602,18 +607,18 @@ def default_for_entra_id(
"""
return cls(
authority_url=f"https://login.microsoftonline.com/{tenant_id}",
client_id=client_id, # Default application for CDF API for device code flow
client_id=client_id,
scopes=[
f"https://{cdf_cluster}.cognitedata.com/IDENTITY",
f"https://{cdf_cluster}.cognitedata.com/user_impersonation",
"profile",
"openid",
"offline_access", # required for Azure to issue a refresh token
],
token_cache_path=token_cache_path,
token_expiry_leeway_seconds=token_expiry_leeway_seconds,
clear_cache=clear_cache,
mem_cache_only=mem_cache_only,
audience=f"https://{cdf_cluster}.cognitedata.com",
)

@classmethod
Expand Down Expand Up @@ -697,9 +702,8 @@ def scopes(self) -> list[str]:
return self.__scopes

def _refresh_access_token(self) -> tuple[str, float]:
# First check if a token cache exists on disk. If yes, find and use:
# - A valid access token.
# - A valid refresh token, and if so, use it automatically to redeem a new access token.
# Try the in-memory token cache silently (MSAL checks AT first, then RT automatically).
# Falls through to interactive flow if nothing usable is found.
credentials = None
if accounts := self.__app.get_accounts():
credentials = self.__app.acquire_token_silent(scopes=self.__scopes, account=accounts[0])
Expand Down
Loading
Loading