From 15774de9f00df186ff46a0cebfeacd38f21e0bf5 Mon Sep 17 00:00:00 2001 From: LanceDB Robot Date: Sat, 27 Jun 2026 22:12:02 +0000 Subject: [PATCH] fix: serve still-valid cached AWS credentials when refresh fails AwsCredentialAdapter proactively refreshes credentials credentials_refresh_offset (default 60s) before they expire. When that proactive refresh hit a transient error from the underlying provider (e.g. an IMDS/STS HTTP connect timeout), get_credential discarded the still-valid cached credentials and returned a hard error, surfacing as a 500 for S3 and DynamoDB operations. Fall back to the cached credentials when a refresh fails but the cached credentials have not actually expired yet; the next call retries the refresh. Truly-expired credentials still surface the error rather than being used. --- .../src/object_store/providers/aws.rs | 187 ++++++++++++++---- 1 file changed, 147 insertions(+), 40 deletions(-) diff --git a/rust/lance-io/src/object_store/providers/aws.rs b/rust/lance-io/src/object_store/providers/aws.rs index 9aad637bce2..182d38c1080 100644 --- a/rust/lance-io/src/object_store/providers/aws.rs +++ b/rust/lance-io/src/object_store/providers/aws.rs @@ -350,6 +350,17 @@ impl AwsCredentialAdapter { const AWS_CREDS_CACHE_KEY: &str = "aws_credentials"; +/// Convert AWS SDK credentials into the object_store credential type. +fn to_object_store_credential( + creds: &aws_credential_types::Credentials, +) -> ObjectStoreAwsCredential { + ObjectStoreAwsCredential { + key_id: creds.access_key_id().to_string(), + secret_key: creds.secret_access_key().to_string(), + token: creds.session_token().map(|s| s.to_string()), + } +} + /// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime fn to_system_time(time: std::time::SystemTime) -> SystemTime { let duration_since_epoch = time @@ -363,47 +374,65 @@ impl CredentialProvider for AwsCredentialAdapter { type Credential = ObjectStoreAwsCredential; async fn get_credential(&self) -> ObjectStoreResult> { - let cached_creds = { - let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned(); - let expired = cache_value - .clone() - .map(|cred| { - cred.expiry() - .map(|exp| { - to_system_time(exp) - .checked_sub(self.credentials_refresh_offset) - .expect("this time should always be valid") - < SystemTime::now() - }) - // no expiry is never expire - .unwrap_or(false) - }) - .unwrap_or(true); // no cred is the same as expired; - if expired { None } else { cache_value.clone() } - }; - - if let Some(creds) = cached_creds { - Ok(Arc::new(Self::Credential { - key_id: creds.access_key_id().to_string(), - secret_key: creds.secret_access_key().to_string(), - token: creds.session_token().map(|s| s.to_string()), - })) - } else { - let refreshed_creds = - Arc::new(self.inner.provide_credentials().await.map_err(|e| { - Error::internal(format!("Failed to get AWS credentials: {:?}", e)) - })?); + let cached = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned(); + + // Credentials are proactively refreshed `credentials_refresh_offset` + // before they actually expire. During that window the cached + // credentials are still valid, so we only need a refresh, not a + // mandatory replacement. + let needs_refresh = cached + .as_ref() + .map(|cred| { + cred.expiry() + .map(|exp| { + to_system_time(exp) + .checked_sub(self.credentials_refresh_offset) + .expect("this time should always be valid") + < SystemTime::now() + }) + // no expiry is never expire + .unwrap_or(false) + }) + .unwrap_or(true); // no cred is the same as needing a refresh - self.cache - .write() - .await - .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone()); + if !needs_refresh { + // Safe to unwrap: needs_refresh is only false when a cached value exists. + return Ok(Arc::new(to_object_store_credential(cached.as_ref().unwrap()))); + } - Ok(Arc::new(Self::Credential { - key_id: refreshed_creds.access_key_id().to_string(), - secret_key: refreshed_creds.secret_access_key().to_string(), - token: refreshed_creds.session_token().map(|s| s.to_string()), - })) + match self.inner.provide_credentials().await { + Ok(refreshed_creds) => { + let refreshed_creds = Arc::new(refreshed_creds); + self.cache + .write() + .await + .insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone()); + Ok(Arc::new(to_object_store_credential(&refreshed_creds))) + } + // The refresh failed (e.g. a transient IMDS/STS timeout). If we + // still hold cached credentials that have not actually expired yet, + // keep using them rather than failing the request — the next call + // will try to refresh again. This prevents a momentary credential + // provider blip during the proactive refresh window from turning + // into a hard error for the caller. + Err(e) => { + if let Some(creds) = cached { + let still_valid = creds + .expiry() + .map(|exp| to_system_time(exp) > SystemTime::now()) + // no expiry means the credentials never expire + .unwrap_or(true); + if still_valid { + log::warn!( + "Failed to refresh AWS credentials, \ + falling back to cached credentials that are still valid: {:?}", + e + ); + return Ok(Arc::new(to_object_store_credential(&creds))); + } + } + Err(Error::internal(format!("Failed to get AWS credentials: {:?}", e)).into()) + } } } } @@ -462,7 +491,7 @@ mod tests { use crate::object_store::StorageOptionsProvider; use mock_instant::thread_local::MockClock; use object_store::path::Path; - use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use super::*; @@ -514,6 +543,84 @@ mod tests { assert!(mock_provider.called.load(Ordering::Relaxed)); } + /// A `ProvideCredentials` that returns valid credentials on its first call + /// and then fails on every subsequent call, simulating a transient + /// credential provider outage (e.g. an IMDS/STS connect timeout) occurring + /// during a proactive refresh. + #[derive(Debug)] + struct FlakyCredentialsProvider { + call_count: Arc, + expiry_secs: u64, + } + + impl ProvideCredentials for FlakyCredentialsProvider { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + let n = self.call_count.fetch_add(1, Ordering::SeqCst); + let expiry = std::time::UNIX_EPOCH + Duration::from_secs(self.expiry_secs); + let result = if n == 0 { + Ok(aws_credential_types::Credentials::new( + "AKID_FRESH", + "SECRET_FRESH", + Some("TOKEN_FRESH".to_string()), + Some(expiry), + "flaky-test", + )) + } else { + Err( + aws_credential_types::provider::error::CredentialsError::provider_error( + Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "simulated IMDS connect timeout", + )), + ), + ) + }; + aws_credential_types::provider::future::ProvideCredentials::ready(result) + } + } + + #[tokio::test] + async fn test_aws_credential_adapter_falls_back_to_cached_on_refresh_failure() { + // Base time 100_000s; credentials expire 30s later. With a 60s refresh + // offset, the cached credentials are immediately within the refresh + // window yet remain valid until 100_030s. + MockClock::set_system_time(Duration::from_secs(100_000)); + + let call_count = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(FlakyCredentialsProvider { + call_count: call_count.clone(), + expiry_secs: 100_030, + }); + let adapter = AwsCredentialAdapter::new(provider, Duration::from_secs(60)); + + // First call seeds the cache from the provider. + let cred = adapter.get_credential().await.unwrap(); + assert_eq!(cred.key_id, "AKID_FRESH"); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + // Second call attempts a refresh (creds are within the refresh window), + // the provider fails, but the cached creds are still valid -> serve them + // instead of erroring. + let cred = adapter.get_credential().await.unwrap(); + assert_eq!(cred.key_id, "AKID_FRESH"); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + + // Once the cached creds have actually expired, a failing refresh must + // surface as an error rather than serving expired credentials. + MockClock::set_system_time(Duration::from_secs(100_031)); + let err = adapter.get_credential().await.unwrap_err(); + assert!( + err.to_string().contains("Failed to get AWS credentials"), + "unexpected error: {err}" + ); + assert_eq!(call_count.load(Ordering::SeqCst), 3); + } + #[test] fn test_s3_path_parsing() { let provider = AwsStoreProvider;