From afcd310fefb9b916a0394eb216566e55b9ef82e3 Mon Sep 17 00:00:00 2001 From: stoader Date: Fri, 13 Feb 2026 16:42:13 +0100 Subject: [PATCH 1/2] Support minting Azure access token from Vault sourced Azure credentials also from Vault sourced GCP service account keys --- go.mod | 3 +- go.sum | 2 + pkg/aws/creds.go | 10 +- pkg/azure/creds.go | 28 ++-- pkg/gcp/creds.go | 10 +- pkg/generic/creds.go | 10 +- pkg/oauth2cc/oauth2cc.go | 5 +- pkg/oci/creds.go | 20 +-- pkg/util/credential.go | 8 + pkg/vault/azure.go | 109 +++++++++++++ pkg/vault/creds.go | 326 +++++++++++++++++++++++++++++++-------- pkg/vault/gcp.go | 115 ++++++++++++++ pkg/vault/option.go | 58 +++++++ 13 files changed, 574 insertions(+), 130 deletions(-) create mode 100644 pkg/vault/azure.go create mode 100644 pkg/vault/gcp.go diff --git a/go.mod b/go.mod index 8f62fe0..1c5cef9 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 // indirect github.com/aws/smithy-go v1.22.4 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect @@ -52,7 +53,7 @@ require ( github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.23.0 // indirect - github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 github.com/gogo/protobuf v1.3.2 // indirect github.com/google/gnostic-models v0.6.9 // indirect github.com/google/go-cmp v0.7.0 // indirect diff --git a/go.sum b/go.sum index 8f03c6d..61dfc65 100644 --- a/go.sum +++ b/go.sum @@ -78,6 +78,8 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f h1:C5bqEmzEPLsHm9Mv73lSE9e9bKV23aB1vxOsmZrkl3k= diff --git a/pkg/aws/creds.go b/pkg/aws/creds.go index b3778e2..d5f9f39 100644 --- a/pkg/aws/creds.go +++ b/pkg/aws/creds.go @@ -78,10 +78,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, provi // Get credentials awsCreds, err := provider.Retrieve(ctx) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to retrieve credentials"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to retrieve credentials")) return } @@ -111,10 +108,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, provi // If credentials are already expired, this is an error if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "expiresAt", awsCreds.Expires), - }) + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", awsCreds.Expires)) return } diff --git a/pkg/azure/creds.go b/pkg/azure/creds.go index 99177a9..95469db 100644 --- a/pkg/azure/creds.go +++ b/pkg/azure/creds.go @@ -86,10 +86,17 @@ func (cp *credentialsProvider) refreshCredentialsLoop( Scopes: []string{cfg.scope}, }) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to retrieve credentials"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to retrieve credentials")) + + return + } + + // Calculate when to refresh + timeUntilExpiry := time.Until(token.ExpiresOn) + + // If credentials are already expired, this is an error + if timeUntilExpiry <= 0 { + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresOn)) return } @@ -107,19 +114,6 @@ func (cp *credentialsProvider) refreshCredentialsLoop( }) cp.logger.V(2).Info("Sent credentials", "expires", token.ExpiresOn) - // Calculate when to refresh - timeUntilExpiry := time.Until(token.ExpiresOn) - - // If credentials are already expired, this is an error - if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresOn), - }) - - return - } - refreshBuffer := util.CalculateRefreshBuffer(timeUntilExpiry) refreshTime := timeUntilExpiry - refreshBuffer diff --git a/pkg/gcp/creds.go b/pkg/gcp/creds.go index c7780ed..22a04ce 100644 --- a/pkg/gcp/creds.go +++ b/pkg/gcp/creds.go @@ -90,10 +90,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop( for { accessToken, err := genAccessTokenFunc() if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to get access token"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get access token")) return } @@ -113,10 +110,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop( // If credentials are already expired, this is an error if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "expiresAt", accessToken.Expiry), - }) + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", accessToken.Expiry)) return } diff --git a/pkg/generic/creds.go b/pkg/generic/creds.go index de6bb00..23d1763 100644 --- a/pkg/generic/creds.go +++ b/pkg/generic/creds.go @@ -159,10 +159,7 @@ loop: token, err = tokenProvider.GetToken(ctx, opts...) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "could not create token"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "could not create token")) return } @@ -190,10 +187,7 @@ loop: // If credentials are already expired, this is an error if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresAt.Format(time.DateTime)), - }) + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresAt.Format(time.DateTime))) return } diff --git a/pkg/oauth2cc/oauth2cc.go b/pkg/oauth2cc/oauth2cc.go index 2571ff8..af94d47 100644 --- a/pkg/oauth2cc/oauth2cc.go +++ b/pkg/oauth2cc/oauth2cc.go @@ -262,10 +262,7 @@ func (r *tokenRetriever) tokenRefresherLoop(ctx context.Context) { // If credentials are already expired, this is an error if timeUntilExpiry <= 0 { - util.SendToChannel(r.ch, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired token", "expiresAt", token.Expiry), - }) + util.SendErrorToChannel(r.ch, errors.NewWithDetails("received already expired token", "expiresAt", token.Expiry)) return } diff --git a/pkg/oci/creds.go b/pkg/oci/creds.go index 27b41b0..43e7f0b 100644 --- a/pkg/oci/creds.go +++ b/pkg/oci/creds.go @@ -95,30 +95,21 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg * for { idToken, err := cfg.identityTokenProvider.GetToken(ctx) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to get identity token"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get identity token")) return } authToken, err := exchangeToken(ctx, tokenEndpoint, cfg.clientID, cfg.clientSecret, idToken.Token, publicKey) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "token exchange failed"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "token exchange failed")) return } expTime, err := getTokenExpiration(authToken) if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to get expiration time of the received UPST"), - }) + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get expiration time of the received UPST")) return } @@ -134,10 +125,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg * // If credentials are already expired, this is an error if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "expiresAt", ociCredential.ExpiresAt), - }) + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", ociCredential.ExpiresAt)) return } diff --git a/pkg/util/credential.go b/pkg/util/credential.go index 8a454ac..1fa0615 100644 --- a/pkg/util/credential.go +++ b/pkg/util/credential.go @@ -44,3 +44,11 @@ func SendToChannel(credsChan chan credential.Result, result credential.Result) { credsChan <- result } } + +// SendErrorToChannel is a helper function to send an error result to the credentials channel. +func SendErrorToChannel(credsChan chan credential.Result, err error) { + SendToChannel(credsChan, credential.Result{ + Credential: nil, + Err: err, + }) +} diff --git a/pkg/vault/azure.go b/pkg/vault/azure.go new file mode 100644 index 0000000..96cd4d7 --- /dev/null +++ b/pkg/vault/azure.go @@ -0,0 +1,109 @@ +// Copyright (c) 2026 Riptides Labs, Inc. +// SPDX-License-Identifier: MIT + +package vault + +import ( + "context" + "time" + + "emperror.dev/errors" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/cenkalti/backoff/v5" + "github.com/go-logr/logr" + + "go.riptides.io/tokenex/pkg/credential" + "go.riptides.io/tokenex/pkg/util" +) + +// VaultAzureSecret represents the structure of the secret data returned by Vault's Azure secrets engine when configured to return Azure credentials. +type vaultAzureSecret struct { + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` +} + +type azureAccessTokenProvider struct { + tenantID string + clientID string + clientSecret string + scopes []string + + logger logr.Logger +} + +// GetCredentials begins the process of exchanging the client ID and client secret for an Azure access token and refreshing it as needed until the context is canceled. +func (r *azureAccessTokenProvider) GetCredentials(ctx context.Context, credsChan chan credential.Result) { + b := backoff.NewExponentialBackOff() + + for { + azClientCreds, err := azidentity.NewClientSecretCredential(r.tenantID, r.clientID, r.clientSecret, nil) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to create Azure client secret credential")) + + return + } + + token, err := backoff.Retry(ctx, func() (azcore.AccessToken, error) { + token, err := azClientCreds.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: r.scopes, + }) + + return token, err + }, backoff.WithBackOff(b), backoff.WithMaxElapsedTime(30*time.Second)) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get Azure access token from client secret credential")) + + return + } + + // Calculate when to refresh + timeUntilExpiry := time.Until(token.ExpiresOn) + + // If credentials are already expired, this is an error + if timeUntilExpiry <= 0 { + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired access token from Azure", "expiresAt", token.ExpiresOn)) + + return + } + + // Send credentials + azureCredential := &credential.Oauth2Creds{ + AccessToken: token.Token, + TokenType: "Bearer", + Expiry: token.ExpiresOn, + } + util.SendToChannel(credsChan, credential.Result{ + Credential: azureCredential, + Err: nil, + Event: credential.UpdateEventType, + }) + r.logger.V(2).Info("Sent credentials", "expires", token.ExpiresOn) + + refreshBuffer := util.CalculateRefreshBuffer(timeUntilExpiry) + refreshTime := timeUntilExpiry - refreshBuffer + + if !token.RefreshOn.IsZero() { + // if refresh time is recommended in the received token, use that + r.logger.V(2).Info("Using RefreshOn time from token", "refreshOn", token.RefreshOn) + + rt := time.Until(token.RefreshOn) + if rt > 0 { + refreshTime = rt + } else { + r.logger.V(2).Info("RefreshOn time is in the past, using calculated refresh time", "refreshTime", refreshTime) + } + } + + select { + case <-ctx.Done(): + r.logger.V(1).Info("Context cancelled, stopping credential refresh") + + return + case <-time.After(refreshTime): + // Continue to next iteration to refresh + r.logger.V(2).Info("Refreshing access token") + } + } +} diff --git a/pkg/vault/creds.go b/pkg/vault/creds.go index 584a018..e893013 100644 --- a/pkg/vault/creds.go +++ b/pkg/vault/creds.go @@ -9,6 +9,7 @@ import ( "emperror.dev/errors" "github.com/go-logr/logr" + "github.com/go-viper/mapstructure/v2" jwtauth "github.com/openbao/openbao/api/auth/jwt/v2" "github.com/openbao/openbao/api/v2" @@ -21,6 +22,72 @@ import ( // ErrDataNotFound is returned when secret at a the specified secret path does not exist in Vault. var ErrDataNotFound = errors.New("data not found") +// gcpCredentialConfig holds configuration specific to GCP credentials returned by the Google Cloud secrets engine in Vault. +type gcpCredentialConfig struct { + // Exchange service account key material for a short-lived access token + // This option is used only when the Google Cloud secrets engine is configured to return service account keys. + exchangeSAKeyForAccessToken bool + + // accessTokenScopes specifies the scopes to request when exchanging a service account key for an access token. + // This option is used only when exchangeSAKeyForAccessToken is true. + // If not set, it defaults to ["https://www.googleapis.com/auth/cloud-platform"]. + accessTokenScopes []string +} + +func (g *gcpCredentialConfig) ExchangeSAKeyForAccessToken() bool { + if g != nil { + return g.exchangeSAKeyForAccessToken + } + + return false +} + +func (g *gcpCredentialConfig) AccessTokenScopes() []string { + if g != nil { + return g.accessTokenScopes + } + + return nil +} + +// azureCredentialConfig holds configuration specific to Azure credentials returned by the Azure secrets engine in Vault. +type azureCredentialConfig struct { + // exchangeForAccessToken indicates whether to exchange the client ID and client secret returned by Vault for an Azure access token. + exchangeForAccessToken bool + + // tenantID is the Azure tenant ID to use when exchanging Vault credentials for an Azure access token. This is required when exchangeForAccessToken is true. + tenantID string + + // accessTokenScopes specifies the scopes to request when exchanging Azure credentials for an access token. + // This option is used only when exchangeForAccessToken is true. + // If not set, it defaults to ["https://management.azure.com/.default"]. + accessTokenScopes []string +} + +func (a *azureCredentialConfig) ExchangeForAccessToken() bool { + if a != nil { + return a.exchangeForAccessToken + } + + return false +} + +func (a *azureCredentialConfig) TenantID() string { + if a != nil { + return a.tenantID + } + + return "" +} + +func (a *azureCredentialConfig) AccessTokenScopes() []string { + if a != nil { + return a.accessTokenScopes + } + + return nil +} + // credentialsConfig holds the configuration for GetCredentials. type credentialsConfig struct { jwtAuthMethodPath string @@ -29,6 +96,9 @@ type credentialsConfig struct { pollInterval time.Duration reqData map[string][]string identityTokenProvider token.IdentityTokenProvider + + gcp *gcpCredentialConfig + azure *azureCredentialConfig } // credentialData holds the secret data and expiration information returned from Vault. @@ -98,6 +168,12 @@ func validateConfig(cfg *credentialsConfig) error { return errors.New("identity token provider must be specified") } + if cfg.azure != nil { + if cfg.azure.exchangeForAccessToken && cfg.azure.tenantID == "" { + return errors.New("Azure tenant ID must be specified when exchange for access token is enabled") + } + } + return nil } @@ -133,17 +209,38 @@ func (cp *credentialsProvider) authenticateWithJWT(ctx context.Context, idToken return nil } +func (cp *credentialsProvider) authenticate(ctx context.Context, cfg *credentialsConfig) error { + // Get ID token + idToken, err := cfg.identityTokenProvider.GetToken(ctx) + if err != nil { + return errors.WrapIf(err, "failed to get ID token") + } + + // Authenticate with Vault using JWT + err = cp.authenticateWithJWT(ctx, idToken, cfg.jwtAuthMethodPath, cfg.jwtAuthRoleName) + if err != nil { + return errors.WrapIfWithDetails(err, "failed to authenticate with Vault", "auth_path", cfg.jwtAuthMethodPath, "role", cfg.jwtAuthRoleName) + } + + return nil +} + // retrieveCredentials retrieves a secret from Vault at the specified path. // For dynamic secrets (with a lease), expiration is based on the lease duration. // For static secrets (no lease), expiration is based on the secret's TTL if available, // or falls back to the poll interval to ensure periodic refresh. -func (cp *credentialsProvider) retrieveCredentials(ctx context.Context, secretPath string, pollInterval time.Duration, reqData map[string][]string) (*credentialData, error) { - secret, err := cp.client.Logical().ReadWithDataWithContext(ctx, secretPath, reqData) +func (cp *credentialsProvider) retrieveCredentials(ctx context.Context, cfg *credentialsConfig) (*credentialData, error) { + err := cp.authenticate(ctx, cfg) if err != nil { - return nil, errors.WrapIfWithDetails(err, "failed to read secret", "path", secretPath) + return nil, errors.WrapIf(err, "failed to authenticate with Vault") + } + + secret, err := cp.client.Logical().ReadWithDataWithContext(ctx, cfg.secretFullPath, cfg.reqData) + if err != nil { + return nil, errors.WrapIfWithDetails(err, "failed to read secret", "path", cfg.secretFullPath) } if secret == nil { - return nil, errors.WithDetails(ErrDataNotFound, "path", secretPath) + return nil, errors.WithDetails(ErrDataNotFound, "path", cfg.secretFullPath) } var expiresAt time.Time @@ -164,7 +261,7 @@ func (cp *credentialsProvider) retrieveCredentials(ctx context.Context, secretPa // If no TTL is present, fall back to using the poll interval to ensure the secret is periodically refreshed. ttl, err := secret.TokenTTL() if err != nil { - return nil, errors.WrapIfWithDetails(err, "failed to get secret TTL", "path", secretPath) + return nil, errors.WrapIfWithDetails(err, "failed to get secret TTL", "path", cfg.secretFullPath) } // Add a small leeway to allow Vault to rotate static credentials before we attempt to refresh. @@ -172,7 +269,7 @@ func (cp *credentialsProvider) retrieveCredentials(ctx context.Context, secretPa if ttl == 0 { // No TTL means the secret does not expire and Vault will not rotate it automatically. // In this case, set the expiration to the poll interval to ensure we periodically check for updates to the secret in Vault. - ttl = pollInterval + ttl = cfg.pollInterval staticCredsRotationLeeway = 0 // No leeway needed since Vault won't rotate this credential. } expiresAt = time.Now().Add(ttl) @@ -184,89 +281,182 @@ func (cp *credentialsProvider) retrieveCredentials(ctx context.Context, secretPa }, nil } -// refreshCredentialsLoop handles the credential retrieval and refresh loop. -func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg *credentialsConfig, credsChan chan credential.Result) { - for { - // Get ID token - idToken, err := cfg.identityTokenProvider.GetToken(ctx) - if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to get ID token"), - }) +// startGcpAccessTokenProvider starts a worker goroutine that exchanges a GCP service account key for an access token and refreshes it as needed until the context is canceled. +func (cp *credentialsProvider) startGcpAccessTokenProvider(ctx context.Context, secretData map[string]any, scopes []string) (<-chan credential.Result, error) { + var serviceAccountKeySecret gcpServiceAccountKeySecret + if err := mapstructure.Decode(secretData, &serviceAccountKeySecret); err != nil { + return nil, errors.WrapIf(err, "failed to decode Vault secret data into GCP service account key credentials structure") + } - return - } + keyJSON, err := serviceAccountKeySecret.ServiceAccountKeyJSON() + if err != nil { + return nil, errors.WrapIf(err, "failed to get service account key JSON") + } - // Authenticate with Vault using JWT - err = cp.authenticateWithJWT(ctx, idToken, cfg.jwtAuthMethodPath, cfg.jwtAuthRoleName) - if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to authenticate with Vault"), - }) + if len(scopes) == 0 { + scopes = []string{"https://www.googleapis.com/auth/cloud-platform"} + } - return - } + provider := gcpAccessTokenProvider{ + serviceAccountKeyJSON: keyJSON, + scopes: scopes, + logger: logr.FromContextOrDiscard(ctx).WithName("gcp_access_token_provider"), + } - // Retrieve the secret - creds, err := cp.retrieveCredentials(ctx, cfg.secretFullPath, cfg.pollInterval, cfg.reqData) - if err != nil { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.WrapIf(err, "failed to retrieve secret"), - }) + credsChan := make(chan credential.Result, 1) + go func() { + defer close(credsChan) + provider.GetCredentials(ctx, credsChan) + }() - return - } + return credsChan, nil +} - // Calculate when to refresh - timeUntilExpiry := time.Until(creds.ExpiresAt) +// startAzureAccessTokenProvider starts a worker goroutine that exchanges Azure credentials for an access token and refreshes it as needed until the context is canceled. +func (cp *credentialsProvider) startAzureAccessTokenProvider(ctx context.Context, secretData map[string]any, tenantID string, scopes []string) (<-chan credential.Result, error) { + var azureSecret vaultAzureSecret - // If credentials are already expired, this is an error - if timeUntilExpiry <= 0 { - util.SendToChannel(credsChan, credential.Result{ - Credential: nil, - Err: errors.NewWithDetails("received already expired credentials", "secret_path", cfg.secretFullPath, "expiresAt", creds.ExpiresAt), - }) + if err := mapstructure.Decode(secretData, &azureSecret); err != nil { + return nil, errors.WrapIf(err, "failed to decode Vault secret data into Azure credentials structure") + } - return - } + if len(scopes) == 0 { + scopes = []string{"https://management.azure.com/.default"} + } - // Send credentials - util.SendToChannel(credsChan, credential.Result{ - Credential: &credential.VaultSecret{ - Data: creds.Data, - }, - Err: nil, - Event: credential.UpdateEventType, - }) + provider := azureAccessTokenProvider{ + tenantID: tenantID, + clientID: azureSecret.ClientID, + clientSecret: azureSecret.ClientSecret, + scopes: scopes, + } - cp.logger.V(2).Info("Published Vault secret", "secret_path", cfg.secretFullPath, "expiresAt", creds.ExpiresAt) + credsChan := make(chan credential.Result, 1) + go func() { + defer close(credsChan) + provider.GetCredentials(ctx, credsChan) + }() - // Apply refresh buffer - var refreshBuffer, refreshTime time.Duration + return credsChan, nil +} - if !creds.RefreshOn.IsZero() { - // if refresh time is specified in the received credentials, use that - cp.logger.V(2).Info("Using RefreshOn time from credentials", "refreshOn", creds.RefreshOn) +// shouldStartWorker determines whether a worker goroutine should be started to handle credential exchange and refresh based on the configuration. +func shouldStartWorker(cfg *credentialsConfig) bool { + return cfg.gcp.ExchangeSAKeyForAccessToken() || cfg.azure.ExchangeForAccessToken() +} - refreshTime = time.Until(creds.RefreshOn) - } else { - refreshBuffer = util.CalculateRefreshBuffer(timeUntilExpiry) - refreshTime = timeUntilExpiry - refreshBuffer +func (cp *credentialsProvider) startWorker(ctx context.Context, cfg *credentialsConfig, credsData map[string]any) (<-chan credential.Result, context.CancelFunc, error) { + if cfg.gcp.ExchangeSAKeyForAccessToken() { + // get access token using the service account key data in the Vault secret + // and send the access token through the channel instead of the raw service account key data + ctx, cancel := context.WithCancel(ctx) + credsChan, err := cp.startGcpAccessTokenProvider(ctx, credsData, cfg.gcp.accessTokenScopes) + if err != nil { + return nil, cancel, errors.WrapIf(err, "failed to start GCP access token provider") } - cp.logger.V(1).Info("Scheduling credential refresh", "refreshIn", refreshTime, "refreshBuffer", refreshBuffer, "secret_path", cfg.secretFullPath) + return credsChan, cancel, nil + } + if cfg.azure.ExchangeForAccessToken() { + // get access token using the client ID and client secret data in the Vault secret + ctx, cancel := context.WithCancel(ctx) + credsChan, err := cp.startAzureAccessTokenProvider(ctx, credsData, cfg.azure.TenantID(), cfg.azure.AccessTokenScopes()) + if err != nil { + return nil, cancel, errors.WrapIf(err, "failed to start Azure access token provider") + } + + return credsChan, cancel, nil + } + + return nil, nil, nil +} + +// refreshCredentialsLoop handles the credential retrieval and refresh loop. +func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg *credentialsConfig, credsChan chan credential.Result) { + var cancelWorker context.CancelFunc + var workerChan <-chan credential.Result + + var refreshBuffer, refreshTime time.Duration + + logger := cp.logger.WithValues("secret_path", cfg.secretFullPath) + + for { select { case <-ctx.Done(): - cp.logger.V(1).Info("Context cancelled, stopping credential refresh") + logger.V(1).Info("Context cancelled, stopping credential refresh") return case <-time.After(refreshTime): - // Continue to next iteration to refresh - cp.logger.V(2).Info("Refreshing credentials", "secret_path", cfg.secretFullPath) + logger.V(2).Info("Refreshing credentials") + if cancelWorker != nil { + cancelWorker() // cancel any existing worker goroutine before starting a new one to refresh credentials + cancelWorker = nil + } + + // Retrieve the secret + creds, err := cp.retrieveCredentials(ctx, cfg) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIfWithDetails(err, "failed to retrieve secret", "secret_path", cfg.secretFullPath)) + + return + } + + // Calculate when to refresh + timeUntilExpiry := time.Until(creds.ExpiresAt) + + // If credentials are already expired, this is an error + if timeUntilExpiry <= 0 { + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "secret_path", cfg.secretFullPath, "expiresAt", creds.ExpiresAt)) + + return + } + + if shouldStartWorker(cfg) { + workerChan, cancelWorker, err = cp.startWorker(logr.NewContext(ctx, logger), cfg, creds.Data) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to start credential worker")) + + return + } + } else { + // Send credentials + util.SendToChannel(credsChan, credential.Result{ + Credential: &credential.VaultSecret{ + Data: creds.Data, + }, + Err: nil, + Event: credential.UpdateEventType, + }) + logger.V(2).Info("Published Vault secret", "secret_path", cfg.secretFullPath, "expiresAt", creds.ExpiresAt) + } + + if !creds.RefreshOn.IsZero() { + // if refresh time is specified in the received credentials, use that + logger.V(2).Info("Using RefreshOn time from credentials", "refreshOn", creds.RefreshOn) + + refreshTime = time.Until(creds.RefreshOn) + } else { + refreshBuffer = util.CalculateRefreshBuffer(timeUntilExpiry) + refreshTime = timeUntilExpiry - refreshBuffer + } + + logger.V(1).Info("Scheduling credential refresh", "refreshIn", refreshTime, "refreshBuffer", refreshBuffer, "secret_path", cfg.secretFullPath) + case result, ok := <-workerChan: + if !ok { + // Worker channel closed, likely due to an error in the worker goroutine + util.SendErrorToChannel(credsChan, errors.New("credential worker stopped unexpectedly")) + + return + } + + if result.Err != nil { + util.SendErrorToChannel(credsChan, result.Err) + + return + } + + util.SendToChannel(credsChan, result) } } } diff --git a/pkg/vault/gcp.go b/pkg/vault/gcp.go new file mode 100644 index 0000000..f252b51 --- /dev/null +++ b/pkg/vault/gcp.go @@ -0,0 +1,115 @@ +// Copyright (c) 2026 Riptides Labs, Inc. +// SPDX-License-Identifier: MIT + +package vault + +import ( + "context" + "encoding/base64" + "strings" + "time" + + "emperror.dev/errors" + "github.com/cenkalti/backoff/v5" + "github.com/go-logr/logr" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + + "go.riptides.io/tokenex/pkg/credential" + "go.riptides.io/tokenex/pkg/util" +) + +// gcpServiceAccountKeySecret represents the structure of the service account key material returned by Vault's Google Cloud secrets engine when configured to return service account keys. +// It contains the base64-encoded private key data. +type gcpServiceAccountKeySecret struct { + PrivateKeyData string `mapstructure:"private_key_data"` +} + +// ServiceAccountKeyJSON decodes the base64-encoded private key data from the Vault secret and returns it as a byte slice. +func (s *gcpServiceAccountKeySecret) ServiceAccountKeyJSON() ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(s.PrivateKeyData) + if err != nil { + return nil, errors.WrapIf(err, "private_key_data in Vault secret data for GCP credentials is not valid base64 encoded string") + } + + return key, nil +} + +// gcpAccessTokenProvider is responsible for exchanging a GCP service account key for an access token and refreshing it as needed until stopped. +type gcpAccessTokenProvider struct { + serviceAccountKeyJSON []byte + scopes []string + + logger logr.Logger +} + +// GetCredentials begins the process of exchanging the service account key for an access token and refreshing it as needed until the context is canceled. +func (r *gcpAccessTokenProvider) GetCredentials(ctx context.Context, credsChan chan credential.Result) { + b := backoff.NewExponentialBackOff() + + for { + // use the service account key to authenticate to GCP and obtain an access token + gcpCreds, err := google.CredentialsFromJSON(ctx, r.serviceAccountKeyJSON, r.scopes...) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to obtain GCP credentials from service account key")) + + return + } + + // if the SA key was just created it's posisble that it may take a few seconds for GCP to propagate the key and allow it to be used for authentication. + // This can result in transient errors when trying to exchange the key for an access token. + token, err := backoff.Retry(ctx, func() (*oauth2.Token, error) { + token, err := gcpCreds.TokenSource.Token() + if err != nil { + if strings.Contains(err.Error(), "invalid_grant") { + r.logger.V(2).Info("Received invalid_grant error when exchanging service account key for access token, likely due to GCP propagation delay. Retrying...", "error", err) + + return nil, err + } + + return nil, backoff.Permanent(errors.WrapIf(err, "failed to exchange service account key for access token")) + } + + return token, nil + }, backoff.WithBackOff(b), backoff.WithMaxElapsedTime(30*time.Second)) + if err != nil { + util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to obtain access token from GCP using service account key")) + + return + } + + // Calculate when to refresh + timeUntilExpiry := time.Until(token.Expiry) + + // If credentials are already expired, this is an error + if timeUntilExpiry <= 0 { + util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired access token from GCP", "expiresAt", token.Expiry)) + + return + } + + gcpCredential := credential.Oauth2Creds(*token) + util.SendToChannel(credsChan, credential.Result{ + Credential: &gcpCredential, + Err: nil, + Event: credential.UpdateEventType, + }) + + r.logger.V(2).Info("Published access token", "expiresAt", token.Expiry) + + refreshBuffer := util.CalculateRefreshBuffer(timeUntilExpiry) + refreshTime := timeUntilExpiry - refreshBuffer + + r.logger.V(0).Info("Scheduling access token refresh", "refreshIn", refreshTime, "refreshBuffer", refreshBuffer) + + select { + case <-ctx.Done(): + r.logger.V(1).Info("Context cancelled, stopping access token refresh") + + return + case <-time.After(refreshTime): + // Continue to next iteration to refresh + r.logger.V(2).Info("Refreshing access token") + } + } +} diff --git a/pkg/vault/option.go b/pkg/vault/option.go index d7ee200..8b30244 100644 --- a/pkg/vault/option.go +++ b/pkg/vault/option.go @@ -98,3 +98,61 @@ func WithRequestData(data map[string][]string) option.Option { c.reqData = data }) } + +// WithGCPServiceAccountKeyExchange configures the credential exchange to handle GCP service account keys returned by Vault's GCP secrets engine. +// This option is only applicable if the Vault secret being retrieved contains GCP service account key data. +// When enabled, if the Vault secret contains GCP service account key data, the credential exchange will use that key data to obtain an access token from GCP and return the access token instead of the raw service account key data. +func WithGCPServiceAccountKeyExchange() option.Option { + return withCredentialsOption(func(c *credentialsConfig) { + if c.gcp == nil { + c.gcp = &gcpCredentialConfig{} + } + c.gcp.exchangeSAKeyForAccessToken = true + }) +} + +// WithGCPServiceAccountKeyExchangeScopes sets the scopes to be used when exchanging a GCP service account key for an access token from GCP. +// This option is only applicable if WithGCPServiceAccountKeyExchange is enabled. +// If not set, the default scope used for exchange is "https://www.googleapis.com/auth/cloud-platform". +func WithGCPServiceAccountKeyExchangeScopes(scopes []string) option.Option { + return withCredentialsOption(func(c *credentialsConfig) { + if c.gcp == nil { + c.gcp = &gcpCredentialConfig{} + } + c.gcp.accessTokenScopes = scopes + }) +} + +// WithAzureClientSecretExchange configures the credential exchange to handle Azure credentials returned by Vault's Azure secrets engine. +// This option is only applicable if the Vault secret being retrieved contains Azure credential data (client ID, client secret). +func WithAzureClientSecretExchange() option.Option { + return withCredentialsOption(func(c *credentialsConfig) { + if c.azure == nil { + c.azure = &azureCredentialConfig{} + } + c.azure.exchangeForAccessToken = true + }) +} + +// WithAzureClientSecretExchangeScopes sets the scopes to be used when exchanging Azure credentials for an access token from Azure. +// This option is only applicable if WithAzureClientSecretExchange is enabled. +// If not set, the default scope used for exchange is "https://management.azure.com/.default". +func WithAzureClientSecretExchangeScopes(scopes []string) option.Option { + return withCredentialsOption(func(c *credentialsConfig) { + if c.azure == nil { + c.azure = &azureCredentialConfig{} + } + c.azure.accessTokenScopes = scopes + }) +} + +// WithAzureTenantID sets the tenant ID to be used when exchanging Azure credentials for an access token from Azure. +// This option is required if WithAzureClientSecretExchange is enabled. +func WithAzureTenantID(tenantID string) option.Option { + return withCredentialsOption(func(c *credentialsConfig) { + if c.azure == nil { + c.azure = &azureCredentialConfig{} + } + c.azure.tenantID = tenantID + }) +} From db76e37e33e1d972c9857fa2e78c90412803ca2b Mon Sep 17 00:00:00 2001 From: stoader Date: Fri, 13 Feb 2026 17:09:07 +0100 Subject: [PATCH 2/2] address review comments --- pkg/vault/creds.go | 5 +++-- pkg/vault/gcp.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/vault/creds.go b/pkg/vault/creds.go index e893013..a1644b0 100644 --- a/pkg/vault/creds.go +++ b/pkg/vault/creds.go @@ -329,6 +329,7 @@ func (cp *credentialsProvider) startAzureAccessTokenProvider(ctx context.Context clientID: azureSecret.ClientID, clientSecret: azureSecret.ClientSecret, scopes: scopes, + logger: logr.FromContextOrDiscard(ctx).WithName("azure_access_token_provider"), } credsChan := make(chan credential.Result, 1) @@ -428,7 +429,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg * Err: nil, Event: credential.UpdateEventType, }) - logger.V(2).Info("Published Vault secret", "secret_path", cfg.secretFullPath, "expiresAt", creds.ExpiresAt) + logger.V(2).Info("Published Vault secret", "expiresAt", creds.ExpiresAt) } if !creds.RefreshOn.IsZero() { @@ -441,7 +442,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg * refreshTime = timeUntilExpiry - refreshBuffer } - logger.V(1).Info("Scheduling credential refresh", "refreshIn", refreshTime, "refreshBuffer", refreshBuffer, "secret_path", cfg.secretFullPath) + logger.V(1).Info("Scheduling credential refresh", "refreshIn", refreshTime, "refreshBuffer", refreshBuffer) case result, ok := <-workerChan: if !ok { // Worker channel closed, likely due to an error in the worker goroutine diff --git a/pkg/vault/gcp.go b/pkg/vault/gcp.go index f252b51..a332fd0 100644 --- a/pkg/vault/gcp.go +++ b/pkg/vault/gcp.go @@ -56,7 +56,7 @@ func (r *gcpAccessTokenProvider) GetCredentials(ctx context.Context, credsChan c return } - // if the SA key was just created it's posisble that it may take a few seconds for GCP to propagate the key and allow it to be used for authentication. + // if the SA key was just created it's possible that it may take a few seconds for GCP to propagate the key and allow it to be used for authentication. // This can result in transient errors when trying to exchange the key for an access token. token, err := backoff.Retry(ctx, func() (*oauth2.Token, error) { token, err := gcpCreds.TokenSource.Token()