Skip to content

Commit

Permalink
Add fallback_on_interrupt to the ProvideCredentials trait (#2246)
Browse files Browse the repository at this point in the history
* Implement RFC for providing fallback credentials

This commit implements the changes checklist in the RFC for providing
fallback credentials.

* Remove needless lifetime parameter

* Update CHANGELOG.next.toml

---------

Co-authored-by: Yuki Saito <[email protected]>
  • Loading branch information
ysaito1001 and ysaito1001 authored Jan 27, 2023
1 parent e6c3a4b commit 980b5c4
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 6 deletions.
48 changes: 47 additions & 1 deletion CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,50 @@
# message = "Fix typos in module documentation for generated crates"
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"
# author = "rcoh"

[[aws-sdk-rust]]
message = """
Provide a way to retrieve fallback credentials if a call to `provide_credentials` is interrupted. An interrupt can occur when a timeout future is raced against a future for `provide_credentials`, and the former wins the race. A new method, `fallback_on_interrupt` on the `ProvideCredentials` trait, can be used in that case. The following code snippet from `LazyCredentialsCache::provide_cached_credentials` has been updated like so:
Before:
```rust
let timeout_future = self.sleeper.sleep(self.load_timeout);
// --snip--
let future = Timeout::new(provider.provide_credentials(), timeout_future);
let result = cache
.get_or_load(|| {
async move {
let credentials = future.await.map_err(|_err| {
CredentialsError::provider_timed_out(load_timeout)
})??;
// --snip--
}
}).await;
// --snip--
```
After:
```rust
let timeout_future = self.sleeper.sleep(self.load_timeout);
// --snip--
let future = Timeout::new(provider.provide_credentials(), timeout_future);
let result = cache
.get_or_load(|| {
async move {
let credentials = match future.await {
Ok(creds) => creds?,
Err(_err) => match provider.fallback_on_interrupt() { // can provide fallback credentials
Some(creds) => creds,
None => return Err(CredentialsError::provider_timed_out(load_timeout)),
}
};
// --snip--
}
}).await;
// --snip--
```
"""
references = ["smithy-rs#2246"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
author = "ysaito1001"
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use std::borrow::Cow;

use aws_credential_types::provider::{self, future, ProvideCredentials};
use aws_credential_types::Credentials;
use tracing::Instrument;

use crate::environment::credentials::EnvironmentVariableCredentialsProvider;
Expand Down Expand Up @@ -83,6 +84,10 @@ impl ProvideCredentials for DefaultCredentialsChain {
{
future::ProvideCredentials::new(self.credentials())
}

fn fallback_on_interrupt(&self) -> Option<Credentials> {
self.provider_chain.fallback_on_interrupt()
}
}

/// Builder for [`DefaultCredentialsChain`](DefaultCredentialsChain)
Expand Down
116 changes: 115 additions & 1 deletion aws/rust-runtime/aws-config/src/meta/credentials/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
* SPDX-License-Identifier: Apache-2.0
*/

use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::{
provider::{self, error::CredentialsError, future, ProvideCredentials},
Credentials,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use std::borrow::Cow;
use tracing::Instrument;
Expand Down Expand Up @@ -104,4 +107,115 @@ impl ProvideCredentials for CredentialsProviderChain {
{
future::ProvideCredentials::new(self.credentials())
}

fn fallback_on_interrupt(&self) -> Option<Credentials> {
for (_, provider) in &self.providers {
match provider.fallback_on_interrupt() {
creds @ Some(_) => return creds,
None => {}
}
}
None
}
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use aws_credential_types::{
credential_fn::provide_credentials_fn,
provider::{error::CredentialsError, future, ProvideCredentials},
Credentials,
};
use aws_smithy_async::future::timeout::Timeout;

use crate::meta::credentials::CredentialsProviderChain;

#[derive(Debug)]
struct FallbackCredentials(Credentials);

impl ProvideCredentials for FallbackCredentials {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
where
Self: 'a,
{
future::ProvideCredentials::new(async {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(self.0.clone())
})
}

fn fallback_on_interrupt(&self) -> Option<Credentials> {
Some(self.0.clone())
}
}

#[tokio::test]
async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
) {
let chain = CredentialsProviderChain::first_try(
"provider1",
provide_credentials_fn(|| async {
tokio::time::sleep(Duration::from_millis(200)).await;
Err(CredentialsError::not_loaded(
"no providers in chain provided credentials",
))
}),
)
.or_else("provider2", FallbackCredentials(Credentials::for_tests()));

// Let the first call to `provide_credentials` succeed.
let expected = chain.provide_credentials().await.unwrap();

// Let the second call fail with an external timeout.
let timeout = Timeout::new(
chain.provide_credentials(),
tokio::time::sleep(Duration::from_millis(300)),
);
match timeout.await {
Ok(_) => assert!(false, "provide_credentials completed before timeout future"),
Err(_err) => match chain.fallback_on_interrupt() {
Some(actual) => assert_eq!(actual, expected),
None => assert!(
false,
"provide_credentials timed out and no credentials returned from fallback_on_interrupt"
),
},
};
}

#[tokio::test]
async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
) {
let chain = CredentialsProviderChain::first_try(
"provider1",
provide_credentials_fn(|| async {
tokio::time::sleep(Duration::from_millis(200)).await;
Err(CredentialsError::not_loaded(
"no providers in chain provided credentials",
))
}),
)
.or_else("provider2", FallbackCredentials(Credentials::for_tests()));

// Let the first call to `provide_credentials` succeed.
let expected = chain.provide_credentials().await.unwrap();

// Let the second call fail with an external timeout.
let timeout = Timeout::new(
chain.provide_credentials(),
tokio::time::sleep(Duration::from_millis(100)),
);
match timeout.await {
Ok(_) => assert!(false, "provide_credentials completed before timeout future"),
Err(_err) => match chain.fallback_on_interrupt() {
Some(actual) => assert_eq!(actual, expected),
None => assert!(
false,
"provide_credentials timed out and no credentials returned from fallback_on_interrupt"
),
},
};
}
}
15 changes: 12 additions & 3 deletions aws/rust-runtime/aws-credential-types/src/cache/lazy_caching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,19 @@ impl ProvideCachedCredentials for LazyCredentialsCache {
let result = cache
.get_or_load(|| {
let span = info_span!("lazy_load_credentials");
let provider = provider.clone();
async move {
let credentials = future.await.map_err(|_err| {
CredentialsError::provider_timed_out(load_timeout)
})??;
let credentials = match future.await {
Ok(creds) => creds?,
Err(_err) => match provider.fallback_on_interrupt() {
Some(creds) => creds,
None => {
return Err(CredentialsError::provider_timed_out(
load_timeout,
))
}
},
};
// If the credentials don't have an expiration time, then create a default one
let expiry = credentials
.expiry()
Expand Down
15 changes: 14 additions & 1 deletion aws/rust-runtime/aws-credential-types/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ pub mod future {

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;

/// Future new-type that the `ProvideCredentials` trait must return.
/// Future new-type that `ProvideCredentials::provide_credentials` must return.
#[derive(Debug)]
pub struct ProvideCredentials<'a>(NowOrLater<super::Result, BoxFuture<'a, super::Result>>);

Expand Down Expand Up @@ -280,6 +280,19 @@ pub trait ProvideCredentials: Send + Sync + std::fmt::Debug {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
where
Self: 'a;

/// Returns fallback credentials.
///
/// This method should be used as a fallback plan, i.e., when
/// a call to `provide_credentials` is interrupted and its future
/// fails to complete.
///
/// The fallback credentials should be set aside and ready to be returned
/// immediately. Therefore, the user should NOT go fetch new credentials
/// within this method, which might cause a long-running operation.
fn fallback_on_interrupt(&self) -> Option<Credentials> {
None
}
}

impl ProvideCredentials for Credentials {
Expand Down

0 comments on commit 980b5c4

Please sign in to comment.