Skip to content

Commit

Permalink
Implement support for ID tokens
Browse files Browse the repository at this point in the history
For google stuff these are relevant when trying to invoke e.g. Cloud
Run services. I'm not at all knowledgeable enough with OAuth to be able
to tell if what I'm doing here is correct. e.g. I can imagine people
potentially wanting to access both id token and access token from the
same oauth response?

This is a breaking change, however, as `AccessToken` got renamed to just
`Token` (since it now encompasses more than just `access_token` and
there are some changes to the `TokenInfo` type too.

Sponsored by: standard.ai
  • Loading branch information
nagisa committed Jul 9, 2021
1 parent e63aa4b commit 25de05b
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 70 deletions.
11 changes: 4 additions & 7 deletions src/authenticator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::installed::{InstalledFlow, InstalledFlowReturnMethod};
use crate::refresh::RefreshFlow;
use crate::service_account::{ServiceAccountFlow, ServiceAccountFlowOpts, ServiceAccountKey};
use crate::storage::{self, Storage, TokenStorage};
use crate::types::{AccessToken, ApplicationSecret, TokenInfo};
use crate::types::{ApplicationSecret, Token, TokenInfo};
use private::AuthFlow;

use futures::lock::Mutex;
Expand Down Expand Up @@ -53,7 +53,7 @@ where
C: hyper::client::connect::Connect + Clone + Send + Sync + 'static,
{
/// Return the current token for the provided scopes.
pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result<AccessToken, Error>
pub async fn token<'a, T>(&'a self, scopes: &'a [T]) -> Result<Token, Error>
where
T: AsRef<str>,
{
Expand All @@ -62,10 +62,7 @@ where

/// Return a token for the provided scopes, but don't reuse cached tokens. Instead,
/// always fetch a new token from the OAuth server.
pub async fn force_refreshed_token<'a, T>(
&'a self,
scopes: &'a [T],
) -> Result<AccessToken, Error>
pub async fn force_refreshed_token<'a, T>(&'a self, scopes: &'a [T]) -> Result<Token, Error>
where
T: AsRef<str>,
{
Expand All @@ -77,7 +74,7 @@ where
&'a self,
scopes: &'a [T],
force_refresh: bool,
) -> Result<AccessToken, Error>
) -> Result<Token, Error>
where
T: AsRef<str>,
{
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,4 @@ pub use crate::service_account::ServiceAccountKey;

#[doc(inline)]
pub use crate::error::Error;
pub use crate::types::{AccessToken, ApplicationSecret, ConsoleApplicationSecret};
pub use crate::types::{ApplicationSecret, ConsoleApplicationSecret, Token};
10 changes: 9 additions & 1 deletion src/service_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ mod tests {
const TEST_PRIVATE_KEY_PATH: &'static str = "examples/Sanguine-69411a0c0eea.json";

// Uncomment this test to verify that we can successfully obtain tokens.
//#[tokio::test]
// #[tokio::test]
#[allow(dead_code)]
async fn test_service_account_e2e() {
let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
Expand All @@ -240,6 +240,14 @@ mod tests {
acc.token(&client, &["https://www.googleapis.com/auth/pubsub"])
.await
);
println!(
"{:?}",
acc.token(
&client,
&["https://some.scope/likely-to-hand-out-id-tokens"]
)
.await
);
}

#[tokio::test]
Expand Down
3 changes: 2 additions & 1 deletion src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ mod tests {
#[tokio::test]
async fn test_disk_storage() {
let new_token = |access_token: &str| TokenInfo {
access_token: access_token.to_owned(),
id_token: None,
access_token: Some(access_token.to_owned()),
refresh_token: None,
expires_at: None,
};
Expand Down
108 changes: 59 additions & 49 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,70 @@ use crate::error::{AuthErrorOr, Error};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};

/// Represents an access token returned by oauth2 servers. All access tokens are
/// Bearer tokens. Other types of tokens are not supported.
/// Represents a token returned by oauth2 servers. All tokens are Bearer tokens. Other types of
/// tokens are not supported.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
pub struct AccessToken {
value: String,
pub struct Token {
id_token: Option<String>,
access_token: Option<String>,
expires_at: Option<DateTime<Utc>>,
}

impl AccessToken {
impl Token {
/// A string representation of the ID token.
pub fn id_token(&self) -> Option<&str> {
self.id_token.as_deref()
}

/// A string representation of the access token.
pub fn as_str(&self) -> &str {
&self.value
pub fn access_token(&self) -> Option<&str> {
self.access_token.as_deref()
}

/// The time the access token will expire, if any.
/// The time at which the tokens will expire, if any.
pub fn expiration_time(&self) -> Option<DateTime<Utc>> {
self.expires_at
}

/// Determine if the access token is expired.
/// This will report that the token is expired 1 minute prior to the
/// expiration time to ensure that when the token is actually sent to the
/// server it's still valid.
///
/// This will report that the token is expired 1 minute prior to the expiration time to ensure
/// that when the token is actually sent to the server it's still valid.
pub fn is_expired(&self) -> bool {
// Consider the token expired if it's within 1 minute of it's expiration
// time.
// Consider the token expired if it's within 1 minute of it's expiration time.
self.expires_at
.map(|expiration_time| expiration_time - chrono::Duration::minutes(1) <= Utc::now())
.unwrap_or(false)
}
}

impl AsRef<str> for AccessToken {
fn as_ref(&self) -> &str {
self.as_str()
}
}

impl From<TokenInfo> for AccessToken {
fn from(value: TokenInfo) -> Self {
AccessToken {
value: value.access_token,
expires_at: value.expires_at,
impl From<TokenInfo> for Token {
fn from(
TokenInfo {
access_token,
id_token,
expires_at,
..
}: TokenInfo,
) -> Self {
Token {
access_token,
id_token,
expires_at,
}
}
}

/// Represents a token as returned by OAuth2 servers.
///
/// It is produced by all authentication flows.
/// It authenticates certain operations, and must be refreshed once
/// it reached it's expiry date.
/// It authenticates certain operations, and must be refreshed once it reached it's expiry date.
#[derive(Clone, PartialEq, Debug, Deserialize, Serialize)]
pub struct TokenInfo {
/// used when authorizing calls to oauth2 enabled services.
pub access_token: Option<String>,
/// used when authenticating calls to oauth2 enabled services.
pub access_token: String,
pub id_token: Option<String>,
/// used to refresh an expired access_token.
pub refresh_token: Option<String>,
/// The time when the token expires.
Expand All @@ -68,38 +76,40 @@ pub struct TokenInfo {
impl TokenInfo {
pub(crate) fn from_json(json_data: &[u8]) -> Result<TokenInfo, Error> {
#[derive(Deserialize)]
struct RawToken {
access_token: String,
refresh_token: Option<String>,
token_type: String,
struct TokenSchema<'a> {
id_token: Option<&'a str>,
access_token: Option<&'a str>,
refresh_token: Option<&'a str>,
token_type: Option<&'a str>,
expires_in: Option<i64>,
}

let RawToken {
let TokenSchema {
id_token,
access_token,
refresh_token,
token_type,
expires_in,
} = serde_json::from_slice::<AuthErrorOr<RawToken>>(json_data)?.into_result()?;

if token_type.to_lowercase().as_str() != "bearer" {
use std::io;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
r#"unknown token type returned; expected "bearer" found {}"#,
token_type
),
)
.into());
} = serde_json::from_slice::<AuthErrorOr<_>>(json_data)?.into_result()?;
match token_type {
Some(token_ty) if !token_ty.eq_ignore_ascii_case("bearer") => {
use std::io;
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
r#"unknown token type returned; expected "bearer" found {}"#,
token_ty
),
)
.into());
}
_ => (),
}

let expires_at = expires_in
.map(|seconds_from_now| Utc::now() + chrono::Duration::seconds(seconds_from_now));

Ok(TokenInfo {
access_token,
refresh_token,
id_token: id_token.map(String::from),
access_token: access_token.map(String::from),
refresh_token: refresh_token.map(String::from),
expires_at,
})
}
Expand Down
22 changes: 11 additions & 11 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async fn test_device_success() {
.token(&["https://www.googleapis.com/scope/1"])
.await
.expect("token failed");
assert_eq!("accesstoken", token.as_str());
assert_eq!("accesstoken", token.access_token().expect("should have access token"));
}

#[tokio::test]
Expand Down Expand Up @@ -262,7 +262,7 @@ async fn test_installed_interactive_success() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!("accesstoken", tok.as_str());
assert_eq!("accesstoken", tok.access_token().expect("should have access token"));
}

#[tokio::test]
Expand Down Expand Up @@ -291,7 +291,7 @@ async fn test_installed_redirect_success() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!("accesstoken", tok.as_str());
assert_eq!("accesstoken", tok.access_token().expect("should have access token"));
}

#[tokio::test]
Expand Down Expand Up @@ -362,7 +362,7 @@ async fn test_service_account_success() {
.token(&["https://www.googleapis.com/auth/pubsub"])
.await
.expect("token failed");
assert!(tok.as_str().contains("ya29.c.ElouBywiys0Ly"));
assert!(tok.access_token().expect("should have access token").contains("ya29.c.ElouBywiys0Ly"));
assert!(Utc::now() + chrono::Duration::seconds(3600) >= tok.expiration_time().unwrap());
}

Expand Down Expand Up @@ -413,7 +413,7 @@ async fn test_refresh() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!("accesstoken", tok.as_str());
assert_eq!("accesstoken", tok.access_token().expect("should have access token"));

server.expect(
Expectation::matching(all_of![
Expand All @@ -434,7 +434,7 @@ async fn test_refresh() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!("accesstoken2", tok.as_str());
assert_eq!("accesstoken2", tok.access_token().expect("should have access token"));

server.expect(
Expectation::matching(all_of![
Expand All @@ -455,7 +455,7 @@ async fn test_refresh() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!("accesstoken3", tok.as_str());
assert_eq!("accesstoken3", tok.access_token().expect("should have access token"));

server.expect(
Expectation::matching(all_of![
Expand Down Expand Up @@ -515,7 +515,7 @@ async fn test_memory_storage() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!(token1.as_str(), "accesstoken");
assert_eq!(token1.access_token().expect("should have access token"), "accesstoken");
assert_eq!(token1, token2);

// Create a new authenticator. This authenticator does not share a cache
Expand All @@ -541,7 +541,7 @@ async fn test_memory_storage() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!(token3.as_str(), "accesstoken2");
assert_eq!(token3.access_token().expect("should have access token"), "accesstoken2");
}

#[tokio::test]
Expand Down Expand Up @@ -583,7 +583,7 @@ async fn test_disk_storage() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!(token1.as_str(), "accesstoken");
assert_eq!(token1.access_token().expect("should have access token"), "accesstoken");
assert_eq!(token1, token2);
}

Expand All @@ -605,6 +605,6 @@ async fn test_disk_storage() {
.token(&["https://googleapis.com/some/scope"])
.await
.expect("failed to get token");
assert_eq!(token1.as_str(), "accesstoken");
assert_eq!(token1.access_token().expect("should have access token"), "accesstoken");
assert_eq!(token1, token2);
}

0 comments on commit 25de05b

Please sign in to comment.