Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 147 additions & 40 deletions rust/lance-io/src/object_store/providers/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@

const AWS_CREDS_CACHE_KEY: &str = "aws_credentials";

/// Convert AWS SDK credentials into the object_store credential type.
fn to_object_store_credential(
creds: &aws_credential_types::Credentials,
) -> ObjectStoreAwsCredential {
ObjectStoreAwsCredential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(|s| s.to_string()),
}
}

/// Convert std::time::SystemTime from AWS SDK to our mockable SystemTime
fn to_system_time(time: std::time::SystemTime) -> SystemTime {
let duration_since_epoch = time
Expand All @@ -363,47 +374,65 @@
type Credential = ObjectStoreAwsCredential;

async fn get_credential(&self) -> ObjectStoreResult<Arc<Self::Credential>> {
let cached_creds = {
let cache_value = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();
let expired = cache_value
.clone()
.map(|cred| {
cred.expiry()
.map(|exp| {
to_system_time(exp)
.checked_sub(self.credentials_refresh_offset)
.expect("this time should always be valid")
< SystemTime::now()
})
// no expiry is never expire
.unwrap_or(false)
})
.unwrap_or(true); // no cred is the same as expired;
if expired { None } else { cache_value.clone() }
};

if let Some(creds) = cached_creds {
Ok(Arc::new(Self::Credential {
key_id: creds.access_key_id().to_string(),
secret_key: creds.secret_access_key().to_string(),
token: creds.session_token().map(|s| s.to_string()),
}))
} else {
let refreshed_creds =
Arc::new(self.inner.provide_credentials().await.map_err(|e| {
Error::internal(format!("Failed to get AWS credentials: {:?}", e))
})?);
let cached = self.cache.read().await.get(AWS_CREDS_CACHE_KEY).cloned();

// Credentials are proactively refreshed `credentials_refresh_offset`
// before they actually expire. During that window the cached
// credentials are still valid, so we only need a refresh, not a
// mandatory replacement.
let needs_refresh = cached
.as_ref()
.map(|cred| {
cred.expiry()
.map(|exp| {
to_system_time(exp)
.checked_sub(self.credentials_refresh_offset)
.expect("this time should always be valid")
< SystemTime::now()
})
// no expiry is never expire
.unwrap_or(false)
})
.unwrap_or(true); // no cred is the same as needing a refresh

Check warning on line 397 in rust/lance-io/src/object_store/providers/aws.rs

View workflow job for this annotation

GitHub Actions / format

Diff in /home/runner/work/lance/lance/rust/lance-io/src/object_store/providers/aws.rs
self.cache
.write()
.await
.insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
if !needs_refresh {
// Safe to unwrap: needs_refresh is only false when a cached value exists.
return Ok(Arc::new(to_object_store_credential(cached.as_ref().unwrap())));
}

Ok(Arc::new(Self::Credential {
key_id: refreshed_creds.access_key_id().to_string(),
secret_key: refreshed_creds.secret_access_key().to_string(),
token: refreshed_creds.session_token().map(|s| s.to_string()),
}))
match self.inner.provide_credentials().await {
Ok(refreshed_creds) => {
let refreshed_creds = Arc::new(refreshed_creds);
self.cache
.write()
.await
.insert(AWS_CREDS_CACHE_KEY.to_string(), refreshed_creds.clone());
Ok(Arc::new(to_object_store_credential(&refreshed_creds)))
}
// The refresh failed (e.g. a transient IMDS/STS timeout). If we
// still hold cached credentials that have not actually expired yet,
// keep using them rather than failing the request — the next call
// will try to refresh again. This prevents a momentary credential
// provider blip during the proactive refresh window from turning
// into a hard error for the caller.
Err(e) => {
if let Some(creds) = cached {
let still_valid = creds
.expiry()
.map(|exp| to_system_time(exp) > SystemTime::now())
// no expiry means the credentials never expire
.unwrap_or(true);
if still_valid {
log::warn!(
"Failed to refresh AWS credentials, \
falling back to cached credentials that are still valid: {:?}",
e
);
return Ok(Arc::new(to_object_store_credential(&creds)));
}
}
Err(Error::internal(format!("Failed to get AWS credentials: {:?}", e)).into())
}
}
}
}
Expand Down Expand Up @@ -462,7 +491,7 @@
use crate::object_store::StorageOptionsProvider;
use mock_instant::thread_local::MockClock;
use object_store::path::Path;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};

use super::*;

Expand Down Expand Up @@ -514,6 +543,84 @@
assert!(mock_provider.called.load(Ordering::Relaxed));
}

/// A `ProvideCredentials` that returns valid credentials on its first call
/// and then fails on every subsequent call, simulating a transient
/// credential provider outage (e.g. an IMDS/STS connect timeout) occurring
/// during a proactive refresh.
#[derive(Debug)]
struct FlakyCredentialsProvider {
call_count: Arc<AtomicUsize>,
expiry_secs: u64,
}

impl ProvideCredentials for FlakyCredentialsProvider {
fn provide_credentials<'a>(
&'a self,
) -> aws_credential_types::provider::future::ProvideCredentials<'a>
where
Self: 'a,
{
let n = self.call_count.fetch_add(1, Ordering::SeqCst);
let expiry = std::time::UNIX_EPOCH + Duration::from_secs(self.expiry_secs);
let result = if n == 0 {
Ok(aws_credential_types::Credentials::new(
"AKID_FRESH",
"SECRET_FRESH",
Some("TOKEN_FRESH".to_string()),
Some(expiry),
"flaky-test",
))
} else {
Err(
aws_credential_types::provider::error::CredentialsError::provider_error(
Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"simulated IMDS connect timeout",
)),
),
)
};
aws_credential_types::provider::future::ProvideCredentials::ready(result)
}
}

#[tokio::test]
async fn test_aws_credential_adapter_falls_back_to_cached_on_refresh_failure() {
// Base time 100_000s; credentials expire 30s later. With a 60s refresh
// offset, the cached credentials are immediately within the refresh
// window yet remain valid until 100_030s.
MockClock::set_system_time(Duration::from_secs(100_000));

let call_count = Arc::new(AtomicUsize::new(0));
let provider = Arc::new(FlakyCredentialsProvider {
call_count: call_count.clone(),
expiry_secs: 100_030,
});
let adapter = AwsCredentialAdapter::new(provider, Duration::from_secs(60));

// First call seeds the cache from the provider.
let cred = adapter.get_credential().await.unwrap();
assert_eq!(cred.key_id, "AKID_FRESH");
assert_eq!(call_count.load(Ordering::SeqCst), 1);

// Second call attempts a refresh (creds are within the refresh window),
// the provider fails, but the cached creds are still valid -> serve them
// instead of erroring.
let cred = adapter.get_credential().await.unwrap();
assert_eq!(cred.key_id, "AKID_FRESH");
assert_eq!(call_count.load(Ordering::SeqCst), 2);

// Once the cached creds have actually expired, a failing refresh must
// surface as an error rather than serving expired credentials.
MockClock::set_system_time(Duration::from_secs(100_031));
let err = adapter.get_credential().await.unwrap_err();
assert!(
err.to_string().contains("Failed to get AWS credentials"),
"unexpected error: {err}"
);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}

#[test]
fn test_s3_path_parsing() {
let provider = AwsStoreProvider;
Expand Down
Loading