-
Notifications
You must be signed in to change notification settings - Fork 252
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for authenticating using azureauth cli (#1464)
- Loading branch information
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
209 changes: 209 additions & 0 deletions
209
sdk/identity/src/token_credentials/azureauth_cli_credentials.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
use azure_core::{ | ||
auth::{AccessToken, TokenCredential, TokenResponse}, | ||
error::{Error, ErrorKind, ResultExt}, | ||
}; | ||
use oauth2::ClientId; | ||
use serde::Deserialize; | ||
use std::process::Command; | ||
use std::str; | ||
use time::OffsetDateTime; | ||
|
||
mod unix_date_string { | ||
use azure_core::error::{Error, ErrorKind}; | ||
use serde::{Deserialize, Deserializer}; | ||
use time::OffsetDateTime; | ||
|
||
pub fn parse(s: &str) -> azure_core::Result<OffsetDateTime> { | ||
let as_i64 = s.parse().map_err(|_| { | ||
Error::with_message(ErrorKind::DataConversion, || { | ||
format!("unable to parse expiration_date '{s}") | ||
}) | ||
})?; | ||
|
||
OffsetDateTime::from_unix_timestamp(as_i64).map_err(|_| { | ||
Error::with_message(ErrorKind::DataConversion, || { | ||
format!("unable to parse expiration_date '{s}") | ||
}) | ||
}) | ||
} | ||
|
||
pub fn deserialize<'de, D>(deserializer: D) -> Result<OffsetDateTime, D::Error> | ||
where | ||
D: Deserializer<'de>, | ||
{ | ||
let s = String::deserialize(deserializer)?; | ||
parse(&s).map_err(serde::de::Error::custom) | ||
} | ||
} | ||
|
||
#[derive(Debug, Clone, Deserialize)] | ||
struct CliTokenResponse { | ||
pub token: AccessToken, | ||
#[serde(with = "unix_date_string")] | ||
pub expiration_date: OffsetDateTime, | ||
} | ||
|
||
/// Authentication Mode | ||
/// | ||
/// Note: While the azureauth CLI supports devicecode, users wishing to use | ||
/// devicecode should use `azure_identity::device_code_flow` | ||
#[derive(Debug, Clone, Copy)] | ||
pub enum AzureauthCliMode { | ||
All, | ||
IntegratedWindowsAuth, | ||
Broker, | ||
Web, | ||
} | ||
|
||
/// Enables authentication to Azure Active Directory using Azure CLI to obtain an access token. | ||
pub struct AzureauthCliCredential { | ||
tenant_id: String, | ||
client_id: ClientId, | ||
modes: Vec<AzureauthCliMode>, | ||
prompt_hint: Option<String>, | ||
} | ||
|
||
impl AzureauthCliCredential { | ||
/// Create a new `AzureCliCredential` | ||
pub fn new<T, C>(tenant_id: T, client_id: C) -> Self | ||
where | ||
T: Into<String>, | ||
C: Into<ClientId>, | ||
{ | ||
Self { | ||
tenant_id: tenant_id.into(), | ||
client_id: client_id.into(), | ||
modes: Vec::new(), | ||
prompt_hint: None, | ||
} | ||
} | ||
|
||
pub fn add_mode(mut self, mode: AzureauthCliMode) -> Self { | ||
self.modes.push(mode); | ||
self | ||
} | ||
|
||
pub fn with_modes(mut self, modes: Vec<AzureauthCliMode>) -> Self { | ||
self.modes = modes; | ||
self | ||
} | ||
|
||
pub fn with_prompt_hint<S>(mut self, hint: S) -> Self | ||
where | ||
S: Into<String>, | ||
{ | ||
self.prompt_hint = Some(hint.into()); | ||
self | ||
} | ||
|
||
fn get_access_token(&self, resource: &str) -> azure_core::Result<CliTokenResponse> { | ||
// try using azureauth.exe first, such that azureauth through WSL is | ||
// used first if possible. | ||
let (cmd_name, use_windows_features) = if Command::new("azureauth.exe") | ||
.arg("--version") | ||
.output() | ||
.map(|x| x.status.success()) | ||
.unwrap_or(false) | ||
{ | ||
("azureauth.exe", true) | ||
} else { | ||
("azureauth", false) | ||
}; | ||
|
||
let mut cmd = Command::new(cmd_name); | ||
cmd.args([ | ||
"aad", | ||
"--scope", | ||
&format!("{resource}/.default"), | ||
resource, | ||
"--client", | ||
self.client_id.as_str(), | ||
"--tenant", | ||
self.tenant_id.as_str(), | ||
"--output", | ||
"json", | ||
]); | ||
|
||
if let Some(prompt_hint) = &self.prompt_hint { | ||
cmd.args(["--prompt-hint", prompt_hint]); | ||
} | ||
|
||
for mode in &self.modes { | ||
match mode { | ||
AzureauthCliMode::All => { | ||
cmd.args(["--mode", "all"]); | ||
} | ||
AzureauthCliMode::IntegratedWindowsAuth => { | ||
if use_windows_features { | ||
cmd.args(["--mode", "iwa"]); | ||
} | ||
} | ||
AzureauthCliMode::Broker => { | ||
if use_windows_features { | ||
cmd.args(["--mode", "broker"]); | ||
} | ||
} | ||
AzureauthCliMode::Web => { | ||
cmd.args(["--mode", "web"]); | ||
} | ||
}; | ||
} | ||
|
||
let result = cmd.output(); | ||
|
||
let output = result.map_err(|e| match e.kind() { | ||
std::io::ErrorKind::NotFound => { | ||
Error::message(ErrorKind::Other, "azureauth CLI not installed") | ||
} | ||
error_kind => Error::with_message(ErrorKind::Other, || { | ||
format!("Unknown error of kind: {error_kind:?}") | ||
}), | ||
})?; | ||
|
||
if !output.status.success() { | ||
let output = String::from_utf8_lossy(&output.stderr); | ||
return Err(Error::with_message(ErrorKind::Credential, || { | ||
format!("'azureauth' command failed: {output}") | ||
})); | ||
} | ||
|
||
let output = String::from_utf8(output.stdout)?; | ||
|
||
let token_response = serde_json::from_str::<CliTokenResponse>(&output) | ||
.map_kind(ErrorKind::DataConversion)?; | ||
Ok(token_response) | ||
} | ||
} | ||
|
||
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] | ||
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] | ||
impl TokenCredential for AzureauthCliCredential { | ||
async fn get_token(&self, resource: &str) -> azure_core::Result<TokenResponse> { | ||
let tr = self.get_access_token(resource)?; | ||
Ok(TokenResponse::new(tr.token, tr.expiration_date)) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn parse_example() -> azure_core::Result<()> { | ||
let src = r#"{ | ||
"user": "[email protected]", | ||
"display_name": "Example User", | ||
"token": "security token here", | ||
"expiration_date": "1700166595" | ||
}"#; | ||
|
||
let response: CliTokenResponse = serde_json::from_str(src)?; | ||
assert_eq!(response.token.secret(), "security token here"); | ||
assert_eq!( | ||
response.expiration_date, | ||
OffsetDateTime::from_unix_timestamp(1700166595).expect("known valid date") | ||
); | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters