Skip to content

Commit

Permalink
Switch axum Extension usage to State (#761)
Browse files Browse the repository at this point in the history
## Motivation
`State` is the new, type-safe way of passing around global state between
handlers in axum. `Extension`s were only resolved at runtime, therefore
if you forgot to pass an extension with `layer()` it'd only fail at
runtime. With `State` a handler can't be registered on the router if its
state can't pass all the required parameters to the handler.

## Solution
Replace all usage of `Extension` with state.

Closes #726
  • Loading branch information
svix-gabriel authored Jan 10, 2023
2 parents a31cda4 + 21ce2b1 commit ef43606
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 176 deletions.
24 changes: 14 additions & 10 deletions server/svix-server/src/core/idempotency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ where
mod tests {
use std::{net::TcpListener, sync::Arc};

use axum::{extract::Extension, routing::post, Router, Server};
use axum::{extract::State, routing::post, Router, Server};
use http::StatusCode;
use reqwest::Client;
use tokio::{sync::Mutex, task::JoinHandle};
Expand All @@ -361,6 +361,12 @@ mod tests {
types::{BaseId, OrganizationId},
};

#[derive(Clone)]
struct TestAppState {
count: Arc<Mutex<u16>>,
wait: Option<std::time::Duration>,
}

/// Starts a basic Axum server with one endpoint which counts the number of times the endpoint
/// has been polled from. This will be nested in the [`IdempotencyService`] such that, providing
/// a key may result in the count not increasing and a prior result being displayed.
Expand All @@ -372,7 +378,7 @@ mod tests {
/// that points to the count of the server such that its internal state may be monitored.
async fn start_service(
wait: Option<std::time::Duration>,
) -> (JoinHandle<()>, String, Arc<Mutex<usize>>) {
) -> (JoinHandle<()>, String, Arc<Mutex<u16>>) {
dotenv::dotenv().ok();

let cache = cache::memory::new();
Expand All @@ -396,8 +402,7 @@ mod tests {
service,
}
}))
.layer(Extension(count))
.layer(Extension(wait))
.with_state(TestAppState { count, wait })
.into_make_service(),
)
.await
Expand All @@ -409,10 +414,7 @@ mod tests {
}

/// Only to be used via [`start_service`] -- this is the actual endpoint implementation
async fn service_endpoint(
Extension(count): Extension<Arc<Mutex<usize>>>,
Extension(wait): Extension<Option<std::time::Duration>>,
) -> String {
async fn service_endpoint(State(TestAppState { wait, count }): State<TestAppState>) -> String {
let mut count = count.lock().await;
*count += 1;

Expand Down Expand Up @@ -597,7 +599,7 @@ mod tests {
service,
}
}))
.layer(Extension(count))
.with_state(TestAppState { count, wait: None })
.into_make_service(),
)
.await
Expand All @@ -609,7 +611,9 @@ mod tests {
}

/// Only to be used via [`start_empty_service`] -- this is the actual endpoint implementation
async fn empty_service_endpoint(Extension(count): Extension<Arc<Mutex<u16>>>) -> StatusCode {
async fn empty_service_endpoint(
State(TestAppState { count, .. }): State<TestAppState>,
) -> StatusCode {
let mut count = count.lock().await;
*count += 1;

Expand Down
56 changes: 19 additions & 37 deletions server/svix-server/src/core/permissions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ use axum::{
async_trait,
extract::{FromRequestParts, Path},
http::request::Parts,
Extension,
};
use sea_orm::DatabaseConnection;

use crate::{
ctx,
db::models::{application, applicationmetadata},
error::{Error, HttpError, Result},
AppState,
};

use super::{
Expand All @@ -22,13 +21,10 @@ pub struct ReadAll {
}

#[async_trait]
impl<S> FromRequestParts<S> for ReadAll
where
S: Send + Sync,
{
impl FromRequestParts<AppState> for ReadAll {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self> {
let permissions = ctx!(permissions_from_bearer(parts, state).await)?;
let org_id = permissions.org_id();
Ok(Self { org_id })
Expand All @@ -51,13 +47,10 @@ impl Permissions {
}

#[async_trait]
impl<S> FromRequestParts<S> for Organization
where
S: Send + Sync,
{
impl FromRequestParts<AppState> for Organization {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self> {
let permissions = permissions_from_bearer(parts, state).await?;

let org_id = match permissions.access_level {
Expand All @@ -74,22 +67,17 @@ pub struct Application {
}

#[async_trait]
impl<S> FromRequestParts<S> for Application
where
S: Send + Sync,
{
impl FromRequestParts<AppState> for Application {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self> {
let permissions = permissions_from_bearer(parts, state).await?;

let Path(ApplicationPathParams { app_id }) =
ctx!(Path::<ApplicationPathParams>::from_request_parts(parts, state).await)?;
let Extension(ref db) =
ctx!(Extension::<DatabaseConnection>::from_request_parts(parts, state).await)?;
let app = ctx!(
application::Entity::secure_find_by_id_or_uid(permissions.org_id(), app_id.to_owned(),)
.one(db)
.one(&state.db)
.await
)?
.ok_or_else(|| HttpError::not_found(None, None))?;
Expand All @@ -106,22 +94,17 @@ pub struct OrganizationWithApplication {
}

#[async_trait]
impl<S> FromRequestParts<S> for OrganizationWithApplication
where
S: Send + Sync,
{
impl FromRequestParts<AppState> for OrganizationWithApplication {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self> {
let Organization { org_id } = ctx!(Organization::from_request_parts(parts, state).await)?;

let Path(ApplicationPathParams { app_id }) =
ctx!(Path::<ApplicationPathParams>::from_request_parts(parts, state).await)?;
let Extension(ref db) =
ctx!(Extension::<DatabaseConnection>::from_request_parts(parts, state).await)?;
let app = ctx!(
application::Entity::secure_find_by_id_or_uid(org_id, app_id.to_owned(),)
.one(db)
.one(&state.db)
.await
)?
.ok_or_else(|| HttpError::not_found(None, None))?;
Expand All @@ -135,22 +118,21 @@ pub struct ApplicationWithMetadata {
}

#[async_trait]
impl<S> FromRequestParts<S> for ApplicationWithMetadata
where
S: Send + Sync,
{
impl FromRequestParts<AppState> for ApplicationWithMetadata {
type Rejection = Error;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self> {
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self> {
let permissions = permissions_from_bearer(parts, state).await?;

let Path(ApplicationPathParams { app_id }) =
ctx!(Path::<ApplicationPathParams>::from_request_parts(parts, state).await)?;
let Extension(ref db) =
ctx!(Extension::<DatabaseConnection>::from_request_parts(parts, state).await)?;
let (app, metadata) = ctx!(
application::Model::fetch_with_metadata(db, permissions.org_id(), app_id.to_owned())
.await
application::Model::fetch_with_metadata(
&state.db,
permissions.org_id(),
app_id.to_owned()
)
.await
)?
.ok_or_else(|| HttpError::not_found(None, None))?;

Expand Down
14 changes: 4 additions & 10 deletions server/svix-server/src/core/security.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use std::fmt::Debug;

use axum::{
extract::{Extension, FromRequestParts, TypedHeader},
extract::{FromRequestParts, TypedHeader},
headers::{authorization::Bearer, Authorization},
};

Expand All @@ -14,9 +14,9 @@ use jwt_simple::prelude::*;
use validator::Validate;

use crate::{
cfg::Configuration,
ctx,
error::{HttpError, Result},
AppState,
};

use super::types::{ApplicationId, OrganizationId};
Expand Down Expand Up @@ -62,17 +62,11 @@ pub struct CustomClaim {
pub organization: Option<String>,
}

pub async fn permissions_from_bearer<S: Send + Sync>(
parts: &mut Parts,
state: &S,
) -> Result<Permissions> {
let Extension(ref cfg) =
ctx!(Extension::<Configuration>::from_request_parts(parts, state).await)?;

pub async fn permissions_from_bearer(parts: &mut Parts, state: &AppState) -> Result<Permissions> {
let TypedHeader(Authorization(bearer)) =
ctx!(TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await)?;

let claims = parse_bearer(&cfg.jwt_secret, &bearer)
let claims = parse_bearer(&state.cfg.jwt_secret, &bearer)
.ok_or_else(|| HttpError::unauthorized(None, Some("Invalid token".to_string())))?;
permissions_from_jwt(claims)
}
Expand Down
35 changes: 26 additions & 9 deletions server/svix-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
#![warn(clippy::all)]
#![forbid(unsafe_code)]

use axum::{extract::Extension, Router};
use axum::Router;

use crate::core::cache::Cache;
use cfg::ConfigurationInner;
use lazy_static::lazy_static;
use opentelemetry::runtime::Tokio;
use opentelemetry_otlp::WithExportConfig;
use queue::TaskQueueProducer;
use sea_orm::DatabaseConnection;
use std::{
net::TcpListener,
sync::atomic::{AtomicBool, Ordering},
Expand All @@ -22,7 +25,9 @@ use tracing_subscriber::{prelude::*, util::SubscriberInitExt};
use crate::{
cfg::{CacheBackend, Configuration},
core::{
cache, idempotency::IdempotencyService, operational_webhooks::OperationalWebhookSenderInner,
cache,
idempotency::IdempotencyService,
operational_webhooks::{OperationalWebhookSender, OperationalWebhookSenderInner},
},
db::init_db,
expired_message_cleaner::expired_message_cleaner_loop,
Expand Down Expand Up @@ -77,6 +82,15 @@ pub async fn run(cfg: Configuration, listener: Option<TcpListener>) {
run_with_prefix(None, cfg, listener).await
}

#[derive(Clone)]
pub struct AppState {
db: DatabaseConnection,
queue_tx: TaskQueueProducer,
cfg: Configuration,
cache: Cache,
op_webhooks: OperationalWebhookSender,
}

// Made public for the purpose of E2E testing in which a queue prefix is necessary to avoid tests
// consuming from each others' queues
pub async fn run_with_prefix(
Expand Down Expand Up @@ -110,8 +124,16 @@ pub async fn run_with_prefix(

let svc_cache = cache.clone();
// build our application with a route
let app_state = AppState {
db: pool.clone(),
queue_tx: queue_tx.clone(),
cfg: cfg.clone(),
cache: cache.clone(),
op_webhooks: op_webhook_sender.clone(),
};
let v1_router = v1::router().with_state::<()>(app_state);
let app = Router::new()
.nest("/api/v1", v1::router())
.nest("/api/v1", v1_router)
.merge(docs::router())
.layer(
ServiceBuilder::new().layer_fn(move |service| IdempotencyService {
Expand All @@ -125,12 +147,7 @@ pub async fn run_with_prefix(
.allow_methods(Any)
.allow_headers(AllowHeaders::mirror_request())
.max_age(Duration::from_secs(600)),
)
.layer(Extension(pool.clone()))
.layer(Extension(queue_tx.clone()))
.layer(Extension(cfg.clone()))
.layer(Extension(cache.clone()))
.layer(Extension(op_webhook_sender.clone()));
);

let with_api = cfg.api_enabled;
let with_worker = cfg.worker_enabled;
Expand Down
17 changes: 9 additions & 8 deletions server/svix-server/src/v1/endpoints/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ use crate::{
validation_error, EmptyResponse, ListResponse, ModelIn, ModelOut, Pagination,
PaginationLimit, ValidatedJson, ValidatedQuery,
},
AppState,
};
use axum::{
extract::{Extension, Path},
extract::{Path, State},
routing::{get, post},
Json, Router,
};
use chrono::{DateTime, Utc};
use hyper::StatusCode;
use sea_orm::ActiveModelTrait;
use sea_orm::ActiveValue::Set;
use sea_orm::{ActiveModelTrait, DatabaseConnection};
use serde::{Deserialize, Serialize};
use svix_server_derive::ModelOut;
use validator::{Validate, ValidationError};
Expand Down Expand Up @@ -184,7 +185,7 @@ impl From<(application::Model, applicationmetadata::Model)> for ApplicationOut {
}

async fn list_applications(
Extension(ref db): Extension<DatabaseConnection>,
State(AppState { ref db, .. }): State<AppState>,
pagination: ValidatedQuery<Pagination<ApplicationId>>,
permissions::Organization { org_id }: permissions::Organization,
) -> Result<Json<ListResponse<ApplicationOut>>> {
Expand Down Expand Up @@ -213,7 +214,7 @@ pub struct CreateApplicationQuery {
}

async fn create_application(
Extension(ref db): Extension<DatabaseConnection>,
State(AppState { ref db, .. }): State<AppState>,
query: ValidatedQuery<CreateApplicationQuery>,
permissions::Organization { org_id }: permissions::Organization,
ValidatedJson(data): ValidatedJson<ApplicationIn>,
Expand Down Expand Up @@ -256,7 +257,7 @@ async fn get_application(
}

async fn update_application(
Extension(ref db): Extension<DatabaseConnection>,
State(AppState { ref db, .. }): State<AppState>,
Path(app_id): Path<ApplicationIdOrUid>,
permissions::Organization { org_id }: permissions::Organization,
ValidatedJson(data): ValidatedJson<ApplicationIn>,
Expand Down Expand Up @@ -294,7 +295,7 @@ async fn update_application(
}

async fn patch_application(
Extension(ref db): Extension<DatabaseConnection>,
State(AppState { ref db, .. }): State<AppState>,
permissions::OrganizationWithApplication { app }: permissions::OrganizationWithApplication,
ValidatedJson(data): ValidatedJson<ApplicationPatch>,
) -> Result<Json<ApplicationOut>> {
Expand All @@ -315,7 +316,7 @@ async fn patch_application(
}

async fn delete_application(
Extension(ref db): Extension<DatabaseConnection>,
State(AppState { ref db, .. }): State<AppState>,
permissions::OrganizationWithApplication { app }: permissions::OrganizationWithApplication,
) -> Result<(StatusCode, Json<EmptyResponse>)> {
let mut app: application::ActiveModel = app.into();
Expand All @@ -325,7 +326,7 @@ async fn delete_application(
Ok((StatusCode::NO_CONTENT, Json(EmptyResponse {})))
}

pub fn router() -> Router {
pub fn router() -> Router<AppState> {
Router::new()
.route("/app/", post(create_application).get(list_applications))
.route(
Expand Down
Loading

0 comments on commit ef43606

Please sign in to comment.