diff --git a/Cargo.lock b/Cargo.lock index 72956b9..f7d7b54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -454,7 +454,7 @@ dependencies = [ "tokio", "tokio-tungstenite 0.17.2", "tower", - "tower-http", + "tower-http 0.3.5", "tower-layer", "tower-service", ] @@ -3091,6 +3091,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tower-http 0.4.4", "tracing", "tracing-subscriber", ] @@ -6151,6 +6152,25 @@ dependencies = [ "tower-service", ] +[[package]] +name = "tower-http" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +dependencies = [ + "bitflags 2.4.2", + "bytes", + "futures-core", + "futures-util", + "http 0.2.11", + "http-body", + "http-range-header", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index fd5f555..2dc3129 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,4 +44,4 @@ tracing-subscriber = { version = "0.3", features = [ ] } tokio = { version = "1.28.1", features = ["full", "rt"] } chrono = "0.4.33" - +tower-http = { version = "0.4.0", features = ["trace", "cors"] } diff --git a/src/server/mod.rs b/src/server/mod.rs index 7d743dd..a3697f6 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,6 +6,7 @@ use std::{ use axum::{extract::Extension, routing::get, Router, Server}; use sqlx::{Pool, Postgres}; +use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer}; use tracing::{debug, info}; use crate::{ @@ -34,12 +35,19 @@ pub async fn run_server(config: Config, db: Pool, _running_program: Ar debug!("Setting up HTTP service"); + // Configure CORS + let cors = CorsLayer::new() + .allow_origin(AllowOrigin::any()) + .allow_methods(AllowMethods::any()) + .allow_headers(AllowHeaders::any()); + let app = Router::new() .route("/health", get(health)) .route( "/api/v1/graphql", get(graphql_playground).post(graphql_handler), ) + .layer(cors) .layer(Extension(schema)) .layer(Extension(context)); let addr = SocketAddr::from_str(&format!("{}:{}", config.server_host(), port))