Skip to content

Commit

Permalink
HTTP handler for status match
Browse files Browse the repository at this point in the history
Quite a bit of changes here, mostly because of a need to have test server for shard endpoint.
  • Loading branch information
akoshelev committed Nov 29, 2024
1 parent aa8077c commit 3b7c242
Show file tree
Hide file tree
Showing 11 changed files with 557 additions and 120 deletions.
29 changes: 27 additions & 2 deletions ipa-core/src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ use crate::{
},
executor::IpaRuntime,
helpers::{
query::{PrepareQuery, QueryConfig, QueryInput},
query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput},
TransportIdentity,
},
net::{http_serde, Error, CRYPTO_PROVIDER},
net::{error::ShardQueryStatusMismatchError, http_serde, Error, CRYPTO_PROVIDER},
protocol::{Gate, QueryId},
};

Expand Down Expand Up @@ -384,6 +384,31 @@ impl<F: ConnectionFlavor> IpaHttpClient<F> {
let resp = self.request(req).await?;
resp_ok(resp).await
}

/// This API is used by leader shards in MPC to request query status information on peers.
/// If a given peer has status that doesn't match the one provided by the leader, it responds
/// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned.
///
/// # Errors
/// If the request has illegal arguments, or fails to be delivered
pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> {
let req = http_serde::query::status_match::try_into_http_request(
&data,
self.scheme.clone(),
self.authority.clone(),
)?;
let resp = self.request(req).await?;

match resp.status() {
StatusCode::OK => Ok(()),
StatusCode::PRECONDITION_FAILED => {
let bytes = response_to_bytes(resp).await?;
let err = serde_json::from_slice::<ShardQueryStatusMismatchError>(&bytes)?;
Err(err.into())
}
_ => Err(Error::from_failed_resp(resp).await),
}
}
}

impl IpaHttpClient<Helper> {
Expand Down
23 changes: 21 additions & 2 deletions ipa-core/src/net/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use axum::{
};

use crate::{
error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, sharding::ShardIndex,
error::BoxError, net::client::ResponseFromEndpoint, protocol::QueryId, query::QueryStatus,
sharding::ShardIndex,
};

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -59,8 +60,13 @@ pub enum Error {
#[source]
inner: hyper_util::client::legacy::Error,
},
#[error("{error}")]
#[error("{code}: {error}")]
Application { code: StatusCode, error: BoxError },
#[error(transparent)]
ShardQueryStatusMismatch {
#[from]
error: ShardQueryStatusMismatchError,
},
}

impl Error {
Expand Down Expand Up @@ -142,6 +148,12 @@ pub struct ShardError {
pub source: Error,
}

#[derive(Debug, thiserror::Error, serde::Deserialize, serde::Serialize)]
#[error("Query status mismatch. Actual status: {actual}")]
pub struct ShardQueryStatusMismatchError {
pub actual: QueryStatus,
}

impl IntoResponse for Error {
fn into_response(self) -> Response {
let status_code = match self {
Expand All @@ -165,6 +177,13 @@ impl IntoResponse for Error {
| Self::MissingExtension(_) => StatusCode::INTERNAL_SERVER_ERROR,

Self::Application { code, .. } => code,
Self::ShardQueryStatusMismatch { error } => {
return (
StatusCode::PRECONDITION_FAILED,
serde_json::to_string(&error).unwrap(),
)
.into_response()
}
};
(status_code, self.to_string()).into_response()
}
Expand Down
44 changes: 44 additions & 0 deletions ipa-core/src/net/http_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,4 +608,48 @@ pub mod query {

pub const AXUM_PATH: &str = "/:query_id/kill";
}

pub mod status_match {
use serde::{Deserialize, Serialize};

use crate::{helpers::query::CompareStatusRequest, query::QueryStatus};

#[derive(Serialize, Deserialize)]
pub struct StatusQueryString {
pub status: QueryStatus,
}

impl StatusQueryString {
fn url_encode(&self) -> String {
// todo: serde urlencoded
format!("status={}", self.status)
}
}

impl From<QueryStatus> for StatusQueryString {
fn from(value: QueryStatus) -> Self {
Self { status: value }
}
}

pub fn try_into_http_request(
req: &CompareStatusRequest,
scheme: axum::http::uri::Scheme,
authority: axum::http::uri::Authority,
) -> crate::net::http_serde::OutgoingRequest {
let uri = axum::http::uri::Uri::builder()
.scheme(scheme)
.authority(authority)
.path_and_query(format!(
"{}/{}/status-match?{}",
crate::net::http_serde::query::BASE_AXUM_PATH,
req.query_id.as_ref(),
StatusQueryString::from(req.status).url_encode(),
))
.build()?;
Ok(hyper::Request::get(uri).body(axum::body::Body::empty())?)
}

pub const AXUM_PATH: &str = "/:query_id/status-match";
}
}
9 changes: 6 additions & 3 deletions ipa-core/src/net/server/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ mod query;

use axum::Router;

use crate::net::{http_serde, transport::MpcHttpTransport, ShardHttpTransport};
use crate::{
net::{http_serde, transport::MpcHttpTransport, HttpTransport, Shard},
sync::Arc,
};

pub fn mpc_router(transport: MpcHttpTransport) -> Router {
echo::router().nest(
http_serde::query::BASE_AXUM_PATH,
Router::new()
.merge(query::query_router(transport.clone()))
.merge(query::h2h_router(transport)),
.merge(query::h2h_router(transport.inner_transport)),
)
}

pub fn shard_router(transport: ShardHttpTransport) -> Router {
pub fn shard_router(transport: Arc<HttpTransport<Shard>>) -> Router {
echo::router().nest(
http_serde::query::BASE_AXUM_PATH,
Router::new().merge(query::s2s_router(transport)),
Expand Down
33 changes: 13 additions & 20 deletions ipa-core/src/net/server/handlers/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod kill;
mod prepare;
mod results;
mod status;
mod status_match;
mod step;

use std::marker::PhantomData;
Expand All @@ -21,8 +22,8 @@ use tower::{layer::layer_fn, Service};

use crate::{
net::{
server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper, Shard,
ShardHttpTransport,
server::ClientIdentity, transport::MpcHttpTransport, ConnectionFlavor, Helper,
HttpTransport, Shard,
},
sync::Arc,
};
Expand All @@ -48,19 +49,20 @@ pub fn query_router(transport: MpcHttpTransport) -> Router {
/// particular query, to coordinate servicing that query.
//
// It might make sense to split the query and h2h handlers into two modules.
pub fn h2h_router(transport: MpcHttpTransport) -> Router {
pub fn h2h_router(transport: Arc<HttpTransport<Helper>>) -> Router {
Router::new()
.merge(step::router(Arc::clone(&transport.inner_transport)))
.merge(prepare::router(transport.inner_transport))
.merge(step::router(Arc::clone(&transport)))
.merge(prepare::router(transport))
.layer(layer_fn(HelperAuthentication::<_, Helper>::new))
}

/// Construct router for shard-to-shard communications similar to [`h2h_router`].
pub fn s2s_router(transport: ShardHttpTransport) -> Router {
pub fn s2s_router(transport: Arc<HttpTransport<Shard>>) -> Router {
Router::new()
.merge(step::router(Arc::clone(&transport.inner_transport)))
.merge(prepare::router(Arc::clone(&transport.inner_transport)))
.merge(results::router(transport.inner_transport))
.merge(step::router(Arc::clone(&transport)))
.merge(prepare::router(Arc::clone(&transport)))
.merge(results::router(Arc::clone(&transport)))
.merge(status_match::router(transport))
.layer(layer_fn(HelperAuthentication::<_, Shard>::new))
}

Expand Down Expand Up @@ -125,12 +127,11 @@ pub mod test_helpers {
use std::{any::Any, sync::Arc};

use axum::body::Body;
use http_body_util::BodyExt;
use hyper::{http::request, StatusCode};

use crate::{
helpers::{HelperIdentity, RequestHandler},
net::test::TestServer,
net::{test::TestServer, Helper},
};

/// Helper trait for optionally adding an extension to a request.
Expand Down Expand Up @@ -178,14 +179,6 @@ pub mod test_helpers {
req: hyper::Request<Body>,
handler: Arc<dyn RequestHandler<HelperIdentity>>,
) -> bytes::Bytes {
let test_server = TestServer::builder()
.with_request_handler(handler)
.build()
.await;
let resp = test_server.server.handle_req(req).await;
let status = resp.status();
assert_eq!(StatusCode::OK, status);

resp.into_body().collect().await.unwrap().to_bytes()
TestServer::<Helper>::oneshot_success(req, handler).await
}
}
Loading

0 comments on commit 3b7c242

Please sign in to comment.