diff --git a/crates/aptos-jwk-consensus/src/jwk_observer.rs b/crates/aptos-jwk-consensus/src/jwk_observer.rs new file mode 100644 index 0000000000000..798869de3bf8e --- /dev/null +++ b/crates/aptos-jwk-consensus/src/jwk_observer.rs @@ -0,0 +1,136 @@ +// Copyright © Aptos Foundation + +use anyhow::Result; +use aptos_channels::aptos_channel; +use aptos_logger::{debug, info}; +use aptos_types::jwks::{jwk::JWK, Issuer}; +use futures::{FutureExt, StreamExt}; +use move_core_types::account_address::AccountAddress; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::{sync::oneshot, task::JoinHandle, time::MissedTickBehavior}; + +#[derive(Serialize, Deserialize)] +struct OpenIDConfiguration { + issuer: String, + jwks_uri: String, +} + +#[derive(Serialize, Deserialize)] +struct JWKsResponse { + keys: Vec, +} + +/// Given an Open ID configuration URL, fetch its JWKs. +pub async fn fetch_jwks(my_addr: AccountAddress, config_url: Vec) -> Result> { + if cfg!(feature = "smoke-test") { + use reqwest::header; + let maybe_url = String::from_utf8(config_url); + let jwk_url = maybe_url?; + let client = reqwest::Client::new(); + let JWKsResponse { keys } = client + .get(jwk_url.as_str()) + .header(header::COOKIE, my_addr.to_hex()) + .send() + .await? + .json() + .await?; + let jwks = keys.into_iter().map(JWK::from).collect(); + Ok(jwks) + } else { + let maybe_url = String::from_utf8(config_url); + let config_url = maybe_url?; + let client = reqwest::Client::new(); + let OpenIDConfiguration { jwks_uri, .. } = + client.get(config_url.as_str()).send().await?.json().await?; + let JWKsResponse { keys } = client.get(jwks_uri.as_str()).send().await?.json().await?; + let jwks = keys.into_iter().map(JWK::from).collect(); + Ok(jwks) + } +} + +/// A process thread that periodically fetch JWKs of a provider and push it back to JWKManager. +pub struct JWKObserver { + close_tx: oneshot::Sender<()>, + join_handle: JoinHandle<()>, +} + +impl JWKObserver { + pub fn spawn( + my_addr: AccountAddress, + issuer: Issuer, + config_url: Vec, + fetch_interval: Duration, + observation_tx: aptos_channel::Sender<(), (Issuer, Vec)>, + ) -> Self { + let (close_tx, close_rx) = oneshot::channel(); + let join_handle = tokio::spawn(Self::thread_main( + fetch_interval, + my_addr, + issuer.clone(), + config_url.clone(), + observation_tx, + close_rx, + )); + info!( + "[JWK] observer spawned, issuer={:?}, config_url={:?}", + String::from_utf8(issuer), + String::from_utf8(config_url) + ); + Self { + close_tx, + join_handle, + } + } + + async fn thread_main( + fetch_interval: Duration, + my_addr: AccountAddress, + issuer: Issuer, + open_id_config_url: Vec, + observation_tx: aptos_channel::Sender<(), (Issuer, Vec)>, + close_rx: oneshot::Receiver<()>, + ) { + let mut interval = tokio::time::interval(fetch_interval); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + let mut close_rx = close_rx.into_stream(); + loop { + tokio::select! { + _ = interval.tick().fuse() => { + let result = fetch_jwks(my_addr, open_id_config_url.clone()).await; + debug!("observe_result={:?}", result); + if let Ok(mut jwks) = result { + jwks.sort(); + let _ = observation_tx.push((), (issuer.clone(), jwks)); + } + }, + _ = close_rx.select_next_some() => { + break; + } + } + } + } + + pub async fn shutdown(self) { + let Self { + close_tx, + join_handle, + } = self; + let _ = close_tx.send(()); + let _ = join_handle.await; + } +} + +#[ignore] +#[tokio::test] +async fn test_fetch_real_jwks() { + let jwks = fetch_jwks( + AccountAddress::ZERO, + "https://www.facebook.com/.well-known/openid-configuration/" + .as_bytes() + .to_vec(), + ) + .await + .unwrap(); + println!("{:?}", jwks); +} diff --git a/crates/aptos-jwk-consensus/src/lib.rs b/crates/aptos-jwk-consensus/src/lib.rs index d51816ec21376..25f214dcb76bb 100644 --- a/crates/aptos-jwk-consensus/src/lib.rs +++ b/crates/aptos-jwk-consensus/src/lib.rs @@ -30,6 +30,7 @@ pub fn start_jwk_consensus_runtime( } pub mod certified_update_producer; +pub mod jwk_observer; pub mod network; pub mod network_interface; pub mod observation_aggregation;