Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(deps): update axum 0.6.20 to 0.7.3 #66

Merged
merged 9 commits into from
Jan 29, 2024
Merged
370 changes: 263 additions & 107 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,26 @@ aws-config = { version = "1.0.1", optional = true }
aws-sdk-kms = { version = "1.3.0", optional = true }
base64 = "0.21.2"
futures = "0.3.28"
tracing = { version = "0.1.40" }
tracing-appender = { version = "0.2.2" }
tracing-attributes = "0.1.27"
tracing-subscriber = { version = "0.3.17", default-features = true, features = ["env-filter", "json", "registry"] }
gethostname = "0.4.3"
rustc-hash = "1.1"
once_cell = "1.18.0"
vaultrs = { version = "0.7.0", optional = true }

# Tokio Dependencies
tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread"] }
axum = "0.6.20"
hyper = "0.14.27"
axum = "0.7.3"
hyper = "1.0.1"
tower = { version = "0.4.13", features = ["limit", "buffer", "load-shed"] }
tower-http = { version = "0.4.4", features = ["trace"] }

tower-http = { version = "0.5.0", features = ["trace"] }
tracing = { version = "0.1.40" }
tracing-appender = { version = "0.2.2" }
tracing-attributes = "0.1.27"
tracing-subscriber = { version = "0.3.17", default-features = true, features = [
"env-filter",
"json",
"registry",
] }
http-body-util = "0.1.0"

diesel = { version = "2.1.3", features = ["postgres", "serde_json", "time"] }
diesel-async = { version = "0.4.1", features = ["postgres", "deadpool"] }
Expand All @@ -61,7 +65,7 @@ argh = "0.1.12"

[dev-dependencies]
rand = "0.8.5"
axum-test = "13.0.1"
axum-test = "14.2.2"

[build-dependencies]
cargo_metadata = "0.15.4"
Expand Down
13 changes: 7 additions & 6 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use axum::routing;
use error_stack::ResultExt;
use hyper::server::conn;
use masking::PeekInterface;
#[cfg(feature = "key_custodian")]
use tokio::sync::{mpsc::Sender, RwLock};
Expand Down Expand Up @@ -52,7 +51,7 @@ pub async fn server1_builder(
state: Arc<RwLock<AppState>>,
server_tx: Sender<()>,
) -> Result<
hyper::Server<conn::AddrIncoming, routing::IntoMakeService<axum::Router>>,
axum::serve::Serve<routing::IntoMakeService<axum::Router>, axum::Router>,
error::ConfigurationError,
>
where
Expand All @@ -71,7 +70,9 @@ where
.with_state(shared_state)
.route("/health", routing::get(routes::health::health));

let server = axum::Server::try_bind(&socket_addr)?.serve(router.into_make_service());
let tcp_listener = tokio::net::TcpListener::bind(&socket_addr).await?;
let server = axum::serve(tcp_listener, router.into_make_service());

Ok(server)
}

Expand All @@ -82,7 +83,7 @@ where
pub async fn server2_builder(
state: &AppState,
) -> Result<
hyper::Server<conn::AddrIncoming, routing::IntoMakeService<axum::Router>>,
axum::serve::Serve<routing::IntoMakeService<axum::Router>, axum::Router>,
error::ConfigurationError,
>
where
Expand Down Expand Up @@ -121,8 +122,8 @@ where
.level(tracing::Level::ERROR),
),
);

let server = axum::Server::try_bind(&socket_addr)?.serve(router.into_make_service());
let tcp_listener = tokio::net::TcpListener::bind(&socket_addr).await?;
let server = axum::serve(tcp_listener, router.into_make_service());
Ok(server)
}

Expand Down
6 changes: 3 additions & 3 deletions src/bin/locker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
{
let state_lock = state.clone();

let (server1_tx, mut server1_rx) = tokio::sync::mpsc::channel::<()>(1);
let (server1_tx, server1_rx) = tokio::sync::mpsc::channel::<()>(1);

let server1 = tartarus::app::server1_builder(state_lock, server1_tx.clone())
.await?
.with_graceful_shutdown(graceful_shutdown_server1(&mut server1_rx));
.with_graceful_shutdown(graceful_shutdown_server1(server1_rx));

logger::info!(
"Key Custodian started [{:?}] [{:?}]",
Expand All @@ -42,7 +42,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

#[cfg(feature = "key_custodian")]
async fn graceful_shutdown_server1(recv: &mut tokio::sync::mpsc::Receiver<()>) {
async fn graceful_shutdown_server1(mut recv: tokio::sync::mpsc::Receiver<()>) {
recv.recv().await;
logger::info!("Shutting down the server1 gracefully.");
}
20 changes: 16 additions & 4 deletions src/bin/utils.rs
ShankarSinghC marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let pub_key = read_file_to_string(
&public_key.ok_or(error::CryptoError::InvalidData("public key not found"))?,
)?;
jwe_operation(|x| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP).encrypt(x)
jwe_operation(|payload| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP)
.encrypt(payload)
.and_then(|payload| {
Ok(serde_json::to_vec(&payload)
.map_err(error::CryptoError::SerdeJsonError)?)
})
})?;
}
SubCommand::JweDecrypt(JweD {
Expand All @@ -94,8 +99,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let pub_key = read_file_to_string(
&public_key.ok_or(error::CryptoError::InvalidData("private key not found"))?,
)?;
jwe_operation(|x| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP).decrypt(x)
jwe_operation(|payload| {
serde_json::from_slice(&payload)
.map_err(error::CryptoError::SerdeJsonError)
.map_err(Into::into)
.and_then(|payload| {
JWEncryption::new(priv_key, pub_key, jwe::RSA_OAEP_256, jwe::RSA_OAEP)
.decrypt(payload)
})
// (x)
})?;
}
}
Expand Down
12 changes: 5 additions & 7 deletions src/crypto/jw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ impl JweBody {
}
}

impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
impl super::Encryption<Vec<u8>, JweBody> for JWEncryption {
type ReturnType<'a, T> = Result<T, ContainerError<error::CryptoError>>;

fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
fn encrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, JweBody> {
let payload = input;
let jws_encoded = jws_sign_payload(&payload, self.private_key.peek().as_bytes())?;
let jws_body = JwsBody::from_dotted_str(&jws_encoded).ok_or(
Expand All @@ -107,13 +107,11 @@ impl super::Encryption<Vec<u8>, Vec<u8>> for JWEncryption {
)?;
let jwe_body = JweBody::from_str(&jwe_encrypted)
.ok_or(error::CryptoError::InvalidData("JWE data incomplete"))?;
Ok(serde_json::to_vec(&jwe_body).map_err(error::CryptoError::from)?)
Ok(jwe_body)
}

fn decrypt(&self, input: Vec<u8>) -> Self::ReturnType<'_, Vec<u8>> {
let jwe_body: JweBody = serde_json::from_slice(&input).map_err(error::CryptoError::from)?;
let jwe_encoded = jwe_body.get_dotted_jwe();
// let algo = jwe::RSA_OAEP_256;
fn decrypt(&self, input: JweBody) -> Self::ReturnType<'_, Vec<u8>> {
let jwe_encoded = input.get_dotted_jwe();
let jwe_decrypted =
decrypt_jwe(&jwe_encoded, self.private_key.peek(), self.decryption_algo)?;

Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum ConfigurationError {
ServerError(#[from] hyper::Error),
#[error("invalid host for socket")]
AddressError(#[from] std::net::AddrParseError),
#[error("invalid host for socket")]
IOError(#[from] std::io::Error),
#[error("Error while connecting/creating database pool")]
DatabaseError,
#[error("Failed to KMS decrypt: {0}")]
Expand Down
65 changes: 26 additions & 39 deletions src/middleware.rs
ShankarSinghC marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,63 +1,50 @@
use crate::app::AppState;
use crate::crypto::jw::JWEncryption;
use crate::crypto::jw::{self, JWEncryption};
use crate::crypto::Encryption;
use crate::error::{self, ContainerError, ResultContainerExt};
use axum::{
body::BoxBody,
extract,
http::{Request, Response},
middleware::Next,
};
use error_stack::ResultExt;
use hyper::body::HttpBody;
use hyper::Body;
use axum::body::Body;
use axum::http::{request, response};
use axum::{extract, http::Request, middleware::Next};

use http_body_util::BodyExt;
use josekit::jwe;

/// Middleware providing implementation to perform JWE + JWS encryption and decryption around the
/// card APIs
pub async fn middleware(
extract::State(state): extract::State<AppState>,
request: Request<Body>,
next: Next<Body>,
) -> Result<Response<BoxBody>, ContainerError<error::ApiError>> {
let (parts, body) = request.into_parts();

let request_body =
hyper::body::to_bytes(body)
.await
.change_error(error::ApiError::RequestMiddlewareError(
"Failed to read request body for jwe decryption",
))?;

parts: request::Parts,
axum::Json(jwe_body): axum::Json<jw::JweBody>,
next: Next,
) -> Result<(response::Parts, axum::Json<jw::JweBody>), ContainerError<error::ApiError>> {
let keys = JWEncryption {
private_key: state.config.secrets.locker_private_key,
public_key: state.config.secrets.tenant_public_key,
encryption_algo: jwe::RSA_OAEP,
decryption_algo: jwe::RSA_OAEP_256,
};

let jwe_decrypted = keys.decrypt(request_body.to_vec())?;
let jwe_decrypted = keys.decrypt(jwe_body)?;

let next_layer_payload = Request::from_parts(parts, Body::from(jwe_decrypted));

let response = next.run(next_layer_payload).await;

let (parts, body) = response.into_parts();
let (mut parts, body) = next.run(next_layer_payload).await.into_parts();

let response_body = hyper::body::to_bytes(body).await.change_error(
error::ApiError::ResponseMiddlewareError("Failed to read response body for jws signing"),
)?;
let response_body = body
.collect()
.await
.change_error(error::ApiError::ResponseMiddlewareError(
"Failed to read response body for jws signing",
))?
.to_bytes();

let jws_signed = keys.encrypt(response_body.to_vec())?;
let jwe_payload = keys.encrypt(response_body.to_vec())?;

let jwt = String::from_utf8(jws_signed).change_error(
error::ApiError::ResponseMiddlewareError("Could not convert to UTF-8"),
)?;
parts.headers = hyper::HeaderMap::new();
parts.headers.append(
hyper::header::CONTENT_TYPE,
axum::http::HeaderValue::from_static("application/json"),
);

Ok(axum::http::response::Builder::new()
.status(parts.status)
.body(jwt.map_err(axum::Error::new).boxed_unsync())
.change_context(error::ApiError::ResponseMiddlewareError(
"failed while generating the response",
))?)
Ok((parts, axum::Json(jwe_payload)))
}
Loading