diff --git a/src/lib.rs b/src/lib.rs index 6e157e00..951e4ede 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -190,6 +190,7 @@ pub mod params; pub mod service; use crate::service::body::BodyStreamExt; +use chrono::{DateTime, Utc}; use http::{HeaderMap, HeaderValue, Method, Uri}; use std::convert::{Infallible, TryInto}; use std::fmt; @@ -760,18 +761,54 @@ impl DefaultOctocrabBuilderConfig { pub type DynBody = dyn http_body::Body + Send + Unpin; +#[derive(Debug, Clone)] +struct CachedTokenInner { + expiration: Option>, + secret: SecretString, +} + +impl CachedTokenInner { + fn new(secret: SecretString, expiration: Option>) -> Self { + Self { secret, expiration } + } + + fn expose_secret(&self) -> &str { + self.secret.expose_secret() + } +} + /// A cached API access token (which may be None) -pub struct CachedToken(RwLock>); +pub struct CachedToken(RwLock>); impl CachedToken { fn clear(&self) { *self.0.write().unwrap() = None; } - fn get(&self) -> Option { - self.0.read().unwrap().clone() + + /// Returns a valid token if it exists and is not expired or if there is no expiration date. + fn valid_token_with_buffer(&self, buffer: chrono::Duration) -> Option { + let inner = self.0.read().unwrap(); + + if let Some(token) = inner.as_ref() { + if let Some(exp) = token.expiration { + if exp - Utc::now() > buffer { + return Some(token.secret.clone()); + } + } else { + return Some(token.secret.clone()); + } + } + + None + } + + fn valid_token(&self) -> Option { + self.valid_token_with_buffer(chrono::Duration::seconds(30)) } - fn set(&self, value: String) { - *self.0.write().unwrap() = Some(SecretString::new(value)); + + fn set(&self, token: String, expiration: Option>) { + *self.0.write().unwrap() = + Some(CachedTokenInner::new(SecretString::new(token), expiration)); } } @@ -1349,7 +1386,21 @@ impl Octocrab { let token_object = InstallationToken::from_response(crate::map_github_error(response).await?).await?; - token.set(token_object.token.clone()); + + let expiration = token_object + .expires_at + .map(|time| { + DateTime::::from_str(&time).map_err(|e| error::Error::Other { + source: Box::new(e), + backtrace: snafu::Backtrace::generate(), + }) + }) + .transpose()?; + + tracing::debug!("Token expires at: {:?}", expiration); + + token.set(token_object.token.clone(), expiration); + Ok(SecretString::new(token_object.token)) } @@ -1407,7 +1458,7 @@ impl Octocrab { Some(HeaderValue::from_bytes(&buf).expect("base64 is always valid HeaderValue")) } AuthState::Installation { ref token, .. } => { - let token = if let Some(token) = token.get() { + let token = if let Some(token) = token.valid_token() { token } else { self.request_installation_auth_token().await? @@ -1523,4 +1574,57 @@ mod tests { .await .unwrap(); } + + use super::*; + use chrono::Duration; + + #[test] + fn clear_token() { + let cache = CachedToken(RwLock::new(None)); + cache.set("secret".to_string(), None); + cache.clear(); + + assert!(cache.valid_token().is_none(), "Token was not cleared."); + } + + #[test] + fn no_token_when_expired() { + let cache = CachedToken(RwLock::new(None)); + let expiration = Utc::now() + Duration::seconds(9); + cache.set("secret".to_string(), Some(expiration)); + + assert!( + cache + .valid_token_with_buffer(Duration::seconds(10)) + .is_none(), + "Token should be considered expired due to buffer." + ); + } + + #[test] + fn get_valid_token_outside_buffer() { + let cache = CachedToken(RwLock::new(None)); + let expiration = Utc::now() + Duration::seconds(12); + cache.set("secret".to_string(), Some(expiration)); + + assert!( + cache + .valid_token_with_buffer(Duration::seconds(10)) + .is_some(), + "Token should still be valid outside of buffer." + ); + } + + #[test] + fn get_valid_token_without_expiration() { + let cache = CachedToken(RwLock::new(None)); + cache.set("secret".to_string(), None); + + assert!( + cache + .valid_token_with_buffer(Duration::seconds(10)) + .is_some(), + "Token with no expiration should always be considered valid." + ); + } }