Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.30.3 // indirect
github.com/aws/smithy-go v1.22.4 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cenkalti/backoff/v5 v5.0.3
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.11 // indirect
Expand All @@ -52,7 +53,7 @@ require (
github.com/go-openapi/jsonpointer v0.21.0 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.23.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/gnostic-models v0.6.9 // indirect
github.com/google/go-cmp v0.7.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f h1:C5bqEmzEPLsHm9Mv73lSE9e9bKV23aB1vxOsmZrkl3k=
Expand Down
10 changes: 2 additions & 8 deletions pkg/aws/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, provi
// Get credentials
awsCreds, err := provider.Retrieve(ctx)
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "failed to retrieve credentials"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to retrieve credentials"))

return
}
Expand Down Expand Up @@ -111,10 +108,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, provi

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired credentials", "expiresAt", awsCreds.Expires),
})
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", awsCreds.Expires))

return
}
Expand Down
28 changes: 11 additions & 17 deletions pkg/azure/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,17 @@ func (cp *credentialsProvider) refreshCredentialsLoop(
Scopes: []string{cfg.scope},
})
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "failed to retrieve credentials"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to retrieve credentials"))

return
}

// Calculate when to refresh
timeUntilExpiry := time.Until(token.ExpiresOn)

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresOn))

return
}
Expand All @@ -107,19 +114,6 @@ func (cp *credentialsProvider) refreshCredentialsLoop(
})
cp.logger.V(2).Info("Sent credentials", "expires", token.ExpiresOn)

// Calculate when to refresh
timeUntilExpiry := time.Until(token.ExpiresOn)

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresOn),
})

return
}

refreshBuffer := util.CalculateRefreshBuffer(timeUntilExpiry)
refreshTime := timeUntilExpiry - refreshBuffer

Expand Down
10 changes: 2 additions & 8 deletions pkg/gcp/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(
for {
accessToken, err := genAccessTokenFunc()
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "failed to get access token"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get access token"))

return
}
Expand All @@ -113,10 +110,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired credentials", "expiresAt", accessToken.Expiry),
})
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", accessToken.Expiry))

return
}
Expand Down
10 changes: 2 additions & 8 deletions pkg/generic/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ loop:

token, err = tokenProvider.GetToken(ctx, opts...)
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "could not create token"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "could not create token"))

return
}
Expand Down Expand Up @@ -190,10 +187,7 @@ loop:

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresAt.Format(time.DateTime)),
})
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", token.ExpiresAt.Format(time.DateTime)))

return
}
Expand Down
5 changes: 1 addition & 4 deletions pkg/oauth2cc/oauth2cc.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,7 @@ func (r *tokenRetriever) tokenRefresherLoop(ctx context.Context) {

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(r.ch, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired token", "expiresAt", token.Expiry),
})
util.SendErrorToChannel(r.ch, errors.NewWithDetails("received already expired token", "expiresAt", token.Expiry))

return
}
Expand Down
20 changes: 4 additions & 16 deletions pkg/oci/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,21 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg *
for {
idToken, err := cfg.identityTokenProvider.GetToken(ctx)
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "failed to get identity token"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get identity token"))

return
}

authToken, err := exchangeToken(ctx, tokenEndpoint, cfg.clientID, cfg.clientSecret, idToken.Token, publicKey)
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "token exchange failed"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "token exchange failed"))

return
}

expTime, err := getTokenExpiration(authToken)
if err != nil {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.WrapIf(err, "failed to get expiration time of the received UPST"),
})
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get expiration time of the received UPST"))

return
}
Expand All @@ -134,10 +125,7 @@ func (cp *credentialsProvider) refreshCredentialsLoop(ctx context.Context, cfg *

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: errors.NewWithDetails("received already expired credentials", "expiresAt", ociCredential.ExpiresAt),
})
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired credentials", "expiresAt", ociCredential.ExpiresAt))

return
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/util/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,11 @@ func SendToChannel(credsChan chan credential.Result, result credential.Result) {
credsChan <- result
}
}

// SendErrorToChannel is a helper function to send an error result to the credentials channel.
func SendErrorToChannel(credsChan chan credential.Result, err error) {
SendToChannel(credsChan, credential.Result{
Credential: nil,
Err: err,
})
}
109 changes: 109 additions & 0 deletions pkg/vault/azure.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2026 Riptides Labs, Inc.
// SPDX-License-Identifier: MIT

package vault

import (
"context"
"time"

"emperror.dev/errors"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/cenkalti/backoff/v5"
"github.com/go-logr/logr"

"go.riptides.io/tokenex/pkg/credential"
"go.riptides.io/tokenex/pkg/util"
)

// VaultAzureSecret represents the structure of the secret data returned by Vault's Azure secrets engine when configured to return Azure credentials.
type vaultAzureSecret struct {
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
}

type azureAccessTokenProvider struct {
tenantID string
clientID string
clientSecret string
scopes []string

logger logr.Logger
}

// GetCredentials begins the process of exchanging the client ID and client secret for an Azure access token and refreshing it as needed until the context is canceled.
func (r *azureAccessTokenProvider) GetCredentials(ctx context.Context, credsChan chan credential.Result) {
b := backoff.NewExponentialBackOff()

for {
azClientCreds, err := azidentity.NewClientSecretCredential(r.tenantID, r.clientID, r.clientSecret, nil)
if err != nil {
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to create Azure client secret credential"))

return
}

token, err := backoff.Retry(ctx, func() (azcore.AccessToken, error) {
token, err := azClientCreds.GetToken(ctx, policy.TokenRequestOptions{
Scopes: r.scopes,
})

return token, err
}, backoff.WithBackOff(b), backoff.WithMaxElapsedTime(30*time.Second))
if err != nil {
util.SendErrorToChannel(credsChan, errors.WrapIf(err, "failed to get Azure access token from client secret credential"))

return
}

// Calculate when to refresh
timeUntilExpiry := time.Until(token.ExpiresOn)

// If credentials are already expired, this is an error
if timeUntilExpiry <= 0 {
util.SendErrorToChannel(credsChan, errors.NewWithDetails("received already expired access token from Azure", "expiresAt", token.ExpiresOn))

return
}

// Send credentials
azureCredential := &credential.Oauth2Creds{
AccessToken: token.Token,
TokenType: "Bearer",
Expiry: token.ExpiresOn,
}
util.SendToChannel(credsChan, credential.Result{
Credential: azureCredential,
Err: nil,
Event: credential.UpdateEventType,
})
r.logger.V(2).Info("Sent credentials", "expires", token.ExpiresOn)

refreshBuffer := util.CalculateRefreshBuffer(timeUntilExpiry)
refreshTime := timeUntilExpiry - refreshBuffer

if !token.RefreshOn.IsZero() {
// if refresh time is recommended in the received token, use that
r.logger.V(2).Info("Using RefreshOn time from token", "refreshOn", token.RefreshOn)

rt := time.Until(token.RefreshOn)
if rt > 0 {
refreshTime = rt
} else {
r.logger.V(2).Info("RefreshOn time is in the past, using calculated refresh time", "refreshTime", refreshTime)
}
}

select {
case <-ctx.Done():
r.logger.V(1).Info("Context cancelled, stopping credential refresh")

return
case <-time.After(refreshTime):
// Continue to next iteration to refresh
r.logger.V(2).Info("Refreshing access token")
}
}
}
Loading
Loading