diff --git a/README.md b/README.md index c5dc22e..d5b03a7 100644 --- a/README.md +++ b/README.md @@ -146,8 +146,8 @@ The processor generates two types of files per channel: | `OUTPUT_DIR` | Directory for output files | - | | `CHUNK_SIZE_MB` | Size of each data chunk in MB | `1` | | `IMPORTER_ENABLED` | Enable Pennsieve upload | `false` | -| `PENNSIEVE_API_KEY` | Pennsieve API key | - | -| `PENNSIEVE_API_SECRET` | Pennsieve API secret | - | +| `SESSION_TOKEN` | Pennsieve session token | - | +| `REFRESH_TOKEN` | Pennsieve refresh token | - | | `PENNSIEVE_API_HOST` | Pennsieve API endpoint | `https://api.pennsieve.net` | | `PENNSIEVE_API_HOST2` | Pennsieve API2 endpoint | `https://api2.pennsieve.net` | | `INTEGRATION_ID` | Workflow instance ID | - | diff --git a/processor/clients/__init__.py b/processor/clients/__init__.py index 3ec433b..350c3a3 100644 --- a/processor/clients/__init__.py +++ b/processor/clients/__init__.py @@ -1,4 +1,6 @@ -from .authentication_client import AuthenticationClient as AuthenticationClient +from .authentication_client import AuthProvider as AuthProvider +from .authentication_client import KeySecretAuthProvider as KeySecretAuthProvider +from .authentication_client import TokenAuthProvider as TokenAuthProvider from .base_client import BaseClient as BaseClient from .base_client import SessionManager as SessionManager from .import_client import ImportClient as ImportClient diff --git a/processor/clients/authentication_client.py b/processor/clients/authentication_client.py index 8cb9fc2..6266b3b 100644 --- a/processor/clients/authentication_client.py +++ b/processor/clients/authentication_client.py @@ -1,5 +1,7 @@ +import base64 import json import logging +from abc import ABC, abstractmethod import boto3 import requests @@ -7,42 +9,152 @@ log = logging.getLogger() -class AuthenticationClient: +class AuthProvider(ABC): + """Interface for authentication strategies. + + All auth methods ultimately produce a session token and the ability to + refresh it. Implementations differ only in how they bootstrap. + """ + + @abstractmethod + def get_session_token(self) -> str: + """Return the current session token.""" + ... + + @abstractmethod + def refresh(self) -> str: + """Refresh and return a new session token.""" + ... + + +class CognitoClient: + """Shared Cognito interaction logic used by all auth providers.""" + def __init__(self, api_host): self.api_host = api_host + self._cognito_config = None + + def _get_cognito_config(self): + if self._cognito_config is not None: + return self._cognito_config - def authenticate(self, api_key, api_secret): url = f"{self.api_host}/authentication/cognito-config" + response = requests.get(url) + response.raise_for_status() + data = json.loads(response.content) - try: - response = requests.get(url) - response.raise_for_status() - data = json.loads(response.content) + self._cognito_config = { + "app_client_id": data["userPool"]["appClientId"], + "region": data["region"], + } + return self._cognito_config - cognito_app_client_id = data["tokenPool"]["appClientId"] - cognito_region = data["region"] + def _get_idp_client(self): + config = self._get_cognito_config() + return boto3.client( + "cognito-idp", + region_name=config["region"], + aws_access_key_id="", + aws_secret_access_key="", + ) - cognito_idp_client = boto3.client( - "cognito-idp", - region_name=cognito_region, - aws_access_key_id="", - aws_secret_access_key="", - ) + def authenticate(self, api_key, api_secret): + """Exchange API key/secret for session + refresh tokens via Cognito USER_PASSWORD_AUTH.""" + config = self._get_cognito_config() + idp_client = self._get_idp_client() - login_response = cognito_idp_client.initiate_auth( - AuthFlow="USER_PASSWORD_AUTH", - AuthParameters={"USERNAME": api_key, "PASSWORD": api_secret}, - ClientId=cognito_app_client_id, - ) + login_response = idp_client.initiate_auth( + AuthFlow="USER_PASSWORD_AUTH", + AuthParameters={"USERNAME": api_key, "PASSWORD": api_secret}, + ClientId=config["app_client_id"], + ) + + auth_result = login_response["AuthenticationResult"] + return auth_result["AccessToken"], auth_result["RefreshToken"] + + @staticmethod + def _decode_token(token): + """Decode a JWT payload without verification (for extracting claims like device_key).""" + payload = token.split(".")[1] + # JWT base64url encoding may lack padding + padding = 4 - len(payload) % 4 + if padding != 4: + payload += "=" * padding + return json.loads(base64.urlsafe_b64decode(payload)) + + def refresh_token(self, refresh_token, session_token=None): + """Use a refresh token to obtain a new access token via Cognito REFRESH_TOKEN_AUTH.""" + config = self._get_cognito_config() + idp_client = self._get_idp_client() + + auth_parameters = {"REFRESH_TOKEN": refresh_token} + + device_key = None + if session_token: + try: + decoded = self._decode_token(session_token) + device_key = decoded.get("device_key") + if device_key: + log.info(f"extracted device_key from session token: {device_key}") + except Exception as e: + log.warning(f"failed to extract device_key from session token: {e}") - access_token = login_response["AuthenticationResult"]["AccessToken"] - return access_token - except requests.HTTPError as e: - log.error(f"failed to reach authentication server with error: {e}") - raise e - except json.JSONDecodeError as e: - log.error(f"failed to decode authentication response with error: {e}") - raise e - except Exception as e: - log.error(f"failed to authenticate with error: {e}") - raise e + if device_key: + auth_parameters["DEVICE_KEY"] = device_key + + response = idp_client.initiate_auth( + AuthFlow="REFRESH_TOKEN_AUTH", + AuthParameters=auth_parameters, + ClientId=config["app_client_id"], + ) + + return response["AuthenticationResult"]["AccessToken"] + + +class TokenAuthProvider(AuthProvider): + """Auth provider for pre-supplied session + refresh tokens (production path).""" + + def __init__(self, api_host, session_token, refresh_token): + self._session_token = session_token + self._refresh_token = refresh_token + self._cognito = CognitoClient(api_host) + + def get_session_token(self) -> str: + return self._session_token + + def refresh(self) -> str: + if not self._refresh_token: + raise RuntimeError("cannot refresh session: no refresh token available") + log.info("refreshing session token using refresh token") + self._session_token = self._cognito.refresh_token(self._refresh_token, self._session_token) + return self._session_token + + +class KeySecretAuthProvider(AuthProvider): + """Auth provider that authenticates with API key/secret (local development path). + + Authenticates eagerly on construction to obtain session + refresh tokens, + then refreshes using the same Cognito refresh flow as TokenAuthProvider. + """ + + def __init__(self, api_host, api_key, api_secret): + self._api_key = api_key + self._api_secret = api_secret + self._cognito = CognitoClient(api_host) + + log.info("authenticating with API key/secret") + self._session_token, self._refresh_token = self._cognito.authenticate(api_key, api_secret) + + def get_session_token(self) -> str: + return self._session_token + + def refresh(self) -> str: + if self._refresh_token: + log.info("refreshing session token using refresh token") + self._session_token = self._cognito.refresh_token(self._refresh_token, self._session_token) + else: + log.info("no refresh token, re-authenticating with API key/secret") + self._session_token, self._refresh_token = self._cognito.authenticate( + self._api_key, self._api_secret + ) + return self._session_token diff --git a/processor/clients/base_client.py b/processor/clients/base_client.py index 5d224db..d93d59d 100644 --- a/processor/clients/base_client.py +++ b/processor/clients/base_client.py @@ -5,24 +5,17 @@ log = logging.getLogger() -# encapsulates a shared API session and re-authentication functionality +# encapsulates a shared API session and token refresh class SessionManager: - def __init__(self, authentication_client, api_key, api_secret): - self.authentication_client = authentication_client - self.api_key = api_key - self.api_secret = api_secret - - self.__session_token = None + def __init__(self, auth_provider): + self._auth_provider = auth_provider @property def session_token(self): - if self.__session_token is None: - self.refresh_session() - - return self.__session_token + return self._auth_provider.get_session_token() def refresh_session(self): - self.__session_token = self.authentication_client.authenticate(self.api_key, self.api_secret) + self._auth_provider.refresh() class BaseClient: diff --git a/processor/clients/workflow_client.py b/processor/clients/workflow_client.py index 13442a0..f284c8c 100644 --- a/processor/clients/workflow_client.py +++ b/processor/clients/workflow_client.py @@ -25,7 +25,7 @@ def __init__(self, api_host, session_manager): # with an empty body even when a workflow instance does not exist @BaseClient.retry_with_refresh def get_workflow_instance(self, workflow_instance_id): - url = f"{self.api_host}/compute/workflows/instances/{workflow_instance_id}" + url = f"{self.api_host}/compute/workflows/runs/{workflow_instance_id}" headers = {"Accept": "application/json", "Authorization": f"Bearer {self.session_manager.session_token}"} diff --git a/processor/config.py b/processor/config.py index 769498d..42a16b0 100644 --- a/processor/config.py +++ b/processor/config.py @@ -1,6 +1,9 @@ +import logging import os import uuid +log = logging.getLogger() + class Config: def __init__(self): @@ -15,6 +18,8 @@ def __init__(self): # has been converted to use a different variable to represent the workflow instance ID self.WORKFLOW_INSTANCE_ID = os.getenv("INTEGRATION_ID", str(uuid.uuid4())) + self.SESSION_TOKEN = os.getenv("SESSION_TOKEN") + self.REFRESH_TOKEN = os.getenv("REFRESH_TOKEN") self.API_KEY = os.getenv("PENNSIEVE_API_KEY") self.API_SECRET = os.getenv("PENNSIEVE_API_SECRET") self.API_HOST = os.getenv("PENNSIEVE_API_HOST", "https://api.pennsieve.net") diff --git a/processor/importer.py b/processor/importer.py index 5d72b32..6107ebd 100644 --- a/processor/importer.py +++ b/processor/importer.py @@ -10,11 +10,9 @@ import backoff import requests from clients import ( - AuthenticationClient, ImportClient, ImportFile, PackagesClient, - SessionManager, TimeSeriesClient, WorkflowClient, ) @@ -34,7 +32,9 @@ """ -def import_timeseries(api_host, api2_host, api_key, api_secret, workflow_instance_id, file_directory): +def import_timeseries( + api_host, api2_host, session_manager, workflow_instance_id, file_directory +): # gather all the time series files from the output directory timeseries_data_files = [] timeseries_channel_files = [] @@ -50,10 +50,6 @@ def import_timeseries(api_host, api2_host, api_key, api_secret, workflow_instanc log.info("no time series channels or data") return None - # authentication against the Pennsieve API - authorization_client = AuthenticationClient(api_host) - session_manager = SessionManager(authorization_client, api_key, api_secret) - # fetch workflow instance for parameters (dataset_id, package_id, etc.) workflow_client = WorkflowClient(api2_host, session_manager) workflow_instance = workflow_client.get_workflow_instance(workflow_instance_id) diff --git a/processor/main.py b/processor/main.py index 3677fcd..3865c0f 100644 --- a/processor/main.py +++ b/processor/main.py @@ -45,11 +45,22 @@ # note: this will be moved to a separated post-processor once the analysis pipeline is more # easily able to handle > 3 processors if config.IMPORTER_ENABLED: - importer = import_timeseries( + from clients.authentication_client import KeySecretAuthProvider, TokenAuthProvider + from clients.base_client import SessionManager + + if config.SESSION_TOKEN: + auth_provider = TokenAuthProvider(config.API_HOST, config.SESSION_TOKEN, config.REFRESH_TOKEN) + elif config.API_KEY and config.API_SECRET: + auth_provider = KeySecretAuthProvider(config.API_HOST, config.API_KEY, config.API_SECRET) + else: + raise RuntimeError("no authentication credentials provided: set SESSION_TOKEN or API_KEY/API_SECRET") + + session_manager = SessionManager(auth_provider) + + import_timeseries( config.API_HOST, config.API_HOST2, - config.API_KEY, - config.API_SECRET, + session_manager, config.WORKFLOW_INSTANCE_ID, config.OUTPUT_DIR, ) diff --git a/tests/conftest.py b/tests/conftest.py index bcf5916..d26bf01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -53,14 +53,6 @@ def mock_session_manager(): return manager -@pytest.fixture -def mock_authentication_client(): - """Mock authentication client.""" - client = Mock() - client.authenticate = Mock(return_value="mock-access-token") - return client - - @pytest.fixture def sample_timestamps(): """Sample evenly-spaced timestamps at 1000 Hz.""" diff --git a/tests/test_authentication_client.py b/tests/test_authentication_client.py index 5ec9d31..23a3050 100644 --- a/tests/test_authentication_client.py +++ b/tests/test_authentication_client.py @@ -1,69 +1,77 @@ +import base64 import json from unittest.mock import Mock, patch import pytest import responses -from clients.authentication_client import AuthenticationClient +from clients.authentication_client import ( + CognitoClient, + KeySecretAuthProvider, + TokenAuthProvider, +) -class TestAuthenticationClientInit: - """Tests for AuthenticationClient initialization.""" +def _make_jwt(payload): + """Build a fake JWT with the given payload dict (no signature verification).""" + header = base64.urlsafe_b64encode(json.dumps({"alg": "RS256"}).encode()).rstrip(b"=").decode() + body = base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode() + return f"{header}.{body}.fake-signature" - def test_initialization(self): - """Test basic initialization.""" - client = AuthenticationClient("https://api.test.com") - assert client.api_host == "https://api.test.com" +class TestCognitoClient: + """Tests for shared CognitoClient logic.""" -class TestAuthenticationClientAuthenticate: - """Tests for AuthenticationClient.authenticate method.""" + def test_initialization(self): + client = CognitoClient("https://api.test.com") + assert client.api_host == "https://api.test.com" + assert client._cognito_config is None @responses.activate def test_authenticate_success(self): - """Test successful authentication flow.""" - # Mock cognito config response responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "test-client-id"}, "region": "us-east-1"}, + json={"userPool": {"appClientId": "test-client-id"}, "region": "us-east-1"}, status=200, ) - # Mock boto3 cognito client mock_cognito_client = Mock() mock_cognito_client.initiate_auth.return_value = { - "AuthenticationResult": {"AccessToken": "test-access-token-12345"} + "AuthenticationResult": { + "AccessToken": "test-access-token-12345", + "RefreshToken": "test-refresh-token-67890", + } } with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): - client = AuthenticationClient("https://api.test.com") - token = client.authenticate("api-key", "api-secret") + client = CognitoClient("https://api.test.com") + access_token, refresh_token = client.authenticate("api-key", "api-secret") - assert token == "test-access-token-12345" + assert access_token == "test-access-token-12345" + assert refresh_token == "test-refresh-token-67890" @responses.activate def test_authenticate_calls_cognito_with_correct_params(self): - """Test that Cognito is called with correct parameters.""" responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "my-app-client-id"}, "region": "us-west-2"}, + json={"userPool": {"appClientId": "my-app-client-id"}, "region": "us-west-2"}, status=200, ) mock_cognito_client = Mock() - mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} + mock_cognito_client.initiate_auth.return_value = { + "AuthenticationResult": {"AccessToken": "token", "RefreshToken": "refresh"} + } with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client) as mock_boto: - client = AuthenticationClient("https://api.test.com") + client = CognitoClient("https://api.test.com") client.authenticate("my-api-key", "my-api-secret") - # Check boto3 client was created with correct parameters mock_boto.assert_called_once_with( "cognito-idp", region_name="us-west-2", aws_access_key_id="", aws_secret_access_key="" ) - # Check initiate_auth was called with correct parameters mock_cognito_client.initiate_auth.assert_called_once_with( AuthFlow="USER_PASSWORD_AUTH", AuthParameters={"USERNAME": "my-api-key", "PASSWORD": "my-api-secret"}, @@ -72,7 +80,6 @@ def test_authenticate_calls_cognito_with_correct_params(self): @responses.activate def test_authenticate_raises_on_config_http_error(self): - """Test that HTTP errors from config endpoint are raised.""" responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", @@ -80,80 +87,110 @@ def test_authenticate_raises_on_config_http_error(self): status=500, ) - client = AuthenticationClient("https://api.test.com") + client = CognitoClient("https://api.test.com") with pytest.raises(Exception): client.authenticate("key", "secret") @responses.activate - def test_authenticate_raises_on_invalid_json(self): - """Test that invalid JSON response raises error.""" + def test_refresh_token_success(self): responses.add( - responses.GET, "https://api.test.com/authentication/cognito-config", body="not valid json", status=200 + responses.GET, + "https://api.test.com/authentication/cognito-config", + json={"userPool": {"appClientId": "test-client-id"}, "region": "us-east-1"}, + status=200, ) - client = AuthenticationClient("https://api.test.com") + mock_cognito_client = Mock() + mock_cognito_client.initiate_auth.return_value = { + "AuthenticationResult": {"AccessToken": "refreshed-access-token"} + } - with pytest.raises(json.JSONDecodeError): - client.authenticate("key", "secret") + with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): + client = CognitoClient("https://api.test.com") + token = client.refresh_token("my-refresh-token") + + assert token == "refreshed-access-token" @responses.activate - def test_authenticate_raises_on_cognito_error(self): - """Test that Cognito errors are raised.""" + def test_refresh_token_calls_cognito_with_correct_params(self): responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + json={"userPool": {"appClientId": "my-app-client-id"}, "region": "us-west-2"}, status=200, ) mock_cognito_client = Mock() - mock_cognito_client.initiate_auth.side_effect = Exception("Cognito auth failed") + mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): - client = AuthenticationClient("https://api.test.com") + client = CognitoClient("https://api.test.com") + client.refresh_token("the-refresh-token") - with pytest.raises(Exception, match="Cognito auth failed"): - client.authenticate("key", "secret") + mock_cognito_client.initiate_auth.assert_called_once_with( + AuthFlow="REFRESH_TOKEN_AUTH", + AuthParameters={"REFRESH_TOKEN": "the-refresh-token"}, + ClientId="my-app-client-id", + ) @responses.activate - def test_authenticate_extracts_access_token(self): - """Test that access token is correctly extracted from response.""" + def test_refresh_token_includes_device_key_from_session_token(self): + """Test that device_key is extracted from session token and included in refresh params.""" responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + json={"userPool": {"appClientId": "client-id"}, "region": "us-east-1"}, status=200, ) mock_cognito_client = Mock() - mock_cognito_client.initiate_auth.return_value = { - "AuthenticationResult": { - "AccessToken": "the-access-token", - "RefreshToken": "refresh-token", - "IdToken": "id-token", - "ExpiresIn": 3600, - } - } + mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} + + session_token = _make_jwt({"device_key": "us-east-1_device-abc-123"}) with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): - client = AuthenticationClient("https://api.test.com") - token = client.authenticate("key", "secret") + client = CognitoClient("https://api.test.com") + client.refresh_token("the-refresh-token", session_token=session_token) - # Should return only the access token - assert token == "the-access-token" + mock_cognito_client.initiate_auth.assert_called_once_with( + AuthFlow="REFRESH_TOKEN_AUTH", + AuthParameters={"REFRESH_TOKEN": "the-refresh-token", "DEVICE_KEY": "us-east-1_device-abc-123"}, + ClientId="client-id", + ) + @responses.activate + def test_refresh_token_without_device_key_in_session_token(self): + """Test that refresh works without device_key when token doesn't contain one.""" + responses.add( + responses.GET, + "https://api.test.com/authentication/cognito-config", + json={"userPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + status=200, + ) -class TestAuthenticationClientEdgeCases: - """Edge case tests for AuthenticationClient.""" + mock_cognito_client = Mock() + mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} + + session_token = _make_jwt({"sub": "user-123"}) + + with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): + client = CognitoClient("https://api.test.com") + client.refresh_token("the-refresh-token", session_token=session_token) + + mock_cognito_client.initiate_auth.assert_called_once_with( + AuthFlow="REFRESH_TOKEN_AUTH", + AuthParameters={"REFRESH_TOKEN": "the-refresh-token"}, + ClientId="client-id", + ) @responses.activate - def test_authenticate_with_empty_credentials(self): - """Test authentication with empty credentials.""" + def test_refresh_token_without_session_token(self): + """Test that refresh works without session_token (no device_key extraction attempted).""" responses.add( responses.GET, "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + json={"userPool": {"appClientId": "client-id"}, "region": "us-east-1"}, status=200, ) @@ -161,35 +198,127 @@ def test_authenticate_with_empty_credentials(self): mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): - client = AuthenticationClient("https://api.test.com") - # Empty credentials should still be passed to Cognito - client.authenticate("", "") + client = CognitoClient("https://api.test.com") + client.refresh_token("the-refresh-token") - mock_cognito_client.initiate_auth.assert_called_once() - call_args = mock_cognito_client.initiate_auth.call_args - assert call_args[1]["AuthParameters"]["USERNAME"] == "" - assert call_args[1]["AuthParameters"]["PASSWORD"] == "" + mock_cognito_client.initiate_auth.assert_called_once_with( + AuthFlow="REFRESH_TOKEN_AUTH", + AuthParameters={"REFRESH_TOKEN": "the-refresh-token"}, + ClientId="client-id", + ) @responses.activate - def test_authenticate_with_different_regions(self): - """Test authentication with different AWS regions.""" - for region in ["us-east-1", "us-west-2", "eu-west-1", "ap-northeast-1"]: - responses.reset() - responses.add( - responses.GET, - "https://api.test.com/authentication/cognito-config", - json={"tokenPool": {"appClientId": "client-id"}, "region": region}, - status=200, - ) - - mock_cognito_client = Mock() - mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} - - with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client) as mock_boto: - client = AuthenticationClient("https://api.test.com") - client.authenticate("key", "secret") - - # Verify correct region was used - mock_boto.assert_called_with( - "cognito-idp", region_name=region, aws_access_key_id="", aws_secret_access_key="" - ) + def test_cognito_config_cached_across_calls(self): + responses.add( + responses.GET, + "https://api.test.com/authentication/cognito-config", + json={"userPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + status=200, + ) + + mock_cognito_client = Mock() + mock_cognito_client.initiate_auth.return_value = {"AuthenticationResult": {"AccessToken": "token"}} + + with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): + client = CognitoClient("https://api.test.com") + client.refresh_token("refresh-token") + client.refresh_token("refresh-token") + + # Config endpoint should only be called once despite two refresh calls + assert len(responses.calls) == 1 + + +class TestTokenAuthProvider: + """Tests for TokenAuthProvider (production path: pre-supplied tokens).""" + + def test_get_session_token(self): + provider = TokenAuthProvider.__new__(TokenAuthProvider) + provider._session_token = "my-session-token" + provider._refresh_token = "my-refresh-token" + provider._cognito = Mock() + + assert provider.get_session_token() == "my-session-token" + + def test_refresh_updates_session_token(self): + mock_cognito = Mock() + mock_cognito.refresh_token.return_value = "new-access-token" + + provider = TokenAuthProvider.__new__(TokenAuthProvider) + provider._session_token = "old-token" + provider._refresh_token = "my-refresh-token" + provider._cognito = mock_cognito + + result = provider.refresh() + + assert result == "new-access-token" + assert provider.get_session_token() == "new-access-token" + mock_cognito.refresh_token.assert_called_once_with("my-refresh-token", "old-token") + + def test_refresh_raises_without_refresh_token(self): + provider = TokenAuthProvider.__new__(TokenAuthProvider) + provider._session_token = "session-token" + provider._refresh_token = None + provider._cognito = Mock() + + with pytest.raises(RuntimeError, match="no refresh token"): + provider.refresh() + + +class TestKeySecretAuthProvider: + """Tests for KeySecretAuthProvider (local dev path: key/secret → tokens).""" + + @responses.activate + def test_authenticates_eagerly_on_init(self): + responses.add( + responses.GET, + "https://api.test.com/authentication/cognito-config", + json={"userPool": {"appClientId": "client-id"}, "region": "us-east-1"}, + status=200, + ) + + mock_cognito_client = Mock() + mock_cognito_client.initiate_auth.return_value = { + "AuthenticationResult": { + "AccessToken": "initial-access-token", + "RefreshToken": "initial-refresh-token", + } + } + + with patch("clients.authentication_client.boto3.client", return_value=mock_cognito_client): + provider = KeySecretAuthProvider("https://api.test.com", "my-key", "my-secret") + + assert provider.get_session_token() == "initial-access-token" + + def test_refresh_uses_refresh_token(self): + mock_cognito = Mock() + mock_cognito.refresh_token.return_value = "refreshed-token" + + provider = KeySecretAuthProvider.__new__(KeySecretAuthProvider) + provider._api_key = "key" + provider._api_secret = "secret" + provider._session_token = "old-token" + provider._refresh_token = "my-refresh-token" + provider._cognito = mock_cognito + + result = provider.refresh() + + assert result == "refreshed-token" + assert provider.get_session_token() == "refreshed-token" + mock_cognito.refresh_token.assert_called_once_with("my-refresh-token", "old-token") + + def test_refresh_re_authenticates_when_no_refresh_token(self): + mock_cognito = Mock() + mock_cognito.authenticate.return_value = ("new-access", "new-refresh") + + provider = KeySecretAuthProvider.__new__(KeySecretAuthProvider) + provider._api_key = "key" + provider._api_secret = "secret" + provider._session_token = "old-token" + provider._refresh_token = None + provider._cognito = mock_cognito + + result = provider.refresh() + + assert result == "new-access" + assert provider._refresh_token == "new-refresh" + mock_cognito.authenticate.assert_called_once_with("key", "secret") diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 08f7555..4c13767 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -8,64 +8,23 @@ class TestSessionManager: """Tests for SessionManager class.""" - def test_initialization(self, mock_authentication_client): - """Test basic initialization.""" - manager = SessionManager( - authentication_client=mock_authentication_client, api_key="test-api-key", api_secret="test-api-secret" - ) - - assert manager.authentication_client == mock_authentication_client - assert manager.api_key == "test-api-key" - assert manager.api_secret == "test-api-secret" - - def test_session_token_lazy_initialization(self, mock_authentication_client): - """Test that session token is lazily initialized on first access.""" - manager = SessionManager(mock_authentication_client, "key", "secret") - - # Token should not be fetched yet - mock_authentication_client.authenticate.assert_not_called() - - # Access token - token = manager.session_token - - # Now authenticate should have been called - mock_authentication_client.authenticate.assert_called_once_with("key", "secret") - assert token == "mock-access-token" - - def test_session_token_cached(self, mock_authentication_client): - """Test that session token is cached after first access.""" - manager = SessionManager(mock_authentication_client, "key", "secret") - - # Access token twice - token1 = manager.session_token - token2 = manager.session_token - - # Authenticate should only be called once - mock_authentication_client.authenticate.assert_called_once() - assert token1 == token2 + def test_session_token_delegates_to_auth_provider(self): + """Test that session_token reads from the auth provider.""" + mock_provider = Mock() + mock_provider.get_session_token.return_value = "my-token" + manager = SessionManager(mock_provider) - def test_refresh_session(self, mock_authentication_client): - """Test manual session refresh.""" - manager = SessionManager(mock_authentication_client, "key", "secret") + assert manager.session_token == "my-token" + mock_provider.get_session_token.assert_called_once() - # Access token to initialize - _ = manager.session_token - assert mock_authentication_client.authenticate.call_count == 1 + def test_refresh_session_delegates_to_auth_provider(self): + """Test that refresh_session calls auth provider's refresh.""" + mock_provider = Mock() + manager = SessionManager(mock_provider) - # Refresh session - mock_authentication_client.authenticate.return_value = "new-token" manager.refresh_session() - assert mock_authentication_client.authenticate.call_count == 2 - assert manager.session_token == "new-token" - - def test_refresh_session_without_prior_access(self, mock_authentication_client): - """Test refresh_session can be called without prior token access.""" - manager = SessionManager(mock_authentication_client, "key", "secret") - - manager.refresh_session() - - mock_authentication_client.authenticate.assert_called_once_with("key", "secret") + mock_provider.refresh.assert_called_once() class TestBaseClient: @@ -203,9 +162,11 @@ def test_method(self): class TestBaseClientIntegration: """Integration tests for BaseClient with SessionManager.""" - def test_client_uses_session_token(self, mock_authentication_client): - """Test that client methods can access session token.""" - session_manager = SessionManager(mock_authentication_client, "key", "secret") + def test_client_uses_session_token(self): + """Test that client methods can access session token via auth provider.""" + mock_provider = Mock() + mock_provider.get_session_token.return_value = "my-access-token" + session_manager = SessionManager(mock_provider) class TestClient(BaseClient): @BaseClient.retry_with_refresh @@ -215,11 +176,16 @@ def get_auth_header(self): client = TestClient(session_manager) header = client.get_auth_header() - assert header == "Bearer mock-access-token" + assert header == "Bearer my-access-token" + + def test_retry_refreshes_token_and_succeeds(self): + """Test that a 401 triggers refresh via auth provider and the retry uses the new token.""" + mock_provider = Mock() + # get_session_token is only called on the successful retry (first attempt raises before reading token) + mock_provider.get_session_token.return_value = "refreshed-token" + + session_manager = SessionManager(mock_provider) - def test_refresh_updates_token_for_next_call(self, mock_authentication_client): - """Test that after refresh, subsequent calls use new token.""" - session_manager = SessionManager(mock_authentication_client, "key", "secret") call_count = [0] class TestClient(BaseClient): @@ -232,11 +198,9 @@ def get_token(self): raise requests.exceptions.HTTPError(response=response) return self.session_manager.session_token - # First call returns 'mock-access-token', refresh returns 'refreshed-token' - mock_authentication_client.authenticate.side_effect = ["mock-access-token", "refreshed-token"] - client = TestClient(session_manager) - client.get_token() + token = client.get_token() - # The refresh_session was called, showing the retry mechanism worked - assert call_count[0] == 2 # Verifies retry happened + assert token == "refreshed-token" + mock_provider.refresh.assert_called_once() + assert call_count[0] == 2 diff --git a/tests/test_config.py b/tests/test_config.py index 3b48ae7..324d8c3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -123,21 +123,21 @@ def test_local_environment_custom_api_hosts(self, tmp_path): assert config.API_HOST == "https://custom.api.com" assert config.API_HOST2 == "https://custom.api2.com" - def test_local_environment_api_credentials(self, tmp_path): - """Test API credentials loading.""" + def test_local_environment_session_tokens(self, tmp_path): + """Test session token loading.""" env_vars = { "ENVIRONMENT": "local", "INPUT_DIR": str(tmp_path), "OUTPUT_DIR": str(tmp_path), - "PENNSIEVE_API_KEY": "test-api-key", - "PENNSIEVE_API_SECRET": "test-api-secret", + "SESSION_TOKEN": "test-session-token", + "REFRESH_TOKEN": "test-refresh-token", } with patch.dict(os.environ, env_vars, clear=True): config = Config() - assert config.API_KEY == "test-api-key" - assert config.API_SECRET == "test-api-secret" + assert config.SESSION_TOKEN == "test-session-token" + assert config.REFRESH_TOKEN == "test-refresh-token" def test_local_importer_enabled_override(self, tmp_path): """Test IMPORTER_ENABLED can be overridden in local environment.""" @@ -232,8 +232,8 @@ def test_missing_output_dir_local(self, tmp_path): config = Config() assert config.OUTPUT_DIR is None - def test_missing_api_credentials(self, tmp_path): - """Test Config with missing API credentials.""" + def test_missing_session_tokens(self, tmp_path): + """Test Config with missing session tokens.""" env_vars = { "ENVIRONMENT": "local", "INPUT_DIR": str(tmp_path), @@ -242,8 +242,8 @@ def test_missing_api_credentials(self, tmp_path): with patch.dict(os.environ, env_vars, clear=True): config = Config() - assert config.API_KEY is None - assert config.API_SECRET is None + assert config.SESSION_TOKEN is None + assert config.REFRESH_TOKEN is None def test_chunk_size_conversion_to_int(self, tmp_path): """Test that CHUNK_SIZE_MB is converted to integer.""" diff --git a/tests/test_workflow_client.py b/tests/test_workflow_client.py index e1518c7..fd52f2e 100644 --- a/tests/test_workflow_client.py +++ b/tests/test_workflow_client.py @@ -49,7 +49,7 @@ def test_get_workflow_instance_success(self, mock_session_manager): """Test successful workflow instance retrieval.""" responses.add( responses.GET, - "https://api.test.com/workflows/instances/wf-instance-123", + "https://api.test.com/compute/workflows/runs/wf-instance-123", json={"uuid": "wf-instance-123", "datasetId": "dataset-456", "packageIds": ["pkg-1", "pkg-2", "pkg-3"]}, status=200, ) @@ -67,7 +67,7 @@ def test_get_workflow_instance_includes_auth_header(self, mock_session_manager): """Test that authorization header is included.""" responses.add( responses.GET, - "https://api.test.com/workflows/instances/wf-123", + "https://api.test.com/compute/workflows/runs/wf-123", json={"uuid": "wf-123", "datasetId": "ds-1", "packageIds": []}, status=200, ) @@ -82,7 +82,7 @@ def test_get_workflow_instance_includes_auth_header(self, mock_session_manager): def test_get_workflow_instance_raises_on_http_error(self, mock_session_manager): """Test that HTTP errors are raised.""" responses.add( - responses.GET, "https://api.test.com/workflows/instances/wf-123", json={"error": "Not found"}, status=404 + responses.GET, "https://api.test.com/compute/workflows/runs/wf-123", json={"error": "Not found"}, status=404 ) client = WorkflowClient("https://api.test.com", mock_session_manager) @@ -93,7 +93,7 @@ def test_get_workflow_instance_raises_on_http_error(self, mock_session_manager): @responses.activate def test_get_workflow_instance_raises_on_invalid_json(self, mock_session_manager): """Test that invalid JSON raises error.""" - responses.add(responses.GET, "https://api.test.com/workflows/instances/wf-123", body="not json", status=200) + responses.add(responses.GET, "https://api.test.com/compute/workflows/runs/wf-123", body="not json", status=200) client = WorkflowClient("https://api.test.com", mock_session_manager) @@ -105,7 +105,7 @@ def test_get_workflow_instance_with_single_package(self, mock_session_manager): """Test workflow instance with single package ID.""" responses.add( responses.GET, - "https://api.test.com/workflows/instances/wf-123", + "https://api.test.com/compute/workflows/runs/wf-123", json={"uuid": "wf-123", "datasetId": "ds-1", "packageIds": ["single-pkg"]}, status=200, ) @@ -125,12 +125,12 @@ def test_get_workflow_instance_retries_on_401(self, mock_session_manager): """Test that get_workflow_instance retries after 401.""" # First call returns 401 responses.add( - responses.GET, "https://api.test.com/workflows/instances/wf-123", json={"error": "Unauthorized"}, status=401 + responses.GET, "https://api.test.com/compute/workflows/runs/wf-123", json={"error": "Unauthorized"}, status=401 ) # Second call succeeds responses.add( responses.GET, - "https://api.test.com/workflows/instances/wf-123", + "https://api.test.com/compute/workflows/runs/wf-123", json={"uuid": "wf-123", "datasetId": "ds-1", "packageIds": []}, status=200, ) @@ -146,12 +146,12 @@ def test_get_workflow_instance_retries_on_403(self, mock_session_manager): """Test that get_workflow_instance retries after 403.""" # First call returns 403 responses.add( - responses.GET, "https://api.test.com/workflows/instances/wf-123", json={"error": "Forbidden"}, status=403 + responses.GET, "https://api.test.com/compute/workflows/runs/wf-123", json={"error": "Forbidden"}, status=403 ) # Second call succeeds responses.add( responses.GET, - "https://api.test.com/workflows/instances/wf-123", + "https://api.test.com/compute/workflows/runs/wf-123", json={"uuid": "wf-123", "datasetId": "ds-1", "packageIds": []}, status=200, )