From 9a56b8174eb2ad5a5d24b37d441999fc947e2696 Mon Sep 17 00:00:00 2001 From: James Lucas Date: Tue, 17 Jan 2023 16:17:59 -0600 Subject: [PATCH] Server: Worker refactor (#704) * Add webhook-http-client, which is based on hyper, enabling several critical improvements: * Default filtering of endpoints that resolve to private address space, with ability to add CIDR exceptions as needed * Case-sensitive HTTP1 headers -- although HTTP spec states that headers are not case-sensitive, in practice many widely deployed web-servers expect to support case-sensitive headers. This requires a fork of hyper, which currently has but does not expose support for case-sensitive headers. * Assurance that the `host` header appears first in outgoing HTTP requests. Although HTTP spec does not prescribe an order for headers, in our experience, many web-servers will fail if the `host` header does not appear first in request headers. We ensure this by manually adding the header ourselves rather than relying on hyper to do so. * Use `openssl` instead of `rusttls` for outbound webhooks. In our experience, many widely deployed webservers use "weak" ciphers that rusttls refuses to support, leaving us with no choice but to rely on openssl libraries. * Support for soft-limiting of concurrent worker tasks. Unbounded concurrent worker tasks can easily overwhelm a system. This sets a default of 500 tasks with support for configuring other limits or setting unlimited (`0`) concurrent tasks * Refactor message dispatch code for clarity. Dispatch code is now broken into separate methods: `prepare_dispatch`, `make_http_call`, `handle_successful_dispatch`, and `handle_failed_dispatch` with the dispatch itself having different types depending on its stage in the dispatch process. Other function names have also been changed for improved clarity. * Prevent worker from shutting down when there are active tasks. * Avoid multiple DB calls for message-destination insert/retrieval --- server/Cargo.lock | 88 +- server/Cargo.toml | 3 + server/run-tests.sh | 1 + server/svix-server/Cargo.toml | 7 + server/svix-server/config.default.toml | 8 + server/svix-server/src/cfg.rs | 9 + server/svix-server/src/core/message_app.rs | 2 +- server/svix-server/src/core/mod.rs | 1 + .../src/core/webhook_http_client.rs | 764 ++++++++++++ server/svix-server/src/error.rs | 29 +- server/svix-server/src/lib.rs | 6 +- server/svix-server/src/queue/mod.rs | 3 +- .../svix-server/src/v1/endpoints/message.rs | 2 +- server/svix-server/src/v1/utils/mod.rs | 16 +- server/svix-server/src/worker.rs | 1075 +++++++++++------ .../tests/integ_webhook_http_client.rs | 156 +++ 16 files changed, 1766 insertions(+), 404 deletions(-) create mode 100644 server/svix-server/src/core/webhook_http_client.rs create mode 100644 server/svix-server/tests/integ_webhook_http_client.rs diff --git a/server/Cargo.lock b/server/Cargo.lock index 20ede477c..c7715f55e 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -1377,8 +1377,7 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" [[package]] name = "hyper" version = "0.14.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "034711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3c" +source = "git+https://github.com/svix/hyper/?rev=b901ca7c#b901ca7c7772c427d63d150e1bf1c2ce7ce0d733" dependencies = [ "bytes", "futures-channel", @@ -1398,6 +1397,24 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-openssl" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ee5d7a8f718585d1c3c61dfde28ef5b0bb14734b4db13f5ada856cdc6c612b" +dependencies = [ + "http", + "hyper", + "linked_hash_set", + "once_cell", + "openssl", + "openssl-sys", + "parking_lot 0.12.1", + "tokio", + "tokio-openssl", + "tower-layer", +] + [[package]] name = "hyper-rustls" version = "0.23.1" @@ -1533,6 +1550,9 @@ name = "ipnet" version = "2.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" +dependencies = [ + "serde", +] [[package]] name = "is-terminal" @@ -1680,6 +1700,15 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "linked_hash_set" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47186c6da4d81ca383c7c47c1bfc80f4b95f4720514d860a5407aaf4233f9588" +dependencies = [ + "linked-hash-map", +] + [[package]] name = "linux-raw-sys" version = "0.1.3" @@ -1980,9 +2009,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.43" +version = "0.10.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020433887e44c27ff16365eaa2d380547a94544ad509aff6eb5b6e3e0b27b376" +checksum = "b102428fd03bc5edf97f62620f7298614c45cedf287c271e7ed450bbaf83f2e1" dependencies = [ "bitflags", "cfg-if", @@ -2012,9 +2041,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.78" +version = "0.9.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07d5c8cb6e57b3a3612064d7b18b117912b4ce70955c2504d4b741c9e244b132" +checksum = "23bbbf7854cd45b83958ebe919f0e8e516793727652e27fda10a8384cfc790b7" dependencies = [ "autocfg 1.1.0", "cc", @@ -3217,6 +3246,28 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strum" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +dependencies = [ + "heck 0.4.0", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "subtle" version = "2.4.1" @@ -3276,11 +3327,14 @@ dependencies = [ "hmac-sha256", "http", "hyper", + "hyper-openssl", "indexmap", + "ipnet", "jsonschema", "jwt-simple", "lazy_static", "num_enum", + "openssl", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", @@ -3297,6 +3351,8 @@ dependencies = [ "serde_urlencoded", "sha2 0.10.6", "sqlx", + "strum", + "strum_macros", "svix", "svix-ksuid", "svix-server_derive", @@ -3308,7 +3364,9 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "trust-dns-resolver", "url", + "urlencoding", "validator", ] @@ -3511,6 +3569,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-openssl" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08f9ffb7809f1b20c1b398d92acf4cc719874b3b2b2d9ea2f09b4a80350878a" +dependencies = [ + "futures-util", + "openssl", + "openssl-sys", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.23.4" @@ -3914,6 +3984,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" + [[package]] name = "uuid" version = "1.2.2" diff --git a/server/Cargo.toml b/server/Cargo.toml index 023cd8eca..252946c0b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -1,2 +1,5 @@ [workspace] members = ["svix-server", "svix-server_derive"] + +[patch.crates-io] +hyper = { git = "https://github.com/svix/hyper/", rev = "b901ca7c" } diff --git a/server/run-tests.sh b/server/run-tests.sh index b318e70a4..699fd7f1b 100755 --- a/server/run-tests.sh +++ b/server/run-tests.sh @@ -8,6 +8,7 @@ TEST_COMMAND="cargo test --all --all-features --all-targets" export DATABASE_URL="postgresql://postgres:postgres@localhost:5432/postgres" export SVIX_JWT_SECRET="test value" export SVIX_LOG_LEVEL="info" +export SVIX_WHITELIST_SUBNETS="[127.0.0.1/32]" echo "*********** RUN 1 ***********" SVIX_QUEUE_TYPE="redis" \ diff --git a/server/svix-server/Cargo.toml b/server/svix-server/Cargo.toml index ac736c51e..4b44268f0 100644 --- a/server/svix-server/Cargo.toml +++ b/server/svix-server/Cargo.toml @@ -23,6 +23,8 @@ clap = { version = "3.2.1", features = ["derive"] } axum = { version = "0.6.1", features = ["headers"] } base64 = "0.13.0" hyper = { version = "0.14.16", features = ["full"] } +hyper-openssl = "0.9.2" +openssl = "0.10.45" tokio = { version = "1.23.1", features = ["full"] } tower = "0.4.11" tower-http = { version = "0.3.4", features = ["trace", "cors", "request-id"] } @@ -66,6 +68,11 @@ jsonschema = "0.16.1" aide = { version = "0.9.0", features = ["axum", "redoc", "macros"] } schemars = { version = "0.8.11", features = ["chrono", "url"] } indexmap = "1.9.2" +trust-dns-resolver = "0.22.0" +ipnet = { version = "2.5", features = ["serde"] } +urlencoding = "2.1.2" +strum_macros = "0.24" +strum = { version = "0.24", features = ["derive"] } [dev-dependencies] anyhow = "1.0.56" diff --git a/server/svix-server/config.default.toml b/server/svix-server/config.default.toml index 6b98a764e..20d734798 100644 --- a/server/svix-server/config.default.toml +++ b/server/svix-server/config.default.toml @@ -94,3 +94,11 @@ api_enabled = true # Should this instance run the message worker worker_enabled = true + +# Subnets to whitelist for outbound webhooks. Note that allowing endpoints in private IP space +# is a security risk and should only be allowed if you are using the service internally or for +# testing purposes. Should be specified in CIDR notation, e.g., `[127.0.0.1/32, 172.17.0.0/16, 192.168.0.0/16]` +# whitelist_subnets = [] + +# Maximum number of concurrent worker tasks to spawn (0 is unlimited) +worker_max_tasks = 500 \ No newline at end of file diff --git a/server/svix-server/src/cfg.rs b/server/svix-server/src/cfg.rs index 17abf209a..b75a6df18 100644 --- a/server/svix-server/src/cfg.rs +++ b/server/svix-server/src/cfg.rs @@ -7,6 +7,7 @@ use figment::{ providers::{Env, Format, Toml}, Figment, }; +use ipnet::IpNet; use std::time::Duration; use crate::{core::cryptography::Encryption, core::security::Keys, error::Result}; @@ -173,6 +174,14 @@ pub struct ConfigurationInner { /// Should this instance run the message worker pub worker_enabled: bool, + /// Subnets to whitelist for outbound webhooks. Note that allowing endpoints in private IP space + /// is a security risk and should only be allowed if you are using the service internally or for + /// testing purposes. Should be specified in CIDR notation, e.g., `[127.0.0.1/32, 172.17.0.0/16, 192.168.0.0/16]` + pub whitelist_subnets: Option>>, + + /// Maximum number of concurrent worker tasks to spawn (0 is unlimited) + pub worker_max_tasks: u16, + #[serde(flatten)] pub internal: InternalConfig, } diff --git a/server/svix-server/src/core/message_app.rs b/server/svix-server/src/core/message_app.rs index d0fc6bc6f..ae1041fd4 100644 --- a/server/svix-server/src/core/message_app.rs +++ b/server/svix-server/src/core/message_app.rs @@ -67,7 +67,7 @@ impl CreateMessageApp { /// exists or from PostgreSQL otherwise. If the RedisCache is Some, but does not contain the /// requisite information, fetch it from PostgreSQL and insert the data into the cache. pub async fn layered_fetch( - cache: Cache, + cache: &Cache, pg: &DatabaseConnection, app: Option, org_id: OrganizationId, diff --git a/server/svix-server/src/core/mod.rs b/server/svix-server/src/core/mod.rs index cf5097c49..3660b7d81 100644 --- a/server/svix-server/src/core/mod.rs +++ b/server/svix-server/src/core/mod.rs @@ -11,6 +11,7 @@ pub mod permissions; pub mod run_with_retries; pub mod security; pub mod types; +pub mod webhook_http_client; #[cfg(test)] mod tests { diff --git a/server/svix-server/src/core/webhook_http_client.rs b/server/svix-server/src/core/webhook_http_client.rs new file mode 100644 index 000000000..197d1d7c4 --- /dev/null +++ b/server/svix-server/src/core/webhook_http_client.rs @@ -0,0 +1,764 @@ +use std::{ + collections::HashMap, + future::Future, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::Pin, + str::FromStr, + sync::Arc, + task::Poll, + time::{Duration, Instant}, +}; + +use axum::headers::{authorization::Credentials, Authorization}; +use bytes::Bytes; +use futures::{future::BoxFuture, FutureExt}; +use http::header::HeaderName; +use http::{HeaderMap, HeaderValue, Method, Response, StatusCode, Version}; +use hyper::{ + client::connect::{dns::Name, HttpConnector}, + ext::HeaderCaseMap, + Body, Client, Uri, +}; +use hyper_openssl::HttpsConnector; +use ipnet::IpNet; +use openssl::ssl::{SslConnector, SslMethod}; +use serde::Serialize; +use thiserror::Error; +use tokio::sync::Mutex; +use tower::Service; +use trust_dns_resolver::{ + error::ResolveError, lookup_ip::LookupIpIntoIter, AsyncResolver, TokioConnection, + TokioConnectionProvider, TokioHandle, +}; + +pub type CaseSensitiveHeaderMap = HashMap; + +#[derive(Debug, Error)] +pub enum Error { + #[error("failure response: {0}")] + FailureStatus(StatusCode), + + #[error("making requests to local IP addresses is forbidden and blocked")] + BlockedIp, + #[error("error resolving name: {0}")] + ResolveError(#[from] ResolveError), + + #[error("request timed out")] + TimedOut, + + #[error("error forming request: {0}")] + InvalidHttpRequest(http::Error), + #[error("error making request: {0}")] + FailedRequest(hyper::Error), +} + +pub struct WebhookClient { + client: Client, Body>, + whitelist_nets: Arc>, +} + +impl WebhookClient { + pub fn new( + whitelist_nets: Option>>, + whitelist_names: Option>>, + ) -> Self { + let whitelist_nets = whitelist_nets.unwrap_or_else(|| Arc::new(Vec::new())); + let whitelist_names = whitelist_names.unwrap_or_else(|| Arc::new(Vec::new())); + + let mut connector = HttpConnector::new_with_resolver(NonLocalDnsResolver::new( + whitelist_nets.clone(), + whitelist_names, + )); + connector.enforce_http(false); + + let ssl = SslConnector::builder(SslMethod::tls()).expect("SslConnector build failed"); + let https = HttpsConnector::with_connector(NonLocalConnector { connector }, ssl) + .expect("HttpsConnector build failed"); + + let client: Client<_, hyper::Body> = Client::builder() + .http1_ignore_invalid_headers_in_responses(true) + .http1_title_case_headers(true) + .build(https); + + Self { + client, + whitelist_nets, + } + } + + pub async fn execute(&self, request: Request) -> Result, Error> { + self.execute_inner(request, true).await + } + + pub fn execute_inner( + &self, + request: Request, + retry: bool, + ) -> BoxFuture, Error>> { + async move { + let org_req = request.clone(); + if let Some(auth) = request.uri.authority() { + if let Ok(ip) = auth.host().parse::() { + if !is_allowed(ip) + && !self + .whitelist_nets + .iter() + .any(|subnet| subnet.contains(&ip)) + { + return Err(Error::BlockedIp); + } + } + } + + let mut req = if let Some(body) = request.body { + hyper::Request::builder() + .method(request.method) + .uri(request.uri) + .version(request.version) + .body(Body::from(body)) + .map_err(Error::InvalidHttpRequest)? + } else { + hyper::Request::builder() + .method(request.method) + .uri(request.uri) + .version(request.version) + .body(Body::empty()) + .map_err(Error::InvalidHttpRequest)? + }; + + *req.headers_mut() = request.headers; + + if let Some(header_names) = request.header_names { + req.extensions_mut().insert(header_names); + } + + let start = Instant::now(); + let res = if let Some(timeout) = request.timeout { + match tokio::time::timeout(timeout, self.client.request(req)).await { + Ok(Ok(resp)) => Ok(resp), + Ok(Err(e)) => Err({ + if e.to_string().contains( + "making requests to local IP addresses is forbidden and blocked", + ) { + Error::BlockedIp + } else { + Error::FailedRequest(e) + } + }), + Err(_to) => Err(Error::TimedOut), + } + } else { + self.client.request(req).await.map_err(|e| { + if e.to_string() + .contains("making requests to local IP addresses is forbidden and blocked") + { + Error::BlockedIp + } else { + Error::FailedRequest(e) + } + }) + }; + + if !retry { + return res; + } + + match res { + Err(ref e) => match e { + Error::FailedRequest(e) if start.elapsed() < Duration::from_millis(1000) => { + tracing::info!("Insta-retrying: {}", e); + self.execute_inner(org_req, false).await + } + _ => res, + }, + res => res, + } + } + .boxed() + } +} + +#[derive(Clone)] +pub struct Request { + method: Method, + uri: Uri, + headers: HeaderMap, + header_names: Option, + body: Option>, + timeout: Option, + version: Version, +} + +pub struct RequestBuilder { + method: Option, + uri: Option, + accept: Option, + user_agent: Option, + headers: Option, + header_names: Option, + body: Option>, + version: Option, + timeout: Option, + basic_auth: Option>, + + // Derived from body + content_type: Option, +} + +#[derive(Debug)] +pub struct RequestBuildError(pub Vec); + +impl std::fmt::Display for RequestBuildError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let err_str = self.0.iter().fold(String::default(), |acc, err| { + if acc.is_empty() { + format!("Build failed: {err}") + } else { + format!("{acc}; {err}") + } + }); + write!(f, "{err_str}") + } +} + +#[derive(Debug, Error)] +pub enum BuildError { + #[error("uri missing")] + UriMissing, + #[error("version missing")] + VersionMissing, +} + +fn decode_or_log(s: &str) -> String { + urlencoding::decode(s) + .map(|x| x.into_owned()) + .unwrap_or_else(|_| { + tracing::error!("URL decoding failed"); + s.to_owned() + }) +} + +impl RequestBuilder { + pub fn new() -> Self { + Self { + method: None, + uri: None, + accept: None, + user_agent: None, + headers: None, + header_names: None, + body: None, + version: None, + timeout: None, + content_type: None, + basic_auth: None, + } + } + + pub fn method(mut self, method: Method) -> Self { + self.method = Some(method); + self + } + + pub fn uri(mut self, uri: url::Url) -> Self { + let basic_auth = if uri.password().is_some() || !uri.username().is_empty() { + let username = decode_or_log(uri.username()); + let password = uri.password().map(decode_or_log).unwrap_or_default(); + + Some( + Authorization::basic(&username, &password) + .0 + .encode() + .as_bytes() + .to_vec(), + ) + } else { + None + }; + self.basic_auth = basic_auth; + + let uri = + Uri::from_str(uri.as_str()).expect("If it's a valid url::Url, it's also a valid Uri"); + self.uri = Some(uri); + self + } + + pub fn uri_str(self, uri: &str) -> Result { + let uri = url::Url::from_str(uri)?; + Ok(self.uri(uri)) + } + + fn build_headers( + headers: CaseSensitiveHeaderMap, + ) -> (hyper::HeaderMap, hyper::ext::HeaderCaseMap) { + let mut hdr_map = hyper::HeaderMap::with_capacity(headers.len()); + let mut case_sensitive_hdrs: hyper::HeaderMap = + hyper::HeaderMap::with_capacity(headers.len()); + for (k, v) in headers.into_iter() { + match HeaderName::from_str(&k) { + Ok(key) => { + hdr_map.insert(key.clone(), v); + case_sensitive_hdrs.insert(key, Bytes::copy_from_slice(k.as_bytes())); + } + Err(e) => { + tracing::error!("Failured to parse header {} {}", k, e); + } + } + } + (hdr_map, case_sensitive_hdrs.into()) + } + + pub fn headers(mut self, headers: CaseSensitiveHeaderMap) -> Self { + let (hdrs, case_map) = Self::build_headers(headers); + self.headers = Some(hdrs); + self.header_names = Some(case_map); + self + } + + pub fn body(mut self, body: Vec, content_type: HeaderValue) -> Self { + self.body = Some(body); + self.content_type = Some(content_type); + self + } + + pub fn json_body(self, body: T) -> Result { + let body = serde_json::to_vec(&body)?; + Ok(self.body(body, HeaderValue::from_static("application/json"))) + } + + pub fn version(mut self, version: Version) -> Self { + self.version = Some(version); + self + } + + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + pub fn user_agent(mut self, user_agent: HeaderValue) -> Self { + self.user_agent = Some(user_agent); + self + } +} + +impl Default for RequestBuilder { + fn default() -> Self { + Self::new() + } +} + +impl RequestBuilder { + fn validate(&self) -> Result<(), RequestBuildError> { + let mut errs: Vec = Vec::new(); + if self.uri.is_none() { + errs.push(BuildError::UriMissing); + } + if self.version.is_none() { + errs.push(BuildError::VersionMissing); + } + + if !errs.is_empty() { + Err(RequestBuildError(errs)) + } else { + Ok(()) + } + } + + pub fn build(self) -> Result { + self.validate()?; + + let custom_headers = self.headers.unwrap_or_default(); + + let uri = self.uri.unwrap(); + let authority = uri.authority().expect("Missing authority"); + let host = match authority.port() { + Some(port) => format!("{}:{}", authority.host(), port), + None => authority.host().to_string(), + }; + + let mut headers = HeaderMap::with_capacity(3 + custom_headers.len()); + + // Ensure that host header is first -- even though this is technically + // not required by HTTP spec, some clients fail if it's not first: + headers.insert(http::header::HOST, HeaderValue::from_str(&host).unwrap()); + headers.insert( + http::header::ACCEPT, + self.accept.unwrap_or(HeaderValue::from_static("*/*")), + ); + headers.insert( + http::header::CONTENT_TYPE, + self.content_type + .unwrap_or(HeaderValue::from_static("application/json")), + ); + + headers.extend(custom_headers.into_iter()); + + if let Some(user_agent) = self.user_agent { + headers.insert(http::header::USER_AGENT, user_agent); + } + + if let Some(auth_header) = self.basic_auth { + if !headers.contains_key(http::header::AUTHORIZATION) { + headers.insert( + http::header::AUTHORIZATION, + HeaderValue::from_bytes(&auth_header).unwrap(), + ); + } + } + + Ok(Request { + method: self.method.unwrap_or(Method::POST), + uri, + headers, + header_names: self.header_names, + body: self.body, + timeout: self.timeout, + version: self.version.unwrap(), + }) + } +} + +#[derive(Clone, Debug)] +struct NonLocalConnector { + connector: HttpConnector, +} + +impl Service for NonLocalConnector { + type Response = >::Response; + type Error = >::Error; + + type Future = as Service>::Future; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.connector.poll_ready(cx) + } + + fn call(&mut self, req: Uri) -> Self::Future { + self.connector.call(req) + } +} + +#[derive(Clone, Debug)] +struct NonLocalDnsResolver { + state: Arc>, + whitelist_nets: Arc>, + whitelist_names: Arc>, +} + +#[derive(Clone, Debug)] +enum DnsState { + Init, + Ready(Arc>), +} + +impl NonLocalDnsResolver { + pub fn new(whitelist_nets: Arc>, whitelist_names: Arc>) -> Self { + NonLocalDnsResolver { + state: Arc::new(Mutex::new(DnsState::Init)), + whitelist_nets, + whitelist_names, + } + } +} + +impl Service for NonLocalDnsResolver { + type Response = SocketAddrs; + type Error = Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, name: Name) -> Self::Future { + let resolver = self.clone(); + let whitelist_nets = self.whitelist_nets.clone(); + let whitelist_names = self.whitelist_names.clone(); + + Box::pin(async move { + let mut lock = resolver.state.lock().await; + + let resolver = match &*lock { + DnsState::Init => { + let resolver = new_resolver().await?; + *lock = DnsState::Ready(resolver.clone()); + resolver + } + + DnsState::Ready(resolver) => resolver.clone(), + }; + + drop(lock); + + let whitelisted_name = whitelist_names + .iter() + .any(|whitelisted| whitelisted == name.as_str()); + + let lookup = resolver.lookup_ip(name.as_str()).await?; + + if lookup.iter().all(|ip| { + !is_allowed(ip) + && !whitelist_nets.iter().any(|subnet| subnet.contains(&ip)) + && !whitelisted_name + }) { + Err(Error::BlockedIp) + } else { + Ok(SocketAddrs { + iter: lookup.into_iter(), + whitelist_nets, + whitelisted_name, + }) + } + }) + } +} + +pub struct SocketAddrs { + iter: LookupIpIntoIter, + whitelist_nets: Arc>, + whitelisted_name: bool, +} + +impl Iterator for SocketAddrs { + type Item = SocketAddr; + + fn next(&mut self) -> Option { + loop { + match self.iter.next() { + Some(ip_addr) => { + if is_allowed(ip_addr) + || self + .whitelist_nets + .iter() + .any(|subnet| subnet.contains(&ip_addr)) + || self.whitelisted_name + { + return Some(SocketAddr::from((ip_addr, 0))); + } + } + + None => return None, + } + } + } +} + +async fn new_resolver( +) -> Result>, ResolveError> { + Ok(Arc::new(AsyncResolver::from_system_conf(TokioHandle)?)) +} + +fn is_allowed(addr: IpAddr) -> bool { + match addr { + IpAddr::V4(addr) => { + !addr.is_private() + && !addr.is_loopback() + && !addr.is_link_local() + && !addr.is_broadcast() + && !addr.is_documentation() + && !is_shared(addr) + && !is_reserved(addr) + && !is_benchmarking(addr) + && !starts_with_zero(addr) + } + IpAddr::V6(addr) => { + !addr.is_multicast() + && !addr.is_loopback() + && !is_unicast_link_local(addr) + && !is_unique_local(addr) + && !addr.is_unspecified() + && !is_documentation_v6(addr) + } + } +} + +/// Util functions copied from the unstable standard library near identically +fn is_shared(addr: Ipv4Addr) -> bool { + addr.octets()[0] == 100 && (addr.octets()[1] & 0b1100_0000 == 0b0100_0000) +} + +fn is_reserved(addr: Ipv4Addr) -> bool { + (addr.octets()[0] == 192 && addr.octets()[1] == 0 && addr.octets()[2] == 0) + || (addr.octets()[0] & 240 == 240 && !addr.is_broadcast()) +} + +fn is_benchmarking(addr: Ipv4Addr) -> bool { + addr.octets()[0] == 198 && (addr.octets()[1] & 0xfe) == 18 +} + +fn starts_with_zero(addr: Ipv4Addr) -> bool { + addr.octets()[0] == 0 +} + +fn is_unicast_link_local(addr: Ipv6Addr) -> bool { + (addr.segments()[0] & 0xffc0) == 0xfe80 +} + +fn is_unique_local(addr: Ipv6Addr) -> bool { + (addr.segments()[0] & 0xfe00) == 0xfc00 +} + +fn is_documentation_v6(addr: Ipv6Addr) -> bool { + (addr.segments()[0] == 0x2001) && (addr.segments()[1] == 0xdb8) +} + +#[cfg(test)] +mod tests { + use std::{net::IpAddr, str::FromStr}; + + use super::*; + + #[test] + fn is_allowed_test() { + // Copied shamelessly from the standard library `is_global` docs + assert!(!is_allowed(IpAddr::from([10, 254, 0, 0]))); + assert!(!is_allowed(IpAddr::from([192, 168, 10, 65]))); + assert!(!is_allowed(IpAddr::from([172, 16, 10, 65]))); + assert!(!is_allowed(IpAddr::from([0, 1, 2, 3]))); + assert!(!is_allowed(IpAddr::from([0, 0, 0, 0]))); + assert!(!is_allowed(IpAddr::from([127, 0, 0, 1]))); + assert!(!is_allowed(IpAddr::from([169, 254, 45, 1]))); + assert!(!is_allowed(IpAddr::from([255, 255, 255, 255]))); + assert!(!is_allowed(IpAddr::from([192, 0, 2, 255]))); + assert!(!is_allowed(IpAddr::from([198, 51, 100, 65]))); + assert!(!is_allowed(IpAddr::from([203, 0, 113, 6]))); + assert!(!is_allowed(IpAddr::from([100, 100, 0, 0]))); + assert!(!is_allowed(IpAddr::from([192, 0, 0, 0]))); + assert!(!is_allowed(IpAddr::from([192, 0, 0, 255]))); + assert!(!is_allowed(IpAddr::from([250, 10, 20, 30]))); + assert!(!is_allowed(IpAddr::from([198, 18, 0, 0]))); + + assert!(is_allowed(IpAddr::from([1, 1, 1, 1]))); + + assert!(!is_allowed(IpAddr::from([0, 0, 0, 0, 0, 0, 0, 0x1]))); + + assert!(is_allowed(IpAddr::from([0, 0, 0, 0xffff, 0, 0, 0, 0x1]))); + } + + #[test] + fn test_builder() { + match RequestBuilder::new().build() { + Err(e) => assert_eq!("Build failed: uri missing; version missing", e.to_string()), + Ok(_) => panic!(), + } + + assert!(RequestBuilder::new() + .version(Version::HTTP_11) + .build() + .is_err()); + + assert!(RequestBuilder::new() + .uri(url::Url::from_str("http://127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .is_ok()); + } + + #[test] + fn test_header_casings() { + let hdrs = CaseSensitiveHeaderMap::from([( + "tEsT-header-1".to_owned(), + HeaderValue::from_static("value"), + )]); + + let req = RequestBuilder::new() + .uri(url::Url::from_str("http://127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .headers(hdrs) + .build() + .unwrap(); + + assert_eq!( + req.header_names + .unwrap() + .get("test-header-1".parse().unwrap()) + .unwrap(), + "tEsT-header-1".as_bytes() + ); + assert_eq!( + req.headers.get("test-header-1").unwrap(), + HeaderValue::from_static("value") + ); + } + + #[test] + fn test_url_basic_auth() { + let req = RequestBuilder::new() + .uri(url::Url::from_str("http://test:123@127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!( + req.headers.get("authorization").unwrap(), + "Basic dGVzdDoxMjM=".as_bytes() + ); + + let req_user_only = RequestBuilder::new() + .uri(url::Url::from_str("http://test:@127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!( + req_user_only.headers.get("authorization").unwrap(), + "Basic dGVzdDo=".as_bytes() + ); + + let req_pass_only = RequestBuilder::new() + .uri(url::Url::from_str("http://:123@127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!( + req_pass_only.headers.get("authorization").unwrap(), + "Basic OjEyMw==".as_bytes() + ); + + let req_no_basic_auth = RequestBuilder::new() + .uri(url::Url::from_str("http://127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert!(req_no_basic_auth.headers.get("authorization").is_none()); + + let req_special_chars = RequestBuilder::new() + .uri(url::Url::from_str("http://test==:123==@127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!( + req_special_chars.headers.get("authorization").unwrap(), + "Basic dGVzdD09OjEyMz09".as_bytes() + ); + } + + #[test] + fn test_host_header() { + let req = RequestBuilder::new() + .uri(url::Url::from_str("http://127.0.0.1/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!(req.headers.get("host").unwrap(), "127.0.0.1".as_bytes()); + + let req_with_port = RequestBuilder::new() + .uri(url::Url::from_str("http://127.0.0.1:8000/").unwrap()) + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert_eq!( + req_with_port.headers.get("host").unwrap(), + "127.0.0.1:8000".as_bytes() + ); + } +} diff --git a/server/svix-server/src/error.rs b/server/svix-server/src/error.rs index ddf653757..5636bebb5 100644 --- a/server/svix-server/src/error.rs +++ b/server/svix-server/src/error.rs @@ -23,6 +23,8 @@ use serde::Serialize; use serde_json::json; use sqlx::Error as SqlxError; +use crate::core::webhook_http_client; + /// A short-hand version of a [std::result::Result] that always returns an Svix [Error]. pub type Result = std::result::Result; @@ -149,6 +151,13 @@ macro_rules! err_queue { }; } +#[macro_export] +macro_rules! err_cache { + ($s:expr) => { + $crate::error::Error::cache($s, $crate::location!()) + }; +} + #[macro_export] macro_rules! err_validation { ($s:expr) => { @@ -271,6 +280,8 @@ pub enum ErrorType { Http(HttpError), /// Cache error Cache(String), + /// Timeout error + Timeout(String), } impl fmt::Display for ErrorType { @@ -282,6 +293,7 @@ impl fmt::Display for ErrorType { Self::Validation(s) => s.fmt(f), Self::Http(s) => s.fmt(f), Self::Cache(s) => s.fmt(f), + Self::Timeout(s) => s.fmt(f), } } } @@ -319,7 +331,7 @@ pub struct ValidationErrorItem { #[derive(Debug, Clone)] pub struct HttpError { - status: StatusCode, + pub status: StatusCode, body: HttpErrorBody, } @@ -428,3 +440,18 @@ impl IntoResponse for HttpError { (self.status, Json(self.body)).into_response() } } + +impl From for Error { + fn from(typ: ErrorType) -> Self { + Self { trace: vec![], typ } + } +} + +impl From for Error { + fn from(err: webhook_http_client::Error) -> Error { + match err { + webhook_http_client::Error::TimedOut => ErrorType::Timeout(err.to_string()).into(), + _ => err_generic!(err.to_string()), + } + } +} diff --git a/server/svix-server/src/lib.rs b/server/svix-server/src/lib.rs index ee742abf6..22039d573 100644 --- a/server/svix-server/src/lib.rs +++ b/server/svix-server/src/lib.rs @@ -34,7 +34,7 @@ use crate::{ }, db::init_db, expired_message_cleaner::expired_message_cleaner_loop, - worker::worker_loop, + worker::queue_handler, }; pub mod cfg; @@ -298,10 +298,10 @@ pub async fn run_with_prefix( async { if with_worker { tracing::debug!("Worker: Initializing"); - worker_loop( + queue_handler( &cfg, - &pool, cache.clone(), + pool.clone(), queue_tx, queue_rx, op_webhook_sender, diff --git a/server/svix-server/src/queue/mod.rs b/server/svix-server/src/queue/mod.rs index 21ca7ff44..152ac62b5 100644 --- a/server/svix-server/src/queue/mod.rs +++ b/server/svix-server/src/queue/mod.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use axum::async_trait; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; +use strum::Display; use svix_ksuid::*; use crate::{ @@ -93,7 +94,7 @@ impl MessageTaskBatch { } } -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Display)] #[serde(rename_all = "camelCase")] #[serde(tag = "type")] pub enum QueueTask { diff --git a/server/svix-server/src/v1/endpoints/message.rs b/server/svix-server/src/v1/endpoints/message.rs index 709d20192..519ddb6c7 100644 --- a/server/svix-server/src/v1/endpoints/message.rs +++ b/server/svix-server/src/v1/endpoints/message.rs @@ -236,7 +236,7 @@ async fn create_message( ValidatedJson(data): ValidatedJson, ) -> Result<(StatusCode, Json)> { let create_message_app = CreateMessageApp::layered_fetch( - cache, + &cache, db, Some(app.clone()), app.org_id.clone(), diff --git a/server/svix-server/src/v1/utils/mod.rs b/server/svix-server/src/v1/utils/mod.rs index b3df849b5..12cf0eb65 100644 --- a/server/svix-server/src/v1/utils/mod.rs +++ b/server/svix-server/src/v1/utils/mod.rs @@ -1,7 +1,14 @@ // SPDX-FileCopyrightText: © 2022 Svix Authors // SPDX-License-Identifier: MIT -use std::{borrow::Cow, collections::HashSet, error::Error as StdError, ops::Deref, str::FromStr}; +use std::{ + borrow::Cow, + collections::HashSet, + error::Error as StdError, + ops::Deref, + str::FromStr, + time::{SystemTime, UNIX_EPOCH}, +}; use aide::{transform::TransformPathItem, OperationInput, OperationIo}; use axum::{ @@ -448,6 +455,13 @@ pub fn openapi_tag(tag: &'static str) -> impl FnOnce(TransformPathItem) -> Trans |op| op.tag(tag) } +pub fn get_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() +} + #[cfg(test)] mod tests { use validator::Validate; diff --git a/server/svix-server/src/worker.rs b/server/svix-server/src/worker.rs index 30e8076c5..f488d9cea 100644 --- a/server/svix-server/src/worker.rs +++ b/server/svix-server/src/worker.rs @@ -1,52 +1,65 @@ // SPDX-FileCopyrightText: © 2022 Svix Authors -// SPDX-License-Identifier: MIT +// SPDX-Licensepub(crate) -Identifier: MIT use crate::cfg::Configuration; +use crate::core::cache::{kv_def, Cache, CacheBehavior, CacheKey, CacheValue}; use crate::core::cryptography::Encryption; -use crate::core::operational_webhooks::EndpointDisabledEvent; +use crate::core::message_app::{CreateMessageApp, CreateMessageEndpoint}; +use crate::core::operational_webhooks::{ + EndpointDisabledEvent, MessageAttemptEvent, OperationalWebhook, OperationalWebhookSender, +}; use crate::core::types::{ - ApplicationId, ApplicationUid, EndpointId, EndpointSecretInternal, EndpointSecretType, + ApplicationId, ApplicationUid, BaseId, EndpointHeaders, EndpointId, EndpointSecretInternal, + EndpointSecretType, MessageAttemptId, MessageAttemptTriggerType, MessageId, MessageStatus, MessageUid, OrganizationId, }; -use crate::core::{ - cache::{kv_def, Cache, CacheBehavior, CacheKey, CacheValue}, - message_app::{CreateMessageApp, CreateMessageEndpoint}, - operational_webhooks::{MessageAttemptEvent, OperationalWebhook, OperationalWebhookSender}, - types::{ - BaseId, EndpointHeaders, MessageAttemptId, MessageAttemptTriggerType, MessageId, - MessageStatus, - }, +use crate::core::webhook_http_client::{ + Error as WebhookClientError, RequestBuilder, WebhookClient, }; use crate::db::models::{endpoint, message, messageattempt, messagedestination}; -use crate::error::Result; +use crate::error::{Error, ErrorType, HttpError, Result}; use crate::queue::{ MessageTask, MessageTaskBatch, QueueTask, TaskQueueConsumer, TaskQueueProducer, }; -use crate::{ctx, err_generic}; +use crate::v1::utils::get_unix_timestamp; +use crate::{ctx, err_cache, err_generic, err_validation}; + use chrono::Utc; use futures::future; +use http::{HeaderValue, StatusCode, Version}; +use ipnet::IpNet; +use lazy_static::lazy_static; use rand::Rng; -use reqwest::header::{HeaderMap, HeaderName}; -use sea_orm::{entity::prelude::*, ActiveValue::Set, DatabaseConnection, EntityTrait}; -use serde::{Deserialize, Serialize}; -use svix_ksuid::{KsuidLike, KsuidMs}; -use tokio::time::{sleep, Duration}; -use std::{ - str::FromStr, - sync::{atomic::Ordering, Arc}, +use sea_orm::prelude::DateTimeUtc; +use sea_orm::{ + ActiveModelTrait, ColumnTrait, DatabaseConnection, EntityTrait, QueryFilter, Set, TryIntoModel, }; +use serde::{Deserialize, Serialize}; +use tokio::time::sleep; +use tracing::Instrument; + +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +pub type CaseSensitiveHeaderMap = HashMap; // The maximum variation from the retry schedule when applying jitter to a resent webhook event in // percent deviation const JITTER_DELTA: f32 = 0.2; +const OVERLOAD_PENALTY_SECS: u64 = 60; const USER_AGENT: &str = concat!("Svix-Webhooks/", env!("CARGO_PKG_VERSION")); + /// Send the MessageAttemptFailingEvent after exceeding this number of failed attempts const OP_WEBHOOKS_SEND_FAILING_EVENT_AFTER: usize = 4; +const RESPONSE_MAX_SIZE: usize = 20000; + /// A simple struct noting the context of the wrapped [`DateTimeUtc`]. This struct is returned when /// you are to disable disable an endpoint. This is optionally returned by [`process_failure_cache`] /// which is to be called after all retry events are exhausted. @@ -65,8 +78,18 @@ pub struct FailureCacheValue { kv_def!(FailureCacheKey, FailureCacheValue, "SVIX_FAILURE_CACHE"); impl FailureCacheKey { - pub fn new(app_id: &ApplicationId, endp_id: &EndpointId) -> FailureCacheKey { - FailureCacheKey(format!("_{app_id}_{endp_id}")) + pub fn new( + org_id: &OrganizationId, + app_id: &ApplicationId, + endp_id: &EndpointId, + ) -> FailureCacheKey { + FailureCacheKey(format!( + "{}_{}_{}_{}", + Self::PREFIX_CACHE, + org_id, + app_id, + endp_id + )) } } @@ -76,16 +99,16 @@ impl FailureCacheKey { /// /// If the key value pair does not already exist in the cache, indicating that the endpoint never /// stopped responding, no operation is performed. -async fn process_success_cache( +#[tracing::instrument(skip_all)] +async fn process_endpoint_success( cache: &Cache, app_id: &ApplicationId, - endp_id: &EndpointId, + org_id: &OrganizationId, + endp: &CreateMessageEndpoint, ) -> Result<()> { - let key = FailureCacheKey::new(app_id, endp_id); + let key = FailureCacheKey::new(org_id, app_id, &endp.id); - cache.delete(&key).await.map_err(|e| err_generic!(e))?; - - Ok(()) + cache.delete(&key).await.map_err(|e| err_cache!(e)) } /// Called upon endpoint failure. Returns whether to disable the endpoint based on the time of first @@ -99,15 +122,15 @@ async fn process_success_cache( /// /// All cache values are set with an expiration time greater thah the grace period, so occasional /// failures will not cause an endpoint to be disabled. -async fn process_failure_cache( +#[tracing::instrument(skip_all)] +async fn process_endpoint_failure( cache: &Cache, - app_id: &ApplicationId, - endp_id: &EndpointId, - + org_id: &OrganizationId, + endp: &CreateMessageEndpoint, disable_in: Duration, ) -> Result> { - let key = FailureCacheKey::new(app_id, endp_id); + let key = FailureCacheKey::new(org_id, app_id, &endp.id); let now = Utc::now(); // If it already exists in the cache, see if the grace preiod has already elapsed @@ -174,148 +197,161 @@ fn generate_msg_headers( whitelabel_headers: bool, configured_headers: Option<&EndpointHeaders>, _endpoint_url: &str, -) -> HeaderMap { - let mut headers = HeaderMap::new(); - let id = msg_id.0.parse().expect("Error parsing message id"); +) -> Result { + let mut headers = CaseSensitiveHeaderMap::new(); + let id_hdr = msg_id + .0 + .parse() + .map_err(|_| err_generic!("Error parsing message id".to_string()))?; let timestamp = timestamp .to_string() .parse() - .expect("Error parsing message timestamp"); + .map_err(|_| err_generic!("Error parsing message timestamp".to_string()))?; let signatures_str = signatures .parse() - .expect("Error parsing message signatures"); + .map_err(|_| err_generic!("Error parsing message signatures".to_string()))?; if whitelabel_headers { - headers.insert("webhook-id", id); - headers.insert("webhook-timestamp", timestamp); - headers.insert("webhook-signature", signatures_str); + headers.insert("webhook-id".to_owned(), id_hdr); + headers.insert("webhook-timestamp".to_owned(), timestamp); + headers.insert("webhook-signature".to_owned(), signatures_str); } else { - headers.insert("svix-id", id); - headers.insert("svix-timestamp", timestamp); - headers.insert("svix-signature", signatures_str); + headers.insert("svix-id".to_owned(), id_hdr); + headers.insert("svix-timestamp".to_owned(), timestamp); + headers.insert("svix-signature".to_owned(), signatures_str); } - + headers.insert( + "user-agent".to_owned(), + USER_AGENT.to_string().parse().unwrap(), + ); + headers.insert( + "content-type".to_owned(), + "application/json".parse().unwrap(), + ); if let Some(configured_headers) = configured_headers { for (k, v) in &configured_headers.0 { - if let (Ok(k), Ok(v)) = (HeaderName::from_str(k), v.parse()) { - headers.insert(k, v); - } else { - tracing::error!("Invalid HeaderName or HeaderValues for `{}: {}`", k, v); + match v.parse() { + Ok(v) => { + headers.insert(k.clone(), v); + } + Err(e) => { + tracing::error!("Invalid HeaderValue {}: {}", v, e); + } } } } - headers + Ok(headers) } -#[derive(Clone, Copy)] +#[derive(Clone)] struct WorkerContext<'a> { - task_id: &'a str, cfg: &'a Configuration, - db: &'a DatabaseConnection, cache: &'a Cache, + db: &'a DatabaseConnection, queue_tx: &'a TaskQueueProducer, op_webhook_sender: &'a OperationalWebhookSender, } -struct DispatchExtraIds<'a> { - org_id: &'a OrganizationId, - app_uid: Option<&'a ApplicationUid>, - msg_uid: Option<&'a MessageUid>, +struct FailedDispatch(messageattempt::ActiveModel, Error); +struct SuccessfulDispatch(messageattempt::ActiveModel); + +#[allow(clippy::large_enum_variant)] +enum IncompleteDispatch { + Pending(PendingDispatch), + #[allow(dead_code)] + Failed(FailedDispatch), } -/// Dispatches one webhook -#[tracing::instrument( - skip_all, - fields( - task_id = task_id, - org_id = org_id.0.as_str(), - endp_id = msg_task.endpoint_id.0.as_str(), - msg_id = msg_task.msg_id.0.as_str() - ) - level = "error" -)] -async fn dispatch( - WorkerContext { - task_id, - cache, - cfg, - db, - queue_tx, - op_webhook_sender, +struct PendingDispatch { + method: http::Method, + url: String, + headers: CaseSensitiveHeaderMap, + payload: String, + request_timeout: u64, + created_at: DateTimeUtc, +} + +enum CompletedDispatch { + Failed(FailedDispatch), + Successful(SuccessfulDispatch), +} + +#[tracing::instrument(skip_all)] +async fn prepare_dispatch( + WorkerContext { cfg, .. }: &WorkerContext<'_>, + DispatchContext { + msg_task, + payload, + endp, .. - }: WorkerContext<'_>, - msg_task: MessageTask, - DispatchExtraIds { - org_id, - app_uid, - msg_uid, - }: DispatchExtraIds<'_>, - body: String, - endp: CreateMessageEndpoint, -) -> Result<()> { - tracing::trace!("Dispatch: {} {}", &msg_task.msg_id, &endp.id); + }: DispatchContext<'_>, +) -> Result { + let attempt_created_at = Utc::now(); - let now = Utc::now(); let headers = { let keys = endp.valid_signing_keys(); let signatures = sign_msg( &cfg.encryption, - now.timestamp(), - &body, + attempt_created_at.timestamp(), + payload, &msg_task.msg_id, &keys, ); - let mut headers = generate_msg_headers( - now.timestamp(), + generate_msg_headers( + attempt_created_at.timestamp(), &msg_task.msg_id, signatures, cfg.whitelabel_headers, endp.headers.as_ref(), &endp.url, - ); - headers.insert("user-agent", USER_AGENT.to_string().parse().unwrap()); - headers.insert("content-type", "application/json".parse().unwrap()); - headers + )? }; - let client = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build() - .expect("Invalid reqwest Client configuration"); - let res = client - .post(&endp.url) - .headers(headers) - .timeout(Duration::from_secs(cfg.worker_request_timeout as u64)) - .body(body) - .send() - .await; - - let msg_dest = ctx!( - messagedestination::Entity::secure_find_by_msg(msg_task.msg_id.clone()) - .filter(messagedestination::Column::EndpId.eq(endp.id.clone())) - .one(db) - .await - )? - .ok_or_else(|| err_generic!("Msg dest not found {} {}", msg_task.msg_id, endp.id))?; + Ok(IncompleteDispatch::Pending(PendingDispatch { + method: http::Method::POST, + url: endp.url.clone(), + headers, + payload: payload.to_owned(), + request_timeout: cfg.worker_request_timeout as _, + created_at: attempt_created_at, + })) +} - if (msg_dest.status != MessageStatus::Pending && msg_dest.status != MessageStatus::Sending) - && (msg_task.trigger_type != MessageAttemptTriggerType::Manual) - { - // TODO: it happens when this message destination is "resent". This leads to 2 queue tasks with the same message destination - tracing::warn!( - "MessageDestination {} is not pending (it's {:?}).", - msg_dest.id, - msg_dest.status - ); - return Ok(()); - } +#[tracing::instrument(skip_all)] +async fn make_http_call( + DispatchContext { msg_task, endp, .. }: DispatchContext<'_>, + PendingDispatch { + method, + url, + headers, + payload, + request_timeout, + created_at, + }: PendingDispatch, + msg_dest: &messagedestination::Model, + whitelist_subnets: &Option>>, +) -> Result { + let client = WebhookClient::new( + whitelist_subnets.clone(), + Some(Arc::new(vec!["backend".to_owned()])), + ); + let req = RequestBuilder::new() + .method(method) + .uri_str(&url) + .map_err(|_| err_validation!("URL is invalid".to_owned()))? + .headers(headers) + .body(payload.into(), HeaderValue::from_static("application/json")) + .version(Version::HTTP_11) + .timeout(Duration::from_secs(request_timeout)) + .build() + .map_err(|e| err_generic!(e))?; let attempt = messageattempt::ActiveModel { // Set both ID and created_at to the same timestamp - id: Set(MessageAttemptId::new(now.into(), None)), - created_at: Set(now.into()), + id: Set(MessageAttemptId::new(created_at.into(), None)), + created_at: Set(created_at.into()), msg_id: Set(msg_task.msg_id.clone()), endp_id: Set(endp.id.clone()), msg_dest_id: Set(msg_dest.id.clone()), @@ -324,7 +360,8 @@ async fn dispatch( trigger_type: Set(msg_task.trigger_type), ..Default::default() }; - let attempt = match res { + + match client.execute(req).await { Ok(res) => { let status_code = res.status().as_u16() as i16; let status = if res.status().is_success() { @@ -332,206 +369,343 @@ async fn dispatch( } else { MessageStatus::Fail }; - let http_error = res.error_for_status_ref().err(); - let attempt = match res.bytes().await { - Ok(bytes) => { - let body = bytes_to_string(bytes); + let http_error = if !res.status().is_success() { + Some(WebhookClientError::FailureStatus(res.status())) + } else { + None + }; - messageattempt::ActiveModel { - response_status_code: Set(status_code), - response: Set(body), - status: Set(status), - ..attempt - } + let body = match hyper::body::to_bytes(res.into_body()).await { + Ok(bytes) if bytes.len() > RESPONSE_MAX_SIZE => { + bytes_to_string(bytes.slice(..RESPONSE_MAX_SIZE)) } + Ok(bytes) => bytes_to_string(bytes), + Err(err) => format!("Error reading response body: {err}"), + }; - Err(err) => { - tracing::warn!("Error reading response body: {}", err); - messageattempt::ActiveModel { - response_status_code: Set(status_code), - response: Set(format!("failed to read response body: {err}")), - status: Set(status), - ..attempt - } - } + let attempt = messageattempt::ActiveModel { + response_status_code: Set(status_code), + response: Set(body), + status: Set(status), + ..attempt }; match http_error { - Some(err) => Err((attempt, err)), - None => Ok(attempt), + Some(err) => Ok(CompletedDispatch::Failed(FailedDispatch( + attempt, + err_generic!(err.to_string()), + ))), + None => Ok(CompletedDispatch::Successful(SuccessfulDispatch(attempt))), } } - - Err(err) => { - let attempt = messageattempt::ActiveModel { + Err(err) => Ok(CompletedDispatch::Failed(FailedDispatch( + messageattempt::ActiveModel { response_status_code: Set(0), response: Set(err.to_string()), status: Set(MessageStatus::Fail), - ..attempt - }; - Err((attempt, err)) - } + }, + err.into(), + ))), + } +} + +#[tracing::instrument(skip_all, fields(response_code, msg_dest_id=msg_dest.id.0))] +async fn handle_successful_dispatch( + WorkerContext { cache, db, .. }: &WorkerContext<'_>, + DispatchContext { + org_id, + endp, + app_id, + .. + }: DispatchContext<'_>, + SuccessfulDispatch(mut attempt): SuccessfulDispatch, + msg_dest: messagedestination::Model, +) -> Result<()> { + attempt.ended_at = Set(Some(Utc::now().into())); + let attempt = ctx!(attempt.insert(*db).await)?; + + let msg_dest = messagedestination::ActiveModel { + status: Set(MessageStatus::Success), + next_attempt: Set(None), + ..msg_dest.into() }; + let _msg_dest = ctx!(msg_dest.update(*db).await)?; - match attempt { - Ok(attempt) => { - let _attempt = ctx!(attempt.insert(db).await)?; + process_endpoint_success(cache, app_id, org_id, endp).await?; - let msg_dest = messagedestination::ActiveModel { - status: Set(MessageStatus::Success), - next_attempt: Set(None), - ..msg_dest.into() - }; - let msg_dest = ctx!(msg_dest.update(db).await)?; + tracing::Span::current().record("response_code", attempt.response_status_code); + tracing::info!("Webhook success."); - process_success_cache(cache, &msg_task.app_id, &msg_task.endpoint_id).await?; + Ok(()) +} - tracing::trace!("Worker success: {} {}", &msg_dest.id, &endp.id,); - } - Err((attempt, err)) => { - let attempt = ctx!(attempt.insert(db).await)?; - - let attempt_count = msg_task.attempt_count as usize; - if msg_task.trigger_type == MessageAttemptTriggerType::Manual { - tracing::debug!("Manual retry failed"); - } else if attempt_count < cfg.retry_schedule.len() { - tracing::debug!( - "Worker failure retrying for attempt {}: {} {} {}", - attempt_count, - err, - &msg_dest.id, - &endp.id - ); +fn calculate_retry_delay(duration: Duration, err: Error) -> Duration { + let duration = if matches!(err.typ, ErrorType::Timeout(_)) + || matches!(err.typ, ErrorType::Http(HttpError { status, .. }) if status == StatusCode::TOO_MANY_REQUESTS) + { + std::cmp::max(duration, Duration::from_secs(OVERLOAD_PENALTY_SECS)) + } else { + duration + }; + // Apply jitter with a maximum variation of JITTER_DELTA + rand::thread_rng() + .gen_range(duration.mul_f32(1.0 - JITTER_DELTA)..=duration.mul_f32(1.0 + JITTER_DELTA)) +} - let duration = cfg.retry_schedule[attempt_count]; +#[tracing::instrument(skip_all, fields(response_code, msg_dest_id=msg_dest.id.0))] +async fn handle_failed_dispatch( + WorkerContext { + db, + cache, + op_webhook_sender, + cfg, + queue_tx, + .. + }: &WorkerContext<'_>, + DispatchContext { + org_id, + app_id, + app_uid, + msg_uid, + endp, + msg_task, + .. + }: DispatchContext<'_>, + FailedDispatch(mut attempt, err): FailedDispatch, + msg_dest: messagedestination::Model, +) -> Result<()> { + attempt.ended_at = Set(Some(Utc::now().into())); + let attempt = ctx!(attempt.insert(*db).await)?; + + tracing::Span::current().record("response_code", attempt.response_status_code); + tracing::info!("Webhook failure."); + + let retry_schedule = &cfg.retry_schedule; + + let attempt_count = msg_task.attempt_count as usize; + if msg_task.trigger_type == MessageAttemptTriggerType::Manual { + tracing::debug!("Manual retry failed"); + Ok(()) + } else if attempt_count < retry_schedule.len() { + tracing::debug!( + "Worker failure retrying for attempt {}: {} {} {}", + attempt_count, + err, + &msg_dest.id, + &endp.id + ); - // Apply jitter with a maximum variation of JITTER_DELTA - let duration = rand::thread_rng().gen_range( - duration.mul_f32(1.0 - JITTER_DELTA)..duration.mul_f32(1.0 + JITTER_DELTA), + let retry_delay = calculate_retry_delay(retry_schedule[attempt_count], err); + let next_attempt_time = + Utc::now() + chrono::Duration::from_std(retry_delay).expect("Error parsing duration"); + let msg_dest = messagedestination::ActiveModel { + next_attempt: Set(Some(next_attempt_time.into())), + ..msg_dest.into() + }; + let _msg_dest = ctx!(msg_dest.update(*db).await)?; + + if attempt_count == (OP_WEBHOOKS_SEND_FAILING_EVENT_AFTER - 1) { + if let Err(e) = op_webhook_sender + .send_operational_webhook( + org_id, + OperationalWebhook::MessageAttemptFailing(MessageAttemptEvent { + app_id: app_id.clone(), + app_uid: app_uid.cloned(), + endpoint_id: msg_task.endpoint_id.clone(), + msg_id: msg_task.msg_id.clone(), + msg_event_id: msg_uid.cloned(), + last_attempt: attempt.into(), + }), + ) + .await + { + tracing::error!( + "Failed sending MessageAttemptFailing Operational Webhook: {}", + e ); + } + } + queue_tx + .send( + QueueTask::MessageV1(MessageTask { + attempt_count: msg_task.attempt_count + 1, + ..msg_task.clone() + }), + Some(retry_delay), + ) + .await?; - let msg_dest = messagedestination::ActiveModel { - next_attempt: Set(Some( - (Utc::now() - + chrono::Duration::from_std(duration) - .expect("Error parsing duration")) - .into(), - )), - ..msg_dest.into() - }; - let _msg_dest = ctx!(msg_dest.update(db).await)?; - - if attempt_count == OP_WEBHOOKS_SEND_FAILING_EVENT_AFTER { - op_webhook_sender - .send_operational_webhook( - org_id, - OperationalWebhook::MessageAttemptFailing(MessageAttemptEvent { - app_id: msg_task.app_id.clone(), - app_uid: app_uid.cloned(), - endpoint_id: msg_task.endpoint_id.clone(), - msg_id: msg_task.msg_id.clone(), - msg_event_id: msg_uid.cloned(), - last_attempt: attempt.into(), - }), - ) - .await?; - } + Ok(()) + } else { + tracing::debug!( + "Worker failure attempts exhausted: {} {} {}", + err, + &msg_dest.id, + &endp.id + ); - queue_tx - .send( - QueueTask::MessageV1(MessageTask { - attempt_count: msg_task.attempt_count + 1, - ..msg_task - }), - Some(duration), + let msg_dest = messagedestination::ActiveModel { + status: Set(MessageStatus::Fail), + next_attempt: Set(None), + ..msg_dest.into() + }; + let _msg_dest = ctx!(msg_dest.update(*db).await)?; + + // Send common operational webhook + op_webhook_sender + .send_operational_webhook( + org_id, + OperationalWebhook::MessageAttemptExhausted(MessageAttemptEvent { + app_id: app_id.clone(), + app_uid: app_uid.cloned(), + endpoint_id: msg_task.endpoint_id.clone(), + msg_id: msg_task.msg_id.clone(), + msg_event_id: msg_uid.cloned(), + last_attempt: attempt.into(), + }), + ) + .await?; + + match process_endpoint_failure( + cache, + app_id, + org_id, + endp, + cfg.endpoint_failure_disable_after, + ) + .await? + { + None => Ok(()), + + Some(EndpointDisableInfo { first_failure_at }) => { + let endp = ctx!( + endpoint::Entity::secure_find_by_id( + msg_task.app_id.clone(), + msg_task.endpoint_id.clone(), ) - .await?; - } else { - tracing::debug!( - "Worker failure attempts exhausted: {} {} {}", - err, - &msg_dest.id, - &endp.id - ); - let msg_dest = messagedestination::ActiveModel { - status: Set(MessageStatus::Fail), - next_attempt: Set(None), - ..msg_dest.into() + .one(*db) + .await + )? + .ok_or_else(|| { + err_generic!("Endpoint not found {} {}", app_id, &msg_task.endpoint_id) + })?; + + let endp = endpoint::ActiveModel { + disabled: Set(true), + first_failure_at: Set(Some(first_failure_at.into())), + ..endp.into() }; - let _msg_dest = ctx!(msg_dest.update(db).await)?; + let _endp = ctx!(endp.update(*db).await)?; - // Send common operational webhook + // Send operational webhooks op_webhook_sender .send_operational_webhook( org_id, - OperationalWebhook::MessageAttemptExhausted(MessageAttemptEvent { - app_id: msg_task.app_id.clone(), + OperationalWebhook::EndpointDisabled(EndpointDisabledEvent { + app_id: app_id.clone(), app_uid: app_uid.cloned(), endpoint_id: msg_task.endpoint_id.clone(), - msg_id: msg_task.msg_id, - msg_event_id: msg_uid.cloned(), - last_attempt: attempt.into(), + // TODO: + endpoint_uid: None, + fail_since: first_failure_at, }), ) - .await?; - - match process_failure_cache( - cache, - &msg_task.app_id, - &msg_task.endpoint_id, - cfg.endpoint_failure_disable_after, - ) - .await? - { - None => {} - - Some(EndpointDisableInfo { first_failure_at }) => { - // Send operational webhooks - op_webhook_sender - .send_operational_webhook( - org_id, - OperationalWebhook::EndpointDisabled(EndpointDisabledEvent { - app_id: msg_task.app_id.clone(), - app_uid: app_uid.cloned(), - endpoint_id: msg_task.endpoint_id.clone(), - // TODO: - endpoint_uid: None, - fail_since: first_failure_at, - }), - ) - .await?; - - // Disable endpoint in DB - let endp = ctx!( - endpoint::Entity::secure_find_by_id( - msg_task.app_id.clone(), - msg_task.endpoint_id.clone(), - ) - .one(db) - .await - )? - .ok_or_else(|| { - err_generic!( - "Endpoint not found {} {}", - &msg_task.app_id, - &msg_task.endpoint_id - ) - })?; - - let endp = endpoint::ActiveModel { - disabled: Set(true), - first_failure_at: Set(Some(first_failure_at.into())), - ..endp.into() - }; - let _endp = ctx!(endp.update(db).await)?; - } - } + .await } } } - Ok(()) +} + +#[derive(Clone)] +struct DispatchContext<'a> { + msg_task: &'a MessageTask, + payload: &'a str, + endp: &'a CreateMessageEndpoint, + org_id: &'a OrganizationId, + app_id: &'a ApplicationId, + app_uid: Option<&'a ApplicationUid>, + msg_uid: Option<&'a MessageUid>, +} + +/// Dispatches one webhook +#[tracing::instrument( + skip_all, + level = "error", + fields( + endp_id = msg_task.endpoint_id.0.as_str(), + ) +)] +async fn dispatch_message_task( + worker_context: &WorkerContext<'_>, + msg: &message::Model, + app: &CreateMessageApp, + msg_task: MessageTask, + payload: &str, + endp: CreateMessageEndpoint, + msg_dest: Option, +) -> Result<()> { + let WorkerContext { cfg, db, .. } = worker_context; + + tracing::trace!("Dispatch start"); + + let msg_dest = if let Some(msg_dest) = msg_dest { + msg_dest + } else { + ctx!( + messagedestination::Entity::secure_find_by_msg(msg_task.msg_id.clone()) + .filter(messagedestination::Column::EndpId.eq(endp.id.clone())) + .one(*db) + .await + )? + .ok_or_else(|| err_generic!("Msg dest not found {} {}", msg_task.msg_id, endp.id))? + }; + + if (msg_dest.status != MessageStatus::Pending && msg_dest.status != MessageStatus::Sending) + && (msg_task.trigger_type != MessageAttemptTriggerType::Manual) + { + // TODO: it happens when this message destination is "resent". This leads to 2 queue tasks with the same message destination + tracing::warn!( + "MessageDestination {} is not pending (it's {:?}).", + msg_dest.id, + msg_dest.status + ); + return Ok(()); + } + + let dispatch_context = DispatchContext { + msg_task: &msg_task, + payload, + endp: &endp, + org_id: &app.org_id, + app_id: &app.id, + app_uid: app.uid.as_ref(), + msg_uid: msg.uid.as_ref(), + }; + + let dispatch = prepare_dispatch(worker_context, dispatch_context.clone()).await?; + let completed = match dispatch { + IncompleteDispatch::Pending(pending) => { + make_http_call( + dispatch_context.clone(), + pending, + &msg_dest, + &cfg.whitelist_subnets, + ) + .await? + } + IncompleteDispatch::Failed(failed) => CompletedDispatch::Failed(failed), + }; + + match completed { + CompletedDispatch::Successful(success) => { + handle_successful_dispatch(worker_context, dispatch_context, success, msg_dest).await + } + CompletedDispatch::Failed(failed) => { + handle_failed_dispatch(worker_context, dispatch_context, failed, msg_dest).await + } + } } fn bytes_to_string(bytes: bytes::Bytes) -> String { @@ -542,15 +716,33 @@ fn bytes_to_string(bytes: bytes::Bytes) -> String { } /// Manages preparation and execution of a QueueTask type -#[tracing::instrument(skip_all, fields(task_id = worker_context.task_id), level = "error")] -async fn process_task(worker_context: WorkerContext<'_>, queue_task: Arc) -> Result<()> { +#[tracing::instrument(skip_all, level = "error", fields(msg_id, app_id, org_id, instance_id, task_type=queue_task.to_string()))] +async fn process_queue_task( + worker_context: WorkerContext<'_>, + queue_task: QueueTask, +) -> Result<()> { + process_queue_task_inner(worker_context, queue_task) + .await + .map_err(|e| { + tracing::error!("{e}"); + e + }) +} + +/// Manages preparation and execution of a QueueTask type +async fn process_queue_task_inner( + worker_context: WorkerContext<'_>, + queue_task: QueueTask, +) -> Result<()> { let WorkerContext { db, cache, .. }: WorkerContext<'_> = worker_context; - if *queue_task == QueueTask::HealthCheck { + if queue_task == QueueTask::HealthCheck { return Ok(()); } - let (msg_id, trigger_type) = match &*queue_task { + let span = tracing::Span::current(); + + let (msg_id, trigger_type) = match &queue_task { QueueTask::MessageBatch(MessageTaskBatch { msg_id, trigger_type, @@ -565,12 +757,21 @@ async fn process_task(worker_context: WorkerContext<'_>, queue_task: Arc unreachable!(), }; + span.record("msg_id", &msg_id.0); + let msg = ctx!(message::Entity::find_by_id(msg_id.clone()).one(db).await)? - .ok_or_else(|| err_generic!("Unexpected: message doesn't exist {}", msg_id,))?; - let payload = msg.payload.as_ref().expect("Message payload is NULL"); + .ok_or_else(|| err_generic!("Unexpected: message doesn't exist"))?; + let payload = msg + .payload + .as_ref() + .and_then(|value| serde_json::to_string(value).ok()) + .ok_or_else(|| err_generic!("Message payload is NULL"))?; - let create_message_app = CreateMessageApp::layered_fetch( - cache.clone(), + span.record("app_id", &msg.app_id.0); + span.record("org_id", &msg.org_id.0); + + let create_message_app = match CreateMessageApp::layered_fetch( + cache, db, None, msg.org_id.clone(), @@ -578,14 +779,18 @@ async fn process_task(worker_context: WorkerContext<'_>, queue_task: Arc create_message_app, + None => { + tracing::info!("Application doesn't exist: {}", &msg.app_id); + return Ok(()); + } + }; let endpoints: Vec = create_message_app .filtered_endpoints(*trigger_type, &msg.event_type, msg.channels.as_ref()) .iter() - .filter(|endpoint| match &*queue_task { + .filter(|endpoint| match &queue_task { QueueTask::HealthCheck => unreachable!(), QueueTask::MessageV1(task) => task.endpoint_id == endpoint.id, QueueTask::MessageBatch(_) => true, @@ -593,87 +798,167 @@ async fn process_task(worker_context: WorkerContext<'_>, queue_task: Arc = endpoints - .into_iter() - .map(|endpoint| { - let task = match &*queue_task { - QueueTask::MessageV1(task) => task.clone(), - QueueTask::MessageBatch(MessageTaskBatch { - msg_id, - app_id, - trigger_type, - .. - }) => MessageTask { - msg_id: msg_id.clone(), - app_id: app_id.clone(), - endpoint_id: endpoint.id.clone(), - attempt_count: 0, - trigger_type: *trigger_type, - }, + let futures: Vec<_> = match &queue_task { + QueueTask::HealthCheck => unreachable!(), - QueueTask::HealthCheck => unreachable!(), + QueueTask::MessageV1(task) => { + let endpoint = match endpoints.into_iter().next() { + Some(ep) => ep, + None => { + return Ok(()); + } }; - let body = serde_json::to_string(&payload).expect("Error parsing message body"); - - dispatch( - worker_context, - task, - DispatchExtraIds { - org_id, - app_uid: app_uid.as_ref(), - msg_uid: msg_uid.as_ref(), - }, - body, + let destination = ctx!( + messagedestination::Entity::secure_find_by_msg(task.msg_id.clone()) + .filter(messagedestination::Column::EndpId.eq(endpoint.id.clone())) + .one(db) + .await + )? + .ok_or_else(|| { + err_generic!(format!( + "MessageDestination not found for message {}", + &task.msg_id + )) + })?; + + vec![dispatch_message_task( + &worker_context, + &msg, + &create_message_app, + task.clone(), + &payload, endpoint, - ) - }) - .collect(); + Some(destination), + )] + } + + QueueTask::MessageBatch(task) => { + let destinations: Vec<_> = endpoints + .iter() + .map(|endpoint| messagedestination::ActiveModel { + msg_id: Set(msg.id.clone()), + endp_id: Set(endpoint.id.clone()), + next_attempt: Set(Some(Utc::now().into())), + status: Set(MessageStatus::Sending), + ..Default::default() + }) + .collect(); + + ctx!( + messagedestination::Entity::insert_many(destinations.clone()) + .exec(db) + .await + )?; + + endpoints + .into_iter() + .zip(destinations) + .map(|(endpoint, destination)| { + let task = MessageTask { + msg_id: msg_id.clone(), + app_id: task.app_id.clone(), + endpoint_id: endpoint.id.clone(), + attempt_count: 0, + trigger_type: *trigger_type, + }; + + dispatch_message_task( + &worker_context, + &msg, + &create_message_app, + task, + &payload, + endpoint, + destination.try_into_model().ok(), + ) + }) + .collect() + } + }; let join = future::join_all(futures).await; let errs: Vec<_> = join.iter().filter(|x| x.is_err()).collect(); if !errs.is_empty() { - return Err(err_generic!( - "Some dispatches failed unexpectedly: {:?}", - errs - )); + return Err(err_generic!(format!( + "Some dispatches failed unexpectedly: {errs:?}" + ))); } Ok(()) } +lazy_static! { + pub static ref LAST_QUEUE_POLL: AtomicU64 = get_unix_timestamp().into(); +} + +async fn update_last_poll_time() { + LAST_QUEUE_POLL.swap(get_unix_timestamp(), Ordering::Relaxed); +} + /// Listens on the message queue for new tasks -pub async fn worker_loop( +#[allow(clippy::too_many_arguments)] +pub async fn queue_handler( cfg: &Configuration, - pool: &DatabaseConnection, cache: Cache, + db: DatabaseConnection, queue_tx: TaskQueueProducer, mut queue_rx: TaskQueueConsumer, op_webhook_sender: OperationalWebhookSender, ) -> Result<()> { + static NUM_WORKERS: AtomicUsize = AtomicUsize::new(0); + + let task_limit = cfg.worker_max_tasks; + if task_limit == 0 { + tracing::info!("Worker concurrent task limit: unlimited"); + } else { + tracing::info!("Worker concurrent task limit: {}", task_limit); + } + + tokio::spawn( + async move { + let mut interval = tokio::time::interval(Duration::from_millis(500)); + loop { + interval.tick().await; + let num_workers = NUM_WORKERS.load(Ordering::Relaxed); + if num_workers > 0 { + tracing::info!("{} active workers", num_workers); + } + } + } + .instrument(tracing::error_span!( + "worker_monitor", + instance_id = tracing::field::Empty + )), + ); + loop { + if task_limit > 0 { + let num_workers = NUM_WORKERS.load(Ordering::Relaxed); + if num_workers > task_limit.into() { + tokio::time::sleep(Duration::from_millis(100)).await; + continue; + } + } + if crate::SHUTTING_DOWN.load(Ordering::SeqCst) { + tokio::join!(async move { + let mut interval = tokio::time::interval(Duration::from_millis(500)); + loop { + interval.tick().await; + let num_workers = NUM_WORKERS.load(Ordering::Relaxed); + if num_workers > 0 { + tracing::info!( + "{} active workers, waiting to shut down worker.", + num_workers + ); + } else { + tracing::info!("No active workers, shutting down worker."); + break; + } + } + }); break; } @@ -681,42 +966,52 @@ pub async fn worker_loop( Ok(batch) => { for delivery in batch { let cfg = cfg.clone(); - let pool = pool.clone(); let cache = cache.clone(); + let db = db.clone(); let queue_tx = queue_tx.clone(); let queue_task = delivery.task.clone(); let op_webhook_sender = op_webhook_sender.clone(); tokio::spawn(async move { - let task_id = KsuidMs::new(None, None).to_string(); + NUM_WORKERS.fetch_add(1, Ordering::Relaxed); let worker_context = WorkerContext { - task_id: &task_id, cfg: &cfg, - db: &pool, + db: &db, cache: &cache, - queue_tx: &queue_tx, op_webhook_sender: &op_webhook_sender, + queue_tx: &queue_tx, }; - if let Err(err) = process_task(worker_context, queue_task).await { - tracing::error!("Error executing task: {}", err); - queue_tx - .nack(delivery) - .await - .expect("Error sending 'nack' to Redis after task execution error"); - } else { - queue_tx.ack(delivery).await.expect( - "Error sending 'ack' to Redis after successful task execution", + let queue_task = + Arc::try_unwrap(queue_task).unwrap_or_else(|arc| (*arc).clone()); + if process_queue_task(worker_context, queue_task) + .await + .is_err() + { + if let Err(err) = queue_tx.nack(delivery).await { + tracing::error!( + "Error sending 'nack' to Redis after task execution error: {}", + err + ); + } + } else if let Err(err) = queue_tx.ack(delivery).await { + tracing::error!( + "Error sending 'ack' to Redis after successful task execution: {}", + err ); } + + NUM_WORKERS.fetch_sub(1, Ordering::Relaxed); }); } } Err(err) => { - tracing::error!("Error receiving task: {}", err); - sleep(Duration::from_millis(10)).await; + tracing::error!("Error receiving task: {:?}", err); + sleep(tokio::time::Duration::from_millis(10)).await; } } + + update_last_poll_time().await; } Ok(()) @@ -741,7 +1036,7 @@ mod tests { /// Utility function that returns the default set of headers before configurable header are /// accounted for - fn mock_headers() -> (HeaderMap, MessageId) { + fn mock_headers() -> (CaseSensitiveHeaderMap, MessageId) { let id = MessageId::new(None, None); let signatures = sign_msg( @@ -760,23 +1055,21 @@ mod tests { WHITELABEL_HEADERS, None, ENDPOINT_URL, - ), + ) + .unwrap(), id, ) } - // Tests configurable headers with a valid and an invalid header. The valid header pair should - // be included, while the invalid pair should be skipped. #[test] - fn test_generate_msg_headers_with_custom_headers() { + fn test_generate_msg_headers() { // The headers to be given to [`generate_msg_headers`] let mut headers = HashMap::new(); headers.insert("test_key".to_owned(), "value".to_owned()); - headers.insert("invälid_key".to_owned(), "value".to_owned()); // The invalid key should be skipped over so it is not included in the expected let (mut expected, id) = mock_headers(); - let _ = expected.insert("test_key", "value".parse().unwrap()); + let _ = expected.insert("test_key".to_owned(), "value".parse().unwrap()); let signatures = sign_msg( &Encryption::new_noop(), @@ -793,7 +1086,8 @@ mod tests { WHITELABEL_HEADERS, Some(&EndpointHeaders(headers)), ENDPOINT_URL, - ); + ) + .unwrap(); assert_eq!(expected, actual); } @@ -828,7 +1122,8 @@ mod tests { WHITELABEL_HEADERS, None, ENDPOINT_URL, - ); + ) + .unwrap(); assert_eq!( actual.get("svix-signature").unwrap(), @@ -836,7 +1131,7 @@ mod tests { ); } - // Tests asemmtric signing keys + // Tests asymmetric signing keys #[test] fn test_asymmetric_key_signing() { let timestamp = 1614265330; diff --git a/server/svix-server/tests/integ_webhook_http_client.rs b/server/svix-server/tests/integ_webhook_http_client.rs new file mode 100644 index 000000000..afcc70db3 --- /dev/null +++ b/server/svix-server/tests/integ_webhook_http_client.rs @@ -0,0 +1,156 @@ +use std::{net::TcpListener, sync::Arc}; + +use axum::extract::State; +use http::{header::USER_AGENT, HeaderValue, Request, StatusCode, Version}; +use hyper::Body; +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; + +use svix_server::core::webhook_http_client::{Error, RequestBuilder, WebhookClient}; + +pub struct TestReceiver { + pub uri: String, + pub jh: tokio::task::JoinHandle<()>, + pub req_recv: mpsc::Receiver>, +} + +#[derive(Clone)] +struct TestAppState { + tx: mpsc::Sender>, + response_status_code: StatusCode, +} + +impl TestReceiver { + pub fn start(resp_code: StatusCode) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let uri = format!("http://{}/", listener.local_addr().unwrap()); + + let (tx, req_recv) = mpsc::channel(32); + + let routes = axum::Router::new() + .route("/", axum::routing::any(test_receiver_route)) + .with_state(TestAppState { + tx, + response_status_code: resp_code, + }) + .into_make_service(); + + let jh = tokio::spawn(async move { + axum::Server::from_tcp(listener) + .unwrap() + .serve(routes) + .await + .unwrap(); + }); + + TestReceiver { uri, jh, req_recv } + } +} + +async fn test_receiver_route( + State(TestAppState { + ref tx, + response_status_code, + }): State, + req: Request, +) -> axum::http::StatusCode { + tx.send(req).await.unwrap(); + response_status_code +} + +#[derive(Deserialize, Serialize)] +pub struct TestSerializable { + test: String, +} + +#[ignore] +#[tokio::test] +async fn test_client_basic_operation() { + // Compares output to `reqwest`. + + let our_client = WebhookClient::new(Some(Arc::new(vec!["127.0.0.1/0".parse().unwrap()])), None); + let reqwest_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Invalid reqwest Client configuration"); + + let mut receiver = TestReceiver::start(StatusCode::OK); + + let our_req = RequestBuilder::new() + .uri_str(&receiver.uri) + .unwrap() + .json_body(TestSerializable { + test: "value".to_owned(), + }) + .unwrap() + .version(Version::HTTP_11) + .build() + .unwrap(); + + let _resp = our_client.execute(our_req).await.unwrap(); + + let our_http_req = receiver.req_recv.recv().await.unwrap(); + + let _resp = reqwest_client + .post(&receiver.uri) + .header( + USER_AGENT, + HeaderValue::from_static(concat!("Svix-Webhooks/", env!("CARGO_PKG_VERSION"))), + ) + .version(Version::HTTP_11) + .json(&TestSerializable { + test: "value".to_owned(), + }) + .send() + .await + .unwrap(); + + let reqwest_http_req = receiver.req_recv.recv().await.unwrap(); + + assert_eq!(our_http_req.headers(), reqwest_http_req.headers()); + assert_eq!( + hyper::body::to_bytes(our_http_req.into_body()) + .await + .unwrap(), + hyper::body::to_bytes(reqwest_http_req.into_body()) + .await + .unwrap() + ); +} + +#[tokio::test] +async fn test_filtering() { + let our_client = WebhookClient::new(None, None); + + let our_req = RequestBuilder::new() + .uri_str("http://127.0.0.1/") + .unwrap() + .json_body(TestSerializable { + test: "value".to_owned(), + }) + .unwrap() + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert!(matches!( + our_client.execute(our_req).await.unwrap_err(), + Error::BlockedIp + )); + + let our_req = RequestBuilder::new() + .uri_str("http://localhost/") + .unwrap() + .json_body(TestSerializable { + test: "value".to_owned(), + }) + .unwrap() + .version(Version::HTTP_11) + .build() + .unwrap(); + + assert!(matches!( + our_client.execute(our_req).await.unwrap_err(), + Error::BlockedIp + )); +}