diff --git a/.gitignore b/.gitignore index 59da84b..c5d348d 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ MANIFEST # Environments .venv venv/ +.idea/ diff --git a/README.md b/README.md index e9be3a8..0439f6e 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ Available options: - `token_duration`: Validity period (in seconds) for retieved authorization tokens. - `aws_access_key_id`: Use a specific AWS access key to authenticate with AWS. - `aws_secret_access_key`: Use a specific AWS secret access key to authenticate with AWS. + - `assume_role`: Role ARN to assume with the current profile name to get the CodeArtifact credentials. + - `assume_role_session_name`: Name to attache to attach for the role session. If not specified, a name will be + selected by AWS SDK. For more explanation of these options see the [AWS CLI documentation](https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html). @@ -56,6 +59,9 @@ profile_name=default aws_access_key_id=xxxxxxxxx aws_secret_access_key=xxxxxxxxx +# Assume the following role to obtain the credentials +assume_role=arn:aws:iam::xxxxxxxxx:role/xxxxxxxxx + ``` ### Multiple Section Configuration (EXPERIMENTAL) diff --git a/keyrings/codeartifact.py b/keyrings/codeartifact.py index 1586e0e..1dcb345 100644 --- a/keyrings/codeartifact.py +++ b/keyrings/codeartifact.py @@ -156,13 +156,7 @@ def get_password(self, service, username): name=repository_name, ) - # Create session with any supplied configuration. - session = boto3.Session( - region_name=region, - profile_name=config.get("profile_name"), - aws_access_key_id=config.get("aws_access_key_id"), - aws_secret_access_key=config.get("aws_secret_access_key"), - ) + session = self.get_boto_session(region=region, config=config) # Create a CodeArtifact client for this repository's region. client = session.client("codeartifact", region_name=region) @@ -193,3 +187,30 @@ def set_password(self, service, username, password): def delete_password(self, service, username): # Defer deleting a password to the next backend raise NotImplementedError() + + @staticmethod + def get_boto_session(*, region, config): + should_assume_role = config.get("assume_role") + + # Create session with any supplied configuration. + session = boto3.Session( + region_name=region, + profile_name=config.get("profile_name"), + aws_access_key_id=config.get("aws_access_key_id"), + aws_secret_access_key=config.get("aws_secret_access_key"), + ) + + if should_assume_role is not None: + assumed_role = session.client("sts").assume_role( + RoleArn=config["assume_role"], + RoleSessionName=config.get( + "assume_role_session_name", "KeyRingsCodeArtifact" + ), + ) + return boto3.Session( + aws_access_key_id=assumed_role["Credentials"]["AccessKeyId"], + aws_secret_access_key=assumed_role["Credentials"]["SecretAccessKey"], + aws_session_token=assumed_role["Credentials"]["SessionToken"], + ) + + return session diff --git a/requirements-test.txt b/requirements-test.txt index c68f3a0..23383e4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ pytest >= 6 pytest-cov +pytest-mock diff --git a/tests/test_backend.py b/tests/test_backend.py index c423966..02ed317 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -8,11 +8,21 @@ from urllib.parse import urlunparse from botocore.client import BaseClient from datetime import datetime, timedelta -from keyrings.codeartifact import CodeArtifactBackend +from keyrings.codeartifact import CodeArtifactBackend, CodeArtifactKeyringConfig @pytest.fixture -def backend(): +def mocked_keyring_config(mocker): + mock_config_instance = mocker.create_autospec( + CodeArtifactKeyringConfig, spec_set=True + ) + mock_config = mocker.patch("keyrings.codeartifact.CodeArtifactKeyringConfig") + mock_config.return_value = mock_config_instance + return mock_config_instance + + +@pytest.fixture +def backend(mocked_keyring_config): # Find the system-wide keyring. original = keyring.get_keyring() @@ -33,6 +43,55 @@ def codeartifact_pypi_url(domain, owner, region, name): return codeartifact_url(domain, owner, region, f"/pypi/{name}/") +def make_check_codeartifact_api_call(*, config, domain, domain_owner): + assumed_role = False + assume_role = config.get("assume_role") + assume_session_name = config.get("assume_session_name") + should_assume_role = assume_role is not None + + def _make_api_call(client, *args, **kwargs): + nonlocal assumed_role + if should_assume_role and not assumed_role: + # We should only ever call GetAuthorizationToken + assert args[0] == "AssumeRole" + + # We should only ever supply these parameters. + assert args[1]["RoleArn"] == assume_role + if assume_session_name is not None: + assert args[1]["RoleSessionName"] == assume_session_name + assumed_role = True + return { + "Credentials": { + "AccessKeyId": "", + "SecretAccessKey": "", + "SessionToken": "", + } + } + else: + assert assumed_role == should_assume_role + + # We should only ever call GetAuthorizationToken + assert args[0] == "GetAuthorizationToken" + + # We should only ever supply these parameters. + assert args[1]["domain"] == domain + assert args[1]["domainOwner"] == domain_owner + assert args[1]["durationSeconds"] == 3600 + + tzinfo = datetime.now().astimezone().tzinfo + current_time = datetime.now(tz=tzinfo) + + # Compute the expiration based on the current timestamp. + expiration = timedelta(seconds=args[1]["durationSeconds"]) + + return { + "authorizationToken": "TOKEN", + "expiration": current_time + expiration, + } + + return _make_api_call + + def test_set_password_raises(backend): with pytest.raises(NotImplementedError): keyring.set_password("service", "username", "password") @@ -67,29 +126,34 @@ def test_get_credential_invalid_path(backend, service): assert not keyring.get_credential(service, None) -def test_get_credential_supported_host(backend, monkeypatch): - def _make_api_call(client, *args, **kwargs): - # We should only ever call GetAuthorizationToken - assert args[0] == "GetAuthorizationToken" - - # We should only ever supply these parameters. - assert args[1]["domain"] == "domain" - assert args[1]["domainOwner"] == "000000000000" - assert args[1]["durationSeconds"] == 3600 - - tzinfo = datetime.now().astimezone().tzinfo - current_time = datetime.now(tz=tzinfo) - - # Compute the expiration based on the current timestamp. - expiration = timedelta(seconds=args[1]["durationSeconds"]) - - return { - "authorizationToken": "TOKEN", - "expiration": current_time + expiration, - } - - monkeypatch.setattr(BaseClient, "_make_api_call", _make_api_call) - url = codeartifact_pypi_url("domain", "000000000000", "region", "name") +@pytest.mark.parametrize( + ["config"], + [ + ({},), + ( + { + "assume_role": "arn:aws:iam::000000000000:role/some-role", + "assume_role_session_name": "SomeSessionName", + }, + ), + ], +) +def test_get_credential_supported_host( + backend, config, mocked_keyring_config, monkeypatch +): + domain = "domain" + domain_owner = "000000000000" + + monkeypatch.setattr( + BaseClient, + "_make_api_call", + make_check_codeartifact_api_call( + config=config, domain=domain, domain_owner=domain_owner + ), + ) + mocked_keyring_config.lookup.return_value = config + + url = codeartifact_pypi_url(domain, domain_owner, "region", "name") credentials = backend.get_credential(url, None) assert credentials.username == "aws" diff --git a/tests/test_config.py b/tests/test_config.py index a6f68bd..5594d1a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,21 +1,12 @@ # test_config.py -- config parsing tests - import pytest from io import StringIO -from os.path import dirname, join +from pathlib import Path from keyrings.codeartifact import CodeArtifactKeyringConfig - -@pytest.fixture -def config_file(): - working_directory = dirname(__file__) - - def _config_file(path): - return join(working_directory, "config", path) - - return _config_file +CONFIG_DIR = Path(__file__).parent / "config" @pytest.mark.parametrize( @@ -26,8 +17,8 @@ def _config_file(path): ("domain", "00000000", "ca-central-1", "repository"), ], ) -def test_parse_single_section_only(config_file, parameters): - config = CodeArtifactKeyringConfig(config_file("single_section.cfg")) +def test_parse_single_section_only(parameters): + config = CodeArtifactKeyringConfig(CONFIG_DIR / "single_section.cfg") # A single section has only one configuration. values = config.lookup(*parameters) @@ -89,8 +80,8 @@ def test_bogus_config_returns_empty_configuration(config_data): ), ], ) -def test_multiple_sections_with_defaults(config_file, query, expected): - path = config_file("multiple_sections_with_default.cfg") +def test_multiple_sections_with_defaults(query, expected): + path = CONFIG_DIR / "multiple_sections_with_default.cfg" config = CodeArtifactKeyringConfig(path) values = config.lookup(**query) @@ -115,8 +106,8 @@ def test_multiple_sections_with_defaults(config_file, query, expected): ), ], ) -def test_multiple_sections_no_defaults(config_file, query, expected): - path = config_file("multiple_sections_no_default.cfg") +def test_multiple_sections_no_defaults(query, expected): + path = CONFIG_DIR / "multiple_sections_no_default.cfg" config = CodeArtifactKeyringConfig(path) values = config.lookup(**query)