From 21ce2b17b68c2972973175cd6c9583a9cc58a51b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andor=20Uhl=C3=A1r?= Date: Tue, 10 Jan 2023 17:29:58 +0100 Subject: [PATCH] Switch axum `Extension` usage to `State` State is the new, type-safe way of passing around global state between handlers. --- server/svix-server/src/core/idempotency.rs | 24 ++++---- server/svix-server/src/core/permissions.rs | 56 +++++++------------ server/svix-server/src/core/security.rs | 14 ++--- server/svix-server/src/lib.rs | 35 +++++++++--- .../src/v1/endpoints/application.rs | 17 +++--- .../svix-server/src/v1/endpoints/attempt.rs | 28 +++++----- server/svix-server/src/v1/endpoints/auth.rs | 8 +-- .../src/v1/endpoints/endpoint/crud.rs | 41 +++++++++----- .../src/v1/endpoints/endpoint/headers.rs | 11 ++-- .../src/v1/endpoints/endpoint/mod.rs | 11 ++-- .../src/v1/endpoints/endpoint/recovery.rs | 8 ++- .../src/v1/endpoints/endpoint/secrets.rs | 13 ++--- .../src/v1/endpoints/event_type.rs | 19 ++++--- server/svix-server/src/v1/endpoints/health.rs | 20 ++++--- .../svix-server/src/v1/endpoints/message.rs | 23 ++++---- server/svix-server/src/v1/mod.rs | 12 ++-- server/svix-server/tests/utils/mod.rs | 25 ++++++--- server/svix-server/tests/worker.rs | 38 +++++++++---- 18 files changed, 227 insertions(+), 176 deletions(-) diff --git a/server/svix-server/src/core/idempotency.rs b/server/svix-server/src/core/idempotency.rs index 8aaa25756..ca61d3826 100644 --- a/server/svix-server/src/core/idempotency.rs +++ b/server/svix-server/src/core/idempotency.rs @@ -348,7 +348,7 @@ where mod tests { use std::{net::TcpListener, sync::Arc}; - use axum::{extract::Extension, routing::post, Router, Server}; + use axum::{extract::State, routing::post, Router, Server}; use http::StatusCode; use reqwest::Client; use tokio::{sync::Mutex, task::JoinHandle}; @@ -361,6 +361,12 @@ mod tests { types::{BaseId, OrganizationId}, }; + #[derive(Clone)] + struct TestAppState { + count: Arc>, + wait: Option, + } + /// Starts a basic Axum server with one endpoint which counts the number of times the endpoint /// has been polled from. This will be nested in the [`IdempotencyService`] such that, providing /// a key may result in the count not increasing and a prior result being displayed. @@ -372,7 +378,7 @@ mod tests { /// that points to the count of the server such that its internal state may be monitored. async fn start_service( wait: Option, - ) -> (JoinHandle<()>, String, Arc>) { + ) -> (JoinHandle<()>, String, Arc>) { dotenv::dotenv().ok(); let cache = cache::memory::new(); @@ -396,8 +402,7 @@ mod tests { service, } })) - .layer(Extension(count)) - .layer(Extension(wait)) + .with_state(TestAppState { count, wait }) .into_make_service(), ) .await @@ -409,10 +414,7 @@ mod tests { } /// Only to be used via [`start_service`] -- this is the actual endpoint implementation - async fn service_endpoint( - Extension(count): Extension>>, - Extension(wait): Extension>, - ) -> String { + async fn service_endpoint(State(TestAppState { wait, count }): State) -> String { let mut count = count.lock().await; *count += 1; @@ -597,7 +599,7 @@ mod tests { service, } })) - .layer(Extension(count)) + .with_state(TestAppState { count, wait: None }) .into_make_service(), ) .await @@ -609,7 +611,9 @@ mod tests { } /// Only to be used via [`start_empty_service`] -- this is the actual endpoint implementation - async fn empty_service_endpoint(Extension(count): Extension>>) -> StatusCode { + async fn empty_service_endpoint( + State(TestAppState { count, .. }): State, + ) -> StatusCode { let mut count = count.lock().await; *count += 1; diff --git a/server/svix-server/src/core/permissions.rs b/server/svix-server/src/core/permissions.rs index c9ef0ed29..2e44d5753 100644 --- a/server/svix-server/src/core/permissions.rs +++ b/server/svix-server/src/core/permissions.rs @@ -2,14 +2,13 @@ use axum::{ async_trait, extract::{FromRequestParts, Path}, http::request::Parts, - Extension, }; -use sea_orm::DatabaseConnection; use crate::{ ctx, db::models::{application, applicationmetadata}, error::{Error, HttpError, Result}, + AppState, }; use super::{ @@ -22,13 +21,10 @@ pub struct ReadAll { } #[async_trait] -impl FromRequestParts for ReadAll -where - S: Send + Sync, -{ +impl FromRequestParts for ReadAll { type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result { let permissions = ctx!(permissions_from_bearer(parts, state).await)?; let org_id = permissions.org_id(); Ok(Self { org_id }) @@ -51,13 +47,10 @@ impl Permissions { } #[async_trait] -impl FromRequestParts for Organization -where - S: Send + Sync, -{ +impl FromRequestParts for Organization { type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result { let permissions = permissions_from_bearer(parts, state).await?; let org_id = match permissions.access_level { @@ -74,22 +67,17 @@ pub struct Application { } #[async_trait] -impl FromRequestParts for Application -where - S: Send + Sync, -{ +impl FromRequestParts for Application { type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result { let permissions = permissions_from_bearer(parts, state).await?; let Path(ApplicationPathParams { app_id }) = ctx!(Path::::from_request_parts(parts, state).await)?; - let Extension(ref db) = - ctx!(Extension::::from_request_parts(parts, state).await)?; let app = ctx!( application::Entity::secure_find_by_id_or_uid(permissions.org_id(), app_id.to_owned(),) - .one(db) + .one(&state.db) .await )? .ok_or_else(|| HttpError::not_found(None, None))?; @@ -106,22 +94,17 @@ pub struct OrganizationWithApplication { } #[async_trait] -impl FromRequestParts for OrganizationWithApplication -where - S: Send + Sync, -{ +impl FromRequestParts for OrganizationWithApplication { type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result { let Organization { org_id } = ctx!(Organization::from_request_parts(parts, state).await)?; let Path(ApplicationPathParams { app_id }) = ctx!(Path::::from_request_parts(parts, state).await)?; - let Extension(ref db) = - ctx!(Extension::::from_request_parts(parts, state).await)?; let app = ctx!( application::Entity::secure_find_by_id_or_uid(org_id, app_id.to_owned(),) - .one(db) + .one(&state.db) .await )? .ok_or_else(|| HttpError::not_found(None, None))?; @@ -135,22 +118,21 @@ pub struct ApplicationWithMetadata { } #[async_trait] -impl FromRequestParts for ApplicationWithMetadata -where - S: Send + Sync, -{ +impl FromRequestParts for ApplicationWithMetadata { type Rejection = Error; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result { let permissions = permissions_from_bearer(parts, state).await?; let Path(ApplicationPathParams { app_id }) = ctx!(Path::::from_request_parts(parts, state).await)?; - let Extension(ref db) = - ctx!(Extension::::from_request_parts(parts, state).await)?; let (app, metadata) = ctx!( - application::Model::fetch_with_metadata(db, permissions.org_id(), app_id.to_owned()) - .await + application::Model::fetch_with_metadata( + &state.db, + permissions.org_id(), + app_id.to_owned() + ) + .await )? .ok_or_else(|| HttpError::not_found(None, None))?; diff --git a/server/svix-server/src/core/security.rs b/server/svix-server/src/core/security.rs index 10686a77b..2d9d650d5 100644 --- a/server/svix-server/src/core/security.rs +++ b/server/svix-server/src/core/security.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use axum::{ - extract::{Extension, FromRequestParts, TypedHeader}, + extract::{FromRequestParts, TypedHeader}, headers::{authorization::Bearer, Authorization}, }; @@ -14,9 +14,9 @@ use jwt_simple::prelude::*; use validator::Validate; use crate::{ - cfg::Configuration, ctx, error::{HttpError, Result}, + AppState, }; use super::types::{ApplicationId, OrganizationId}; @@ -62,17 +62,11 @@ pub struct CustomClaim { pub organization: Option, } -pub async fn permissions_from_bearer( - parts: &mut Parts, - state: &S, -) -> Result { - let Extension(ref cfg) = - ctx!(Extension::::from_request_parts(parts, state).await)?; - +pub async fn permissions_from_bearer(parts: &mut Parts, state: &AppState) -> Result { let TypedHeader(Authorization(bearer)) = ctx!(TypedHeader::>::from_request_parts(parts, state).await)?; - let claims = parse_bearer(&cfg.jwt_secret, &bearer) + let claims = parse_bearer(&state.cfg.jwt_secret, &bearer) .ok_or_else(|| HttpError::unauthorized(None, Some("Invalid token".to_string())))?; permissions_from_jwt(claims) } diff --git a/server/svix-server/src/lib.rs b/server/svix-server/src/lib.rs index f3fb43d78..33784a751 100644 --- a/server/svix-server/src/lib.rs +++ b/server/svix-server/src/lib.rs @@ -4,12 +4,15 @@ #![warn(clippy::all)] #![forbid(unsafe_code)] -use axum::{extract::Extension, Router}; +use axum::Router; +use crate::core::cache::Cache; use cfg::ConfigurationInner; use lazy_static::lazy_static; use opentelemetry::runtime::Tokio; use opentelemetry_otlp::WithExportConfig; +use queue::TaskQueueProducer; +use sea_orm::DatabaseConnection; use std::{ net::TcpListener, sync::atomic::{AtomicBool, Ordering}, @@ -22,7 +25,9 @@ use tracing_subscriber::{prelude::*, util::SubscriberInitExt}; use crate::{ cfg::{CacheBackend, Configuration}, core::{ - cache, idempotency::IdempotencyService, operational_webhooks::OperationalWebhookSenderInner, + cache, + idempotency::IdempotencyService, + operational_webhooks::{OperationalWebhookSender, OperationalWebhookSenderInner}, }, db::init_db, expired_message_cleaner::expired_message_cleaner_loop, @@ -77,6 +82,15 @@ pub async fn run(cfg: Configuration, listener: Option) { run_with_prefix(None, cfg, listener).await } +#[derive(Clone)] +pub struct AppState { + db: DatabaseConnection, + queue_tx: TaskQueueProducer, + cfg: Configuration, + cache: Cache, + op_webhooks: OperationalWebhookSender, +} + // Made public for the purpose of E2E testing in which a queue prefix is necessary to avoid tests // consuming from each others' queues pub async fn run_with_prefix( @@ -110,8 +124,16 @@ pub async fn run_with_prefix( let svc_cache = cache.clone(); // build our application with a route + let app_state = AppState { + db: pool.clone(), + queue_tx: queue_tx.clone(), + cfg: cfg.clone(), + cache: cache.clone(), + op_webhooks: op_webhook_sender.clone(), + }; + let v1_router = v1::router().with_state::<()>(app_state); let app = Router::new() - .nest("/api/v1", v1::router()) + .nest("/api/v1", v1_router) .merge(docs::router()) .layer( ServiceBuilder::new().layer_fn(move |service| IdempotencyService { @@ -125,12 +147,7 @@ pub async fn run_with_prefix( .allow_methods(Any) .allow_headers(AllowHeaders::mirror_request()) .max_age(Duration::from_secs(600)), - ) - .layer(Extension(pool.clone())) - .layer(Extension(queue_tx.clone())) - .layer(Extension(cfg.clone())) - .layer(Extension(cache.clone())) - .layer(Extension(op_webhook_sender.clone())); + ); let with_api = cfg.api_enabled; let with_worker = cfg.worker_enabled; diff --git a/server/svix-server/src/v1/endpoints/application.rs b/server/svix-server/src/v1/endpoints/application.rs index fed8a0128..924caf7f1 100644 --- a/server/svix-server/src/v1/endpoints/application.rs +++ b/server/svix-server/src/v1/endpoints/application.rs @@ -19,16 +19,17 @@ use crate::{ validation_error, EmptyResponse, ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, ValidatedQuery, }, + AppState, }; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, routing::{get, post}, Json, Router, }; use chrono::{DateTime, Utc}; use hyper::StatusCode; +use sea_orm::ActiveModelTrait; use sea_orm::ActiveValue::Set; -use sea_orm::{ActiveModelTrait, DatabaseConnection}; use serde::{Deserialize, Serialize}; use svix_server_derive::ModelOut; use validator::{Validate, ValidationError}; @@ -184,7 +185,7 @@ impl From<(application::Model, applicationmetadata::Model)> for ApplicationOut { } async fn list_applications( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, pagination: ValidatedQuery>, permissions::Organization { org_id }: permissions::Organization, ) -> Result>> { @@ -213,7 +214,7 @@ pub struct CreateApplicationQuery { } async fn create_application( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, query: ValidatedQuery, permissions::Organization { org_id }: permissions::Organization, ValidatedJson(data): ValidatedJson, @@ -256,7 +257,7 @@ async fn get_application( } async fn update_application( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path(app_id): Path, permissions::Organization { org_id }: permissions::Organization, ValidatedJson(data): ValidatedJson, @@ -294,7 +295,7 @@ async fn update_application( } async fn patch_application( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, permissions::OrganizationWithApplication { app }: permissions::OrganizationWithApplication, ValidatedJson(data): ValidatedJson, ) -> Result> { @@ -315,7 +316,7 @@ async fn patch_application( } async fn delete_application( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, permissions::OrganizationWithApplication { app }: permissions::OrganizationWithApplication, ) -> Result<(StatusCode, Json)> { let mut app: application::ActiveModel = app.into(); @@ -325,7 +326,7 @@ async fn delete_application( Ok((StatusCode::NO_CONTENT, Json(EmptyResponse {}))) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route("/app/", post(create_application).get(list_applications)) .route( diff --git a/server/svix-server/src/v1/endpoints/attempt.rs b/server/svix-server/src/v1/endpoints/attempt.rs index 37723366e..a6693dca2 100644 --- a/server/svix-server/src/v1/endpoints/attempt.rs +++ b/server/svix-server/src/v1/endpoints/attempt.rs @@ -14,7 +14,7 @@ use crate::{ db::models::{endpoint, message, messagedestination}, err_database, error::{Error, HttpError, Result}, - queue::{MessageTask, TaskQueueProducer}, + queue::MessageTask, v1::{ endpoints::message::MessageOut, utils::{ @@ -22,9 +22,10 @@ use crate::{ MessageListFetchOptions, ModelOut, PaginationLimit, ReversibleIterator, ValidatedQuery, }, }, + AppState, }; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, routing::{get, post}, Json, Router, }; @@ -119,7 +120,7 @@ pub struct ListAttemptedMessagesQueryParameters { /// Fetches a list of [`AttemptedMessageOut`]s associated with a given app and endpoint. async fn list_attempted_messages( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(pagination): ValidatedQuery>>, ValidatedQuery(ListAttemptedMessagesQueryParameters { channel, @@ -286,7 +287,7 @@ fn list_attempts_by_endpoint_or_message_filters( /// Fetches a list of [`MessageAttemptOut`]s for a given endpoint ID async fn list_attempts_by_endpoint( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(pagination): ValidatedQuery>>, ValidatedQuery(ListAttemptsByEndpointQueryParameters { status, @@ -357,7 +358,7 @@ pub struct ListAttemptsByMsgQueryParameters { /// Fetches a list of [`MessageAttemptOut`]s for a given message ID async fn list_attempts_by_msg( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(pagination): ValidatedQuery>>, ValidatedQuery(ListAttemptsByMsgQueryParameters { status, @@ -458,7 +459,7 @@ impl MessageEndpointOut { } async fn list_attempted_destinations( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(mut pagination): ValidatedQuery>, Path((_app_id, msg_id)): Path<(ApplicationIdOrUid, MessageIdOrUid)>, permissions::Application { app }: permissions::Application, @@ -514,7 +515,7 @@ pub struct ListAttemptsForEndpointQueryParameters { } async fn list_attempts_for_endpoint( - extension: Extension, + state: State, pagination: ValidatedQuery>>, ValidatedQuery(ListAttemptsForEndpointQueryParameters { channel, @@ -527,7 +528,7 @@ async fn list_attempts_for_endpoint( auth_app: permissions::Application, ) -> Result>> { list_messageattempts( - extension, + state, pagination, ValidatedQuery(AttemptListFetchOptions { endpoint_id: Some(endp_id), @@ -555,7 +556,7 @@ pub struct AttemptListFetchOptions { } async fn list_messageattempts( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(pagination): ValidatedQuery>>, ValidatedQuery(AttemptListFetchOptions { endpoint_id, @@ -624,7 +625,7 @@ async fn list_messageattempts( } async fn get_messageattempt( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, msg_id, attempt_id)): Path<( ApplicationIdOrUid, MessageIdOrUid, @@ -650,8 +651,9 @@ async fn get_messageattempt( } async fn resend_webhook( - Extension(ref db): Extension, - Extension(queue_tx): Extension, + State(AppState { + ref db, queue_tx, .. + }): State, Path((_app_id, msg_id, endp_id)): Path<(ApplicationIdOrUid, MessageIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> Result<(StatusCode, Json)> { @@ -700,7 +702,7 @@ async fn resend_webhook( Ok((StatusCode::ACCEPTED, Json(EmptyResponse {}))) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() // NOTE: [`list_messageattempts`] is deprecated .route( diff --git a/server/svix-server/src/v1/endpoints/auth.rs b/server/svix-server/src/v1/endpoints/auth.rs index c11dcce3d..879662c15 100644 --- a/server/svix-server/src/v1/endpoints/auth.rs +++ b/server/svix-server/src/v1/endpoints/auth.rs @@ -1,11 +1,11 @@ -use axum::{routing::post, Extension, Json, Router}; +use axum::{extract::State, routing::post, Json, Router}; use serde::{Deserialize, Serialize}; use crate::{ - cfg::Configuration, core::{permissions, security::generate_app_token}, error::{HttpError, Result}, v1::utils::api_not_implemented, + AppState, }; #[derive(Deserialize, Serialize)] @@ -15,7 +15,7 @@ pub struct DashboardAccessOut { } async fn dashboard_access( - Extension(cfg): Extension, + State(AppState { cfg, .. }): State, permissions::OrganizationWithApplication { app }: permissions::OrganizationWithApplication, ) -> Result> { let token = generate_app_token(&cfg.jwt_secret, app.org_id, app.id.clone())?; @@ -35,7 +35,7 @@ async fn dashboard_access( Ok(Json(DashboardAccessOut { url, token })) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route("/auth/dashboard-access/:app_id/", post(dashboard_access)) .route("/auth/logout/", post(api_not_implemented)) diff --git a/server/svix-server/src/v1/endpoints/endpoint/crud.rs b/server/svix-server/src/v1/endpoints/endpoint/crud.rs index 3f0b458de..2c5d6b679 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/crud.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/crud.rs @@ -1,7 +1,7 @@ use std::{collections::HashSet, mem}; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, Json, }; use hyper::StatusCode; @@ -28,11 +28,12 @@ use crate::{ EmptyResponse, ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, ValidatedQuery, }, + AppState, }; use hack::EventTypeNameResult; pub(super) async fn list_endpoints( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, pagination: ValidatedQuery>, permissions::Application { app }: permissions::Application, ) -> Result>> { @@ -99,9 +100,12 @@ async fn create_endp_from_data( } pub(super) async fn create_endpoint( - Extension(ref db): Extension, - Extension(ref cfg): Extension, - Extension(op_webhooks): Extension, + State(AppState { + ref db, + ref cfg, + op_webhooks, + .. + }): State, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, ) -> Result<(StatusCode, Json)> { @@ -116,7 +120,7 @@ pub(super) async fn create_endpoint( } pub(super) async fn get_endpoint( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> Result> { @@ -160,9 +164,12 @@ async fn update_endp_from_data( } pub(super) async fn update_endpoint( - Extension(ref db): Extension, - Extension(ref cfg): Extension, - Extension(ref op_webhooks): Extension, + State(AppState { + ref db, + ref cfg, + ref op_webhooks, + .. + }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(mut data): ValidatedJson, @@ -188,9 +195,12 @@ pub(super) async fn update_endpoint( } pub(super) async fn patch_endpoint( - Extension(ref db): Extension, - Extension(cfg): Extension, - Extension(ref op_webhooks): Extension, + State(AppState { + ref db, + cfg, + ref op_webhooks, + .. + }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, @@ -217,8 +227,11 @@ pub(super) async fn patch_endpoint( } pub(super) async fn delete_endpoint( - Extension(ref db): Extension, - Extension(op_webhooks): Extension, + State(AppState { + ref db, + ref op_webhooks, + .. + }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> Result<(StatusCode, Json)> { diff --git a/server/svix-server/src/v1/endpoints/endpoint/headers.rs b/server/svix-server/src/v1/endpoints/endpoint/headers.rs index cc4c8e605..887131718 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/headers.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/headers.rs @@ -1,9 +1,9 @@ use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, Json, }; use hyper::StatusCode; -use sea_orm::{ActiveModelTrait, DatabaseConnection}; +use sea_orm::ActiveModelTrait; use super::{EndpointHeadersIn, EndpointHeadersOut, EndpointHeadersPatchIn}; use crate::{ @@ -15,10 +15,11 @@ use crate::{ db::models::endpoint, error::{HttpError, Result}, v1::utils::{EmptyResponse, ModelIn, ValidatedJson}, + AppState, }; pub(super) async fn get_endpoint_headers( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> Result> { @@ -36,7 +37,7 @@ pub(super) async fn get_endpoint_headers( } pub(super) async fn update_endpoint_headers( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, @@ -56,7 +57,7 @@ pub(super) async fn update_endpoint_headers( } pub(super) async fn patch_endpoint_headers( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, diff --git a/server/svix-server/src/v1/endpoints/endpoint/mod.rs b/server/svix-server/src/v1/endpoints/endpoint/mod.rs index d7473768d..bd27b85db 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/mod.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/mod.rs @@ -28,17 +28,16 @@ use crate::{ validate_no_control_characters, validate_no_control_characters_unrequired, validation_error, ModelIn, }, + AppState, }; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, routing::{get, post}, Json, Router, }; use chrono::{DateTime, Utc}; -use sea_orm::{ - ActiveValue::Set, ColumnTrait, DatabaseConnection, FromQueryResult, QueryFilter, QuerySelect, -}; +use sea_orm::{ActiveValue::Set, ColumnTrait, FromQueryResult, QueryFilter, QuerySelect}; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, collections::HashSet}; use url::Url; @@ -487,7 +486,7 @@ pub struct EndpointStatsQueryOut { } async fn endpoint_stats( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> crate::error::Result> { @@ -527,7 +526,7 @@ async fn endpoint_stats( })) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route( "/app/:app_id/endpoint/", diff --git a/server/svix-server/src/v1/endpoints/endpoint/recovery.rs b/server/svix-server/src/v1/endpoints/endpoint/recovery.rs index bcdbb2dba..22ed38b2e 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/recovery.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/recovery.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, Json, }; use chrono::{DateTime, Utc}; @@ -21,6 +21,7 @@ use crate::{ error::{HttpError, Result, ValidationErrorItem}, queue::{MessageTask, TaskQueueProducer}, v1::utils::{EmptyResponse, ValidatedJson}, + AppState, }; async fn bulk_recover_failed_messages( @@ -74,8 +75,9 @@ async fn bulk_recover_failed_messages( } pub(super) async fn recover_failed_webhooks( - Extension(ref db): Extension, - Extension(queue_tx): Extension, + State(AppState { + ref db, queue_tx, .. + }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, diff --git a/server/svix-server/src/v1/endpoints/endpoint/secrets.rs b/server/svix-server/src/v1/endpoints/endpoint/secrets.rs index e5cfe5904..ae158289a 100644 --- a/server/svix-server/src/v1/endpoints/endpoint/secrets.rs +++ b/server/svix-server/src/v1/endpoints/endpoint/secrets.rs @@ -1,16 +1,16 @@ use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, Json, }; use chrono::{Duration, Utc}; use hyper::StatusCode; +use sea_orm::ActiveModelTrait; use sea_orm::ActiveValue::Set; -use sea_orm::{ActiveModelTrait, DatabaseConnection}; use std::iter; use super::{EndpointSecretOut, EndpointSecretRotateIn}; use crate::{ - cfg::{Configuration, DefaultSignatureType}, + cfg::DefaultSignatureType, core::{ cryptography::Encryption, permissions, @@ -23,6 +23,7 @@ use crate::{ db::models::endpoint, error::{HttpError, Result}, v1::utils::{EmptyResponse, ValidatedJson}, + AppState, }; pub(super) fn generate_secret( @@ -36,8 +37,7 @@ pub(super) fn generate_secret( } pub(super) async fn get_endpoint_secret( - Extension(ref db): Extension, - Extension(cfg): Extension, + State(AppState { ref db, cfg, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ) -> Result> { @@ -53,8 +53,7 @@ pub(super) async fn get_endpoint_secret( } pub(super) async fn rotate_endpoint_secret( - Extension(ref db): Extension, - Extension(cfg): Extension, + State(AppState { ref db, cfg, .. }): State, Path((_app_id, endp_id)): Path<(ApplicationIdOrUid, EndpointIdOrUid)>, permissions::Application { app }: permissions::Application, ValidatedJson(data): ValidatedJson, diff --git a/server/svix-server/src/v1/endpoints/event_type.rs b/server/svix-server/src/v1/endpoints/event_type.rs index 9c6303385..b2e455558 100644 --- a/server/svix-server/src/v1/endpoints/event_type.rs +++ b/server/svix-server/src/v1/endpoints/event_type.rs @@ -16,16 +16,17 @@ use crate::{ ListResponse, ModelIn, ModelOut, Pagination, PaginationLimit, ValidatedJson, ValidatedQuery, }, + AppState, }; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, routing::{get, post}, Json, Router, }; use chrono::{DateTime, Utc}; use hyper::StatusCode; use sea_orm::{entity::prelude::*, ActiveValue::Set, QueryOrder}; -use sea_orm::{ActiveModelTrait, DatabaseConnection, QuerySelect}; +use sea_orm::{ActiveModelTrait, QuerySelect}; use serde::{Deserialize, Serialize}; use svix_server_derive::{ModelIn, ModelOut}; use validator::Validate; @@ -167,7 +168,7 @@ pub struct ListFetchOptions { } async fn list_event_types( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, pagination: ValidatedQuery>, fetch_options: ValidatedQuery, permissions::ReadAll { org_id }: permissions::ReadAll, @@ -203,7 +204,7 @@ async fn list_event_types( } async fn create_event_type( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, permissions::Organization { org_id }: permissions::Organization, ValidatedJson(data): ValidatedJson, ) -> Result<(StatusCode, Json)> { @@ -239,7 +240,7 @@ async fn create_event_type( } async fn get_event_type( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path(evtype_name): Path, permissions::ReadAll { org_id }: permissions::ReadAll, ) -> Result> { @@ -253,7 +254,7 @@ async fn get_event_type( } async fn update_event_type( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path(evtype_name): Path, permissions::Organization { org_id }: permissions::Organization, ValidatedJson(data): ValidatedJson, @@ -289,7 +290,7 @@ async fn update_event_type( } async fn patch_event_type( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path(evtype_name): Path, permissions::Organization { org_id }: permissions::Organization, ValidatedJson(data): ValidatedJson, @@ -309,7 +310,7 @@ async fn patch_event_type( } async fn delete_event_type( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path(evtype_name): Path, permissions::Organization { org_id }: permissions::Organization, ) -> Result<(StatusCode, Json)> { @@ -326,7 +327,7 @@ async fn delete_event_type( Ok((StatusCode::NO_CONTENT, Json(EmptyResponse {}))) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route( "/event-type/", diff --git a/server/svix-server/src/v1/endpoints/health.rs b/server/svix-server/src/v1/endpoints/health.rs index 6d5368af6..f03e7131a 100644 --- a/server/svix-server/src/v1/endpoints/health.rs +++ b/server/svix-server/src/v1/endpoints/health.rs @@ -3,13 +3,14 @@ use std::time::Duration; -use axum::{http::StatusCode, routing::get, Extension, Json, Router}; -use sea_orm::{query::Statement, ConnectionTrait, DatabaseBackend, DatabaseConnection}; +use axum::{extract::State, http::StatusCode, routing::get, Json, Router}; +use sea_orm::{query::Statement, ConnectionTrait, DatabaseBackend}; use serde::{Deserialize, Serialize}; use crate::{ - core::cache::{kv_def, Cache, CacheBehavior, CacheKey, CacheValue}, - queue::{QueueTask, TaskQueueProducer}, + core::cache::{kv_def, CacheBehavior, CacheKey, CacheValue}, + queue::QueueTask, + AppState, }; async fn ping() -> StatusCode { @@ -74,9 +75,12 @@ struct HealthCheckCacheValue(()); kv_def!(HealthCheckCacheKey, HealthCheckCacheValue); async fn health( - Extension(ref db): Extension, - Extension(queue_tx): Extension, - Extension(cache): Extension, + State(AppState { + ref db, + queue_tx, + cache, + .. + }): State, ) -> (StatusCode, Json) { // SELECT 1 FROM any table let database: HealthStatus = db @@ -117,7 +121,7 @@ async fn health( ) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route("/health/ping/", get(ping).head(ping)) .route("/health/", get(health).head(health)) diff --git a/server/svix-server/src/v1/endpoints/message.rs b/server/svix-server/src/v1/endpoints/message.rs index 847377823..5aeaa4a2c 100644 --- a/server/svix-server/src/v1/endpoints/message.rs +++ b/server/svix-server/src/v1/endpoints/message.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: MIT use crate::{ - cache::Cache, core::{ message_app::CreateMessageApp, permissions, @@ -13,23 +12,24 @@ use crate::{ }, ctx, err_generic, error::{HttpError, Result}, - queue::{MessageTaskBatch, TaskQueueProducer}, + queue::MessageTaskBatch, v1::utils::{ apply_pagination, iterator_from_before_or_after, validation_error, ListResponse, MessageListFetchOptions, ModelIn, ModelOut, PaginationLimit, ReversibleIterator, ValidatedJson, ValidatedQuery, }, + AppState, }; use axum::{ - extract::{Extension, Path}, + extract::{Path, State}, routing::{get, post}, Json, Router, }; use chrono::{DateTime, Duration, Utc}; use hyper::StatusCode; use sea_orm::entity::prelude::*; +use sea_orm::ActiveModelTrait; use sea_orm::{sea_query::Expr, ActiveValue::Set}; -use sea_orm::{ActiveModelTrait, DatabaseConnection}; use serde::{Deserialize, Serialize}; use svix_server_derive::{ModelIn, ModelOut}; @@ -165,7 +165,7 @@ pub struct ListMessagesQueryParams { } async fn list_messages( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, ValidatedQuery(pagination): ValidatedQuery>>, ValidatedQuery(ListMessagesQueryParams { channel, @@ -219,9 +219,12 @@ pub struct CreateMessageQueryParams { } async fn create_message( - Extension(ref db): Extension, - Extension(queue_tx): Extension, - Extension(cache): Extension, + State(AppState { + ref db, + queue_tx, + cache, + .. + }): State, ValidatedQuery(CreateMessageQueryParams { with_content }): ValidatedQuery< CreateMessageQueryParams, >, @@ -279,7 +282,7 @@ pub struct GetMessageQueryParams { with_content: bool, } async fn get_message( - Extension(ref db): Extension, + State(AppState { ref db, .. }): State, Path((_app_id, msg_id)): Path<(ApplicationIdOrUid, MessageIdOrUid)>, ValidatedQuery(GetMessageQueryParams { with_content }): ValidatedQuery, permissions::Application { app }: permissions::Application, @@ -298,7 +301,7 @@ async fn get_message( Ok(Json(msg_out)) } -pub fn router() -> Router { +pub fn router() -> Router { Router::new() .route("/app/:app_id/msg/", post(create_message).get(list_messages)) .route("/app/:app_id/msg/:msg_id/", get(get_message)) diff --git a/server/svix-server/src/v1/mod.rs b/server/svix-server/src/v1/mod.rs index c8b2599ed..60941c844 100644 --- a/server/svix-server/src/v1/mod.rs +++ b/server/svix-server/src/v1/mod.rs @@ -4,13 +4,16 @@ use axum::Router; use tower_http::trace::TraceLayer; -use crate::core::otel_spans::{AxumOtelOnFailure, AxumOtelOnResponse, AxumOtelSpanCreator}; +use crate::{ + core::otel_spans::{AxumOtelOnFailure, AxumOtelOnResponse, AxumOtelSpanCreator}, + AppState, +}; pub mod endpoints; pub mod utils; -pub fn router() -> Router { - let ret = Router::new() +pub fn router() -> Router { + let ret: Router = Router::new() .merge(endpoints::health::router()) .merge(endpoints::auth::router()) .merge(endpoints::application::router()) @@ -39,6 +42,7 @@ mod development { use crate::error::{Error, Result}; use crate::v1::utils::EmptyResponse; + use crate::AppState; struct EchoData { pub headers: String, @@ -65,7 +69,7 @@ mod development { Ok(Json(EmptyResponse {})) } - pub fn router() -> Router { + pub fn router() -> Router { Router::new().route("/development/echo/", get(echo).post(echo)) } } diff --git a/server/svix-server/tests/utils/mod.rs b/server/svix-server/tests/utils/mod.rs index 94ea32adb..b91fe0c44 100644 --- a/server/svix-server/tests/utils/mod.rs +++ b/server/svix-server/tests/utils/mod.rs @@ -266,6 +266,13 @@ pub struct TestReceiver { pub response_status_code: Arc>, } +#[derive(Clone)] +pub struct TestAppState { + tx: mpsc::Sender, + header_tx: mpsc::Sender, + response_status_code: Arc>, +} + #[derive(Clone)] pub struct ResponseStatusCode { pub status_code: axum::http::StatusCode, @@ -288,9 +295,11 @@ impl TestReceiver { "/", axum::routing::post(test_receiver_route).get(test_receiver_route), ) - .layer(axum::extract::Extension(tx)) - .layer(axum::extract::Extension(header_tx)) - .layer(axum::extract::Extension(response_status_code.clone())) + .with_state(TestAppState { + tx, + header_tx, + response_status_code: response_status_code.clone(), + }) .into_make_service(); let jh = tokio::spawn(async move { @@ -316,11 +325,11 @@ impl TestReceiver { } async fn test_receiver_route( - axum::extract::Extension(ref tx): axum::extract::Extension>, - axum::extract::Extension(ref header_tx): axum::extract::Extension>, - axum::extract::Extension(response_status_code): axum::extract::Extension< - Arc>, - >, + axum::extract::State(TestAppState { + tx, + header_tx, + response_status_code, + }): axum::extract::State, headers: HeaderMap, axum::Json(json): axum::Json, ) -> axum::http::StatusCode { diff --git a/server/svix-server/tests/worker.rs b/server/svix-server/tests/worker.rs index cbf229393..a59893b89 100644 --- a/server/svix-server/tests/worker.rs +++ b/server/svix-server/tests/worker.rs @@ -2,7 +2,7 @@ //! As such they are included with integration tests for organizational purposes. use std::{net::TcpListener, sync::Arc, time::Duration}; -use axum::extract::Extension; +use axum::extract::State; use http::StatusCode; use svix_server::v1::{ endpoints::{attempt::MessageAttemptOut, endpoint::EndpointOut}, @@ -14,7 +14,6 @@ mod utils; use utils::{ common_calls::{create_test_app, create_test_endpoint, create_test_message}, get_default_test_config, run_with_retries, start_svix_server, start_svix_server_with_cfg, - ResponseStatusCode, }; /// Runs a full Axum server with two endoints. The first endpoint redirects to the second endpoint @@ -26,6 +25,12 @@ struct RedirectionVisitReportingReceiver { pub has_been_visited: Arc>, } +#[derive(Clone)] +struct RedirectionVisitReportingState { + has_been_visited: Arc>, + resp_with: axum::http::StatusCode, +} + impl RedirectionVisitReportingReceiver { pub fn start(resp_with: axum::http::StatusCode) -> Self { let listener = TcpListener::bind("127.0.0.1:0").unwrap(); @@ -43,8 +48,10 @@ impl RedirectionVisitReportingReceiver { axum::routing::post(visit_reporting_receiver_route) .get(visit_reporting_receiver_route), ) - .layer(Extension(has_been_visited.clone())) - .layer(Extension(resp_with)) + .with_state(RedirectionVisitReportingState { + has_been_visited: has_been_visited.clone(), + resp_with, + }) .into_make_service(); let jh = tokio::spawn(async move { @@ -68,11 +75,13 @@ async fn redirecting_receiver_route() -> axum::response::Redirect { } async fn visit_reporting_receiver_route( - Extension(visited): Extension>>, - Extension(status_code): Extension>>, + State(RedirectionVisitReportingState { + has_been_visited: visited, + resp_with, + }): State, ) -> StatusCode { *visited.lock().await = true; - status_code.lock().await.status_code + resp_with } // The worker has @@ -224,6 +233,12 @@ struct SporadicallyFailingReceiver { pub jh: tokio::task::JoinHandle<()>, } +#[derive(Clone)] +struct SporadicallyFailingState { + count: Arc>, + resp_with: (http::StatusCode, http::StatusCode), +} + impl SporadicallyFailingReceiver { pub fn start(resp_with: (http::StatusCode, http::StatusCode)) -> Self { let listener = TcpListener::bind("127.0.0.1:0").unwrap(); @@ -236,8 +251,7 @@ impl SporadicallyFailingReceiver { "/", axum::routing::post(sporadically_failing_route).get(sporadically_failing_route), ) - .layer(Extension(count)) - .layer(Extension(resp_with)) + .with_state(SporadicallyFailingState { count, resp_with }) .into_make_service(); let jh = tokio::spawn(async move { @@ -253,8 +267,10 @@ impl SporadicallyFailingReceiver { } async fn sporadically_failing_route( - Extension(count): Extension>>, - Extension((resp_ok, resp_fail)): Extension<(StatusCode, StatusCode)>, + State(SporadicallyFailingState { + count, + resp_with: (resp_ok, resp_fail), + }): State, ) -> StatusCode { let mut count = count.lock().await; *count += 1;