diff --git a/block_engine/src/auth.rs b/block_engine/src/auth.rs deleted file mode 100644 index 592240350a65df..00000000000000 --- a/block_engine/src/auth.rs +++ /dev/null @@ -1,182 +0,0 @@ -use std::{cell::RefCell, rc::Rc, sync::Arc}; - -use jito_protos::block_engine::{ - block_engine_relayer_client::BlockEngineRelayerClient, AccountsOfInterestRequest, - AccountsOfInterestUpdate, PacketBatchUpdate, SubscribeExpiringPacketsResponse, -}; -use solana_sdk::signature::{Keypair, Signer}; -use tokio::sync::mpsc::{channel, Sender}; -use tokio_stream::wrappers::ReceiverStream; -use tonic::{ - codegen::InterceptedService, metadata::MetadataValue, service::Interceptor, transport::Channel, - Code, Response, Status, Streaming, -}; - -// Auth header keys -pub const MESSAGE_BIN: &str = "message-bin"; -pub const PUBKEY_BIN: &str = "public-key-bin"; -pub const SIGNATURE_BIN: &str = "signature-bin"; - -/// Intercepts requests and adds the necessary headers for auth. -#[derive(Clone)] -pub struct AuthInterceptor { - /// Used to sign the server generated token. - keypair: Arc, - token: Rc>, -} - -impl AuthInterceptor { - pub fn new(keypair: Arc, token: Rc>) -> Self { - AuthInterceptor { keypair, token } - } - - pub fn should_retry( - status: &Status, - token: Rc>, - max_retries: usize, - n_retries: usize, - ) -> bool { - if max_retries == n_retries { - return false; - } - - let mut token = token.borrow_mut(); - if let Some(new_token) = Self::maybe_new_auth_token(status, &token) { - *token = new_token; - true - } else { - false - } - } - - /// Checks to see if the server returned a token to be signed and if it does not equal the current - /// token then the new token is returned and authentication can be retried. - fn maybe_new_auth_token(status: &Status, current_token: &str) -> Option { - if status.code() != Code::Unauthenticated { - return None; - } - - let msg = status.message().split_whitespace().collect::>(); - if msg.len() != 2 { - return None; - } - - if msg[0] != "token:" { - return None; - } - - if msg[1] != current_token { - Some(msg[1].to_string()) - } else { - None - } - } -} - -impl Interceptor for AuthInterceptor { - fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { - // Prefix with pubkey and hash it in order to ensure BlockEngine doesn't have us sign a malicious transaction. - let token = format!("{}-{}", self.keypair.pubkey(), self.token.take(),); - let hashed_token = solana_sdk::hash::hash(token.as_bytes()); - - request.metadata_mut().append_bin( - PUBKEY_BIN, - MetadataValue::from_bytes(&self.keypair.pubkey().to_bytes()), - ); - request.metadata_mut().append_bin( - MESSAGE_BIN, - MetadataValue::from_bytes(hashed_token.to_bytes().as_slice()), - ); - request.metadata_mut().append_bin( - SIGNATURE_BIN, - MetadataValue::from_bytes( - self.keypair - .sign_message(hashed_token.to_bytes().as_slice()) - .as_ref(), - ), - ); - - Ok(request) - } -} - -/// Wrapper client that takes care of extracting the auth challenge and retrying requests. -pub struct AuthClient { - inner: BlockEngineRelayerClient>, - token: Rc>, - max_retries: usize, -} - -impl AuthClient { - pub fn new( - inner: BlockEngineRelayerClient>, - max_retries: usize, - ) -> Self { - let token = Rc::new(RefCell::new(String::default())); - Self { - inner, - token, - max_retries, - } - } - - pub async fn subscribe_accounts_of_interest( - &mut self, - req: AccountsOfInterestRequest, - ) -> Result>, Status> { - let mut n_retries = 0; - loop { - return match self.inner.subscribe_accounts_of_interest(req.clone()).await { - Ok(resp) => Ok(resp), - Err(status) => { - if AuthInterceptor::should_retry( - &status, - self.token.clone(), - self.max_retries, - n_retries, - ) { - n_retries += 1; - continue; - } - Err(status) - } - }; - } - } - - pub async fn start_expiring_packet_stream( - &mut self, - buffer: usize, - ) -> Result< - ( - Sender, - Response, - ), - Status, - > { - let mut n_retries = 0; - loop { - let (tx, rx) = channel::(buffer); - let receiver_stream = ReceiverStream::new(rx); - return match self - .inner - .start_expiring_packet_stream(receiver_stream) - .await - { - Ok(resp) => Ok((tx, resp)), - Err(status) => { - if AuthInterceptor::should_retry( - &status, - self.token.clone(), - self.max_retries, - n_retries, - ) { - n_retries += 1; - continue; - } - Err(status) - } - }; - } - } -} diff --git a/block_engine/src/block_engine.rs b/block_engine/src/block_engine.rs index 847037b9c99c0d..35d42fc378b20d 100644 --- a/block_engine/src/block_engine.rs +++ b/block_engine/src/block_engine.rs @@ -1,15 +1,18 @@ use std::{ - cell::RefCell, collections::{hash_map::RandomState, HashSet}, - rc::Rc, str::FromStr, - sync::Arc, + sync::{Arc, Mutex}, thread, thread::{Builder, JoinHandle}, time::{Duration, SystemTime}, }; use jito_protos::{ + auth::{ + auth_service_client::AuthServiceClient, GenerateAuthChallengeRequest, + GenerateAuthTokensRequest, GenerateAuthTokensResponse, RefreshAccessTokenRequest, Role, + Token, + }, block_engine::{ accounts_of_interest_update, block_engine_relayer_client::BlockEngineRelayerClient, packet_batch_update::Msg, AccountsOfInterestRequest, AccountsOfInterestUpdate, @@ -22,17 +25,43 @@ use jito_protos::{ use log::{error, *}; use prost_types::Timestamp; use solana_perf::packet::PacketBatch; -use solana_sdk::{pubkey::Pubkey, signer::keypair::Keypair}; +use solana_sdk::{pubkey::Pubkey, signature::Signer, signer::keypair::Keypair}; use thiserror::Error; use tokio::{ runtime::Runtime, select, - sync::mpsc::{Receiver, Sender}, + sync::mpsc::{channel, Receiver, Sender}, time::{interval, sleep}, }; -use tonic::{transport::Endpoint, Response, Status, Streaming}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{ + codegen::InterceptedService, + service::Interceptor, + transport::{Channel, Endpoint}, + Response, Status, Streaming, +}; -use crate::auth::{AuthClient, AuthInterceptor}; +struct AuthInterceptor { + access_token: Arc>, +} + +impl AuthInterceptor { + pub fn new(access_token: Arc>) -> Self { + AuthInterceptor { access_token } + } +} + +impl Interceptor for AuthInterceptor { + fn call(&mut self, mut request: tonic::Request<()>) -> Result, Status> { + request.metadata_mut().insert( + "authorization", + format!("Bearer {}", self.access_token.lock().unwrap().value) + .parse() + .unwrap(), + ); + Ok(request) + } +} pub struct BlockEnginePackets { pub packet_batches: Vec, @@ -42,17 +71,11 @@ pub struct BlockEnginePackets { #[derive(Error, Debug)] pub enum BlockEngineError { - #[error("connection closed")] - ConnectionClosedError, - - #[error("malformed message")] - MalformedMessage, + #[error("auth service failed: {0}")] + AuthServiceFailure(String), - #[error("heartbeat timeout")] - HeartbeatTimeout, - - #[error("GRPC error: {0}")] - GrpcError(#[from] Status), + #[error("block engine failed: {0}")] + BlockEngineFailure(String), } pub type BlockEngineResult = Result; @@ -65,11 +88,13 @@ pub struct BlockEngineRelayerHandler { impl BlockEngineRelayerHandler { pub fn new( block_engine_url: String, + auth_service_url: String, block_engine_receiver: Receiver, keypair: Arc, ) -> BlockEngineRelayerHandler { let block_engine_forwarder = Self::start_block_engine_relayer_stream( block_engine_url, + auth_service_url, block_engine_receiver, keypair, ); @@ -84,44 +109,28 @@ impl BlockEngineRelayerHandler { fn start_block_engine_relayer_stream( block_engine_url: String, + auth_service_url: String, mut block_engine_receiver: Receiver, keypair: Arc, ) -> JoinHandle<()> { - let endpoint = Endpoint::from_shared(block_engine_url.to_string()) - .expect("valid block engine endpoint"); Builder::new() .name("jito_block_engine_relayer_stream".into()) .spawn(move || { let rt = Runtime::new().unwrap(); rt.block_on(async move { - let auth_interceptor = - AuthInterceptor::new(keypair, Rc::new(RefCell::new(String::default()))); loop { - sleep(Duration::from_secs(1)).await; - - info!("connecting to block engine at url: {:?}", block_engine_url); - match endpoint.connect().await { - Ok(channel) => { - let client = BlockEngineRelayerClient::with_interceptor( - channel, - auth_interceptor.clone(), - ); - let mut client = AuthClient::new(client, 3); - - match Self::start_event_loop( - &mut client, - &mut block_engine_receiver, - ) - .await - { - Ok(_) => {} - Err(e) => { - error!("error with packet stream: {:?}", e); - } - } - } + match Self::auth_and_connect( + &block_engine_url, + &auth_service_url, + &mut block_engine_receiver, + &keypair, + ) + .await + { + Ok(_) => {} Err(e) => { - error!("error connecting: {:?}", e); + error!("error authenticating and connecting: {:?}", e); + sleep(Duration::from_secs(2)).await; } } } @@ -130,24 +139,142 @@ impl BlockEngineRelayerHandler { .unwrap() } + /// Relayers are whitelisted in the block engine. In order to auth, a challenge-response handshake + /// is performed. After that, the relayer can fetch an access and refresh JWT token that's provided + /// in request headers to the block engine. + async fn auth( + auth_client: &mut AuthServiceClient, + keypair: &Arc, + ) -> BlockEngineResult<(Token, Token)> { + let auth_response = auth_client + .generate_auth_challenge(GenerateAuthChallengeRequest { + role: Role::Relayer.into(), + pubkey: keypair.pubkey().to_bytes().to_vec(), + }) + .await + .map_err(|e| BlockEngineError::AuthServiceFailure(e.to_string()))?; + + let challenge = format!( + "{}-{}", + keypair.pubkey(), + auth_response.into_inner().challenge + ); + let signed_challenge = keypair.sign_message(challenge.as_bytes()).as_ref().to_vec(); + + let GenerateAuthTokensResponse { + access_token: maybe_access_token, + refresh_token: maybe_refresh_token, + } = auth_client + .generate_auth_tokens(GenerateAuthTokensRequest { + challenge, + client_pubkey: keypair.pubkey().as_ref().to_vec(), + signed_challenge, + }) + .await + .map_err(|e| BlockEngineError::AuthServiceFailure(e.to_string()))? + .into_inner(); + + if maybe_access_token.is_none() || maybe_refresh_token.is_none() { + return Err(BlockEngineError::AuthServiceFailure( + "failed to get valid auth tokens".to_string(), + )); + } + let access_token = maybe_access_token.unwrap(); + let refresh_token = maybe_refresh_token.unwrap(); + + if access_token.expires_at_utc.is_none() || refresh_token.expires_at_utc.is_none() { + return Err(BlockEngineError::AuthServiceFailure( + "auth tokens don't have valid expiration time".to_string(), + )); + } + + Ok((access_token, refresh_token)) + } + + /// Authenticates the relayer with the block engine and connects to the forwarding service + async fn auth_and_connect( + block_engine_url: &str, + auth_service_url: &str, + block_engine_receiver: &mut Receiver, + keypair: &Arc, + ) -> BlockEngineResult<()> { + let auth_endpoint = Endpoint::from_str(auth_service_url).expect("valid auth url"); + let channel = auth_endpoint + .connect() + .await + .map_err(|e| BlockEngineError::AuthServiceFailure(e.to_string()))?; + let mut auth_client = AuthServiceClient::new(channel); + + let (access_token, mut refresh_token) = Self::auth(&mut auth_client, keypair).await?; + + let access_token_expiration = + SystemTime::try_from(access_token.expires_at_utc.as_ref().unwrap().clone()).unwrap(); + let refresh_token_expiration = + SystemTime::try_from(refresh_token.expires_at_utc.as_ref().unwrap().clone()).unwrap(); + + info!( + "access_token_expiration: {:?}, refresh_token_expiration: {:?}", + access_token_expiration + .duration_since(SystemTime::now()) + .unwrap(), + refresh_token_expiration + .duration_since(SystemTime::now()) + .unwrap() + ); + + let shared_access_token = Arc::new(Mutex::new(access_token)); + let auth_interceptor = AuthInterceptor::new(shared_access_token.clone()); + + let block_engine_endpoint = + Endpoint::from_str(block_engine_url).expect("valid block engine url"); + let block_engine_channel = block_engine_endpoint + .connect() + .await + .map_err(|e| BlockEngineError::BlockEngineFailure(e.to_string()))?; + let block_engine_client = + BlockEngineRelayerClient::with_interceptor(block_engine_channel, auth_interceptor); + Self::start_event_loop( + block_engine_client, + block_engine_receiver, + auth_client, + keypair, + &mut refresh_token, + shared_access_token, + ) + .await + } + /// Starts the bi-directional packet stream. /// The relayer will send heartbeats and packets to the block engine. /// The block engine will send heartbeats back to the relayer. /// If there's a missed heartbeat or any issues responding to each other, they'll disconnect and /// try to re-establish connection async fn start_event_loop( - client: &mut AuthClient, + mut client: BlockEngineRelayerClient>, block_engine_receiver: &mut Receiver, + auth_client: AuthServiceClient, + keypair: &Arc, + refresh_token: &mut Token, + shared_access_token: Arc>, ) -> BlockEngineResult<()> { let subscribe_aoi_stream = client .subscribe_accounts_of_interest(AccountsOfInterestRequest {}) - .await?; - let (packet_msg_sender, _response) = client.start_expiring_packet_stream(100).await?; + .await + .map_err(|e| BlockEngineError::BlockEngineFailure(e.to_string()))?; + let (packet_msg_sender, packet_msg_receiver) = channel(100); + let _response = client + .start_expiring_packet_stream(ReceiverStream::new(packet_msg_receiver)) + .await + .map_err(|e| BlockEngineError::BlockEngineFailure(e.to_string()))?; Self::handle_packet_stream( packet_msg_sender, block_engine_receiver, subscribe_aoi_stream, + auth_client, + keypair, + refresh_token, + shared_access_token, ) .await } @@ -156,6 +283,10 @@ impl BlockEngineRelayerHandler { block_engine_packet_sender: Sender, block_engine_receiver: &mut Receiver, subscribe_aoi_stream: Response>, + mut auth_client: AuthServiceClient, + keypair: &Arc, + refresh_token: &mut Token, + shared_access_token: Arc>, ) -> BlockEngineResult<()> { let mut aoi_stream = subscribe_aoi_stream.into_inner(); @@ -165,7 +296,9 @@ impl BlockEngineRelayerHandler { let mut accounts_of_interest: HashSet = HashSet::new(); let mut heartbeat_count = 0; let heartbeat = interval(Duration::from_millis(500)); + let refresh_interval = interval(Duration::from_secs(60)); tokio::pin!(heartbeat); + tokio::pin!(refresh_interval); loop { select! { _ = heartbeat.tick() => { @@ -178,7 +311,84 @@ impl BlockEngineRelayerHandler { block_engine_batches = block_engine_receiver.recv() => { Self::forward_packets(&block_engine_packet_sender, block_engine_batches).await?; } + _ = refresh_interval.tick() => { + Self::maybe_refresh_auth(&mut auth_client, keypair, refresh_token, &shared_access_token).await?; + } + } + } + } + + /// Refresh authentication tokens if they're about to expire + async fn maybe_refresh_auth( + auth_client: &mut AuthServiceClient, + keypair: &Arc, + refresh_token: &mut Token, + shared_access_token: &Arc>, + ) -> BlockEngineResult<()> { + // expires_at_utc is checked for None when establishing connection + let access_token_expiration_time = shared_access_token + .lock() + .unwrap() + .expires_at_utc + .as_ref() + .unwrap() + .clone(); + + let access_token_expiration_time = + SystemTime::try_from(access_token_expiration_time).unwrap(); + let access_token_duration_left = + access_token_expiration_time.duration_since(SystemTime::now()); + + let refresh_token_expiration_time = + SystemTime::try_from(refresh_token.expires_at_utc.as_ref().unwrap().clone()).unwrap(); + let refresh_token_duration_left = + refresh_token_expiration_time.duration_since(SystemTime::now()); + + let is_access_token_expiring_soon = match access_token_duration_left { + Ok(dur) => dur < Duration::from_secs(5 * 60), + Err(_) => true, + }; + let is_refresh_token_expiring_soon = match refresh_token_duration_left { + Ok(dur) => dur < Duration::from_secs(5 * 60), + Err(_) => true, + }; + + match ( + is_refresh_token_expiring_soon, + is_access_token_expiring_soon, + ) { + (true, _) => { + // re-run the authentication process from the beginning + let (access_token, new_refresh_token) = Self::auth(auth_client, keypair).await?; + + *refresh_token = new_refresh_token; + *shared_access_token.lock().unwrap() = access_token; + info!("access and refresh token were refreshed"); + + Ok(()) + } + (false, true) => { + // fetch a new access token + let response = auth_client + .refresh_access_token(RefreshAccessTokenRequest { + refresh_token: refresh_token.value.clone(), + }) + .await + .map_err(|e| BlockEngineError::AuthServiceFailure(e.to_string()))?; + + let maybe_access_token = response.into_inner().access_token; + if maybe_access_token.is_none() { + return Err(BlockEngineError::AuthServiceFailure( + "missing access token".to_string(), + )); + } + + *shared_access_token.lock().unwrap() = maybe_access_token.unwrap(); + info!("access token was refreshed"); + + Ok(()) } + (false, false) => Ok(()), } } @@ -188,7 +398,9 @@ impl BlockEngineRelayerHandler { ) -> BlockEngineResult<()> { match maybe_msg { Ok(Some(aoi_update)) => match aoi_update.msg { - None => Err(BlockEngineError::MalformedMessage), + None => Err(BlockEngineError::BlockEngineFailure( + "AOI message malformed".to_string(), + )), Some(accounts_of_interest_update::Msg::Add(accounts)) => { let accounts: HashSet = accounts .accounts @@ -223,8 +435,10 @@ impl BlockEngineRelayerHandler { Ok(()) } }, - Ok(None) => Err(BlockEngineError::ConnectionClosedError), - Err(e) => Err(e.into()), + Ok(None) => Err(BlockEngineError::BlockEngineFailure( + "disconnected".to_string(), + )), + Err(e) => Err(BlockEngineError::BlockEngineFailure(e.to_string())), } } @@ -233,8 +447,8 @@ impl BlockEngineRelayerHandler { block_engine_packet_sender: &Sender, block_engine_batches: Option, ) -> BlockEngineResult<()> { - let block_engine_batches = - block_engine_batches.ok_or(BlockEngineError::ConnectionClosedError)?; + let block_engine_batches = block_engine_batches + .ok_or_else(|| BlockEngineError::BlockEngineFailure("disconnected".to_string()))?; if block_engine_packet_sender .send(PacketBatchUpdate { msg: Some(Msg::Batches(ExpiringPacketBatch { @@ -258,7 +472,9 @@ impl BlockEngineRelayerHandler { .await .is_err() { - Err(BlockEngineError::ConnectionClosedError) + Err(BlockEngineError::BlockEngineFailure( + "disconnected".to_string(), + )) } else { Ok(()) } @@ -279,7 +495,9 @@ impl BlockEngineRelayerHandler { .await .is_err() { - return Err(BlockEngineError::ConnectionClosedError); + return Err(BlockEngineError::BlockEngineFailure( + "disconnected".to_string(), + )); } Ok(()) diff --git a/block_engine/src/lib.rs b/block_engine/src/lib.rs index 62ef10d36eb9c0..c95acf45ecd0bd 100644 --- a/block_engine/src/lib.rs +++ b/block_engine/src/lib.rs @@ -1,2 +1 @@ -pub mod auth; pub mod block_engine; diff --git a/jito-protos/build.rs b/jito-protos/build.rs index e1b22a7fe4f47e..fe90005139ed55 100644 --- a/jito-protos/build.rs +++ b/jito-protos/build.rs @@ -4,6 +4,7 @@ fn main() { configure() .compile( &[ + "protos/auth.proto", "protos/block.proto", "protos/block_engine.proto", "protos/bundle.proto", diff --git a/jito-protos/protos b/jito-protos/protos index b3dd2e31c760aa..1f1e2660ca26a3 160000 --- a/jito-protos/protos +++ b/jito-protos/protos @@ -1 +1 @@ -Subproject commit b3dd2e31c760aa5670a97ddd75b6b3ecb753fbd9 +Subproject commit 1f1e2660ca26a33ebdef27c0a2f06d7ef38916bc diff --git a/jito-protos/src/lib.rs b/jito-protos/src/lib.rs index 312a9f482dce1a..584243f17fa08d 100644 --- a/jito-protos/src/lib.rs +++ b/jito-protos/src/lib.rs @@ -1,5 +1,9 @@ pub mod convert; +pub mod auth { + tonic::include_proto!("auth"); +} + pub mod block { tonic::include_proto!("block"); } diff --git a/transaction-relayer/src/main.rs b/transaction-relayer/src/main.rs index 25d45687f4ade8..ff1862cac5cd1b 100644 --- a/transaction-relayer/src/main.rs +++ b/transaction-relayer/src/main.rs @@ -16,7 +16,7 @@ use jito_rpc::load_balancer::LoadBalancer; use jito_transaction_relayer::forwarder::start_forward_and_delay_thread; use log::info; use solana_net_utils::multi_bind_in_range; -use solana_sdk::signature::{Keypair, Signer}; +use solana_sdk::signature::{read_keypair_file, Signer}; use tokio::sync::mpsc::channel; #[derive(Parser, Debug)] @@ -82,6 +82,14 @@ struct Args { /// Block engine address #[clap(long, env, value_parser, default_value = "http://127.0.0.1:13334")] block_engine_url: String, + + /// Authentication service address. Keypairs are authenticated against the block engine + #[clap(long, env, value_parser, default_value = "http://127.0.0.1:14444")] + auth_service_url: String, + + /// Keypair path + #[clap(long, env, value_parser)] + keypair_path: String, } struct Sockets { @@ -143,7 +151,7 @@ fn main() { let sockets = get_sockets(&args); - let keypair = Arc::new(Keypair::new()); + let keypair = Arc::new(read_keypair_file(args.keypair_path).expect("keypair file exists")); solana_metrics::set_host_id(keypair.pubkey().to_string()); info!("Relayer started with pubkey: {}", keypair.pubkey()); @@ -190,8 +198,12 @@ fn main() { block_engine_sender, 1, ); - let block_engine_forwarder = - BlockEngineRelayerHandler::new(args.block_engine_url, block_engine_receiver, keypair); + let block_engine_forwarder = BlockEngineRelayerHandler::new( + args.block_engine_url, + args.auth_service_url, + block_engine_receiver, + keypair, + ); let server_addr = SocketAddr::new(args.grpc_bind_ip, args.grpc_bind_port); let relayer_server = RelayerImpl::new(