diff --git a/rust/openvasd/src/controller/context.rs b/rust/openvasd/src/controller/context.rs index 1b544aef5..8abafbdee 100644 --- a/rust/openvasd/src/controller/context.rs +++ b/rust/openvasd/src/controller/context.rs @@ -7,7 +7,7 @@ use std::sync::RwLock; use async_trait::async_trait; use storage::DefaultDispatcher; -use crate::{config, notus::NotusWrapper, response, scheduling}; +use crate::{config, notus::NotusWrapper, response, scheduling, tls::TlsConfig}; use models::scanner::{ Error, ScanDeleter, ScanResultFetcher, ScanResults, ScanStarter, ScanStopper, @@ -25,6 +25,7 @@ pub struct ContextBuilder { storage: DB, feed_config: Option, api_key: Option, + tls_config: Option, enable_get_scans: bool, marker: std::marker::PhantomData, response: response::Response, @@ -43,6 +44,7 @@ impl storage: crate::storage::inmemory::Storage::default(), feed_config: None, api_key: None, + tls_config: None, marker: std::marker::PhantomData, enable_get_scans: false, response: response::Response::default(), @@ -75,9 +77,12 @@ impl ContextBuilder { /// Sets the api key. pub fn api_key(mut self, api_key: impl Into>) -> Self { self.api_key = api_key.into(); - if self.api_key.is_some() { - self.response.add_authentication("x-api-key"); - } + self + } + + /// Set the TLS config + pub fn tls_config(mut self, tls_config: Option) -> Self { + self.tls_config = tls_config; self } @@ -106,6 +111,7 @@ impl ContextBuilder { storage: _, feed_config, api_key, + tls_config, enable_get_scans, marker, response, @@ -118,6 +124,7 @@ impl ContextBuilder { storage, feed_config, api_key, + tls_config, enable_get_scans, marker, response, @@ -141,6 +148,7 @@ impl ContextBuilder { let Self { feed_config, api_key, + tls_config, enable_get_scans, scanner: _, marker: _, @@ -156,6 +164,7 @@ impl ContextBuilder { feed_config, marker: std::marker::PhantomData, api_key, + tls_config, enable_get_scans, response, notus, @@ -166,7 +175,31 @@ impl ContextBuilder { } impl ContextBuilder> { - pub fn build(self) -> Context { + fn configure_authentication_methods(&mut self) { + let tls_config = if let Some(tls_config) = self.tls_config.take() { + if tls_config.has_clients && self.api_key.is_some() { + tracing::warn!("Client certificates and api key are configured. To disable the possibility to bypass client verification the API key is ignored."); + self.api_key = None; + } + Some(tls_config) + } else { + None + }; + self.tls_config = tls_config; + match (self.tls_config.is_some(), self.api_key.is_some()) { + (true, true) => unreachable!(), + (true, false) => self.response.add_authentication("mTLS"), + (false, true) => self.response.add_authentication("x-api-key"), + (false, false) => { + tracing::warn!( + "Neither mTLS nor an API key are set. /scans endpoint is unsecured." + ); + } + } + } + + pub fn build(mut self) -> Context { + self.configure_authentication_methods(); let scheduler = scheduling::Scheduler::new( self.scheduler_config.unwrap_or_default(), self.scanner.0, @@ -178,6 +211,7 @@ impl ContextBuilder> { feed_config: self.feed_config, abort: Default::default(), api_key: self.api_key, + tls_config: self.tls_config, enable_get_scans: self.enable_get_scans, notus: self.notus, mode: self.mode, @@ -196,6 +230,7 @@ pub struct Context { /// /// When none api key is set, no authentication is required. pub api_key: Option, + pub tls_config: Option, /// Whether to enable the GET /scans endpoint pub enable_get_scans: bool, pub mode: config::Mode, diff --git a/rust/openvasd/src/controller/mod.rs b/rust/openvasd/src/controller/mod.rs index 119589e1e..97b6207b7 100644 --- a/rust/openvasd/src/controller/mod.rs +++ b/rust/openvasd/src/controller/mod.rs @@ -12,10 +12,7 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{ - config, - tls::{self}, -}; +use crate::config; pub use context::{Context, ContextBuilder, NoOpScanner}; use hyper_util::rt::{TokioExecutor, TokioIo}; use tokio::net::TcpListener; @@ -59,6 +56,7 @@ fn retrieve_and_reset(id: Arc>) -> ClientIdentifier { *ci = ClientIdentifier::Unknown; cci } + pub async fn run<'a, S, DB>( mut ctx: Context, config: &config::Config, @@ -73,24 +71,11 @@ where + 'static, DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync, { - let tlsc = { - if let Some((c, conf, has_clients)) = tls::tls_config(config)? { - if has_clients && ctx.api_key.is_some() { - tracing::warn!("Client certificates and api key are configured. To disable the possibility to bypass client verification the API key is ignored."); - ctx.api_key = None; - } - Some((c, conf)) - } else { - None - } - }; - if tlsc.is_none() && ctx.api_key.is_none() { - tracing::warn!("Neither mTLS nor an API key are set. /scans endpoint is unsecured."); - } let addr = config.listener.address; let addr: SocketAddr = addr; let incoming = TcpListener::bind(&addr).await?; + let tls_config = ctx.tls_config.take(); let controller = std::sync::Arc::new(ctx); tracing::info!(?config.mode, "running in"); if config.mode == config::Mode::Service { @@ -98,18 +83,18 @@ where } tokio::spawn(crate::controller::feed::fetch(Arc::clone(&controller))); - if let Some((ci, conf)) = tlsc { + if let Some(tls_config) = tls_config { use hyper::server::conn::http2::Builder; tracing::info!("listening on https://{}", addr); - let config = Arc::new(conf); + let config = Arc::new(tls_config.config); let tls_acceptor = tokio_rustls::TlsAcceptor::from(config); loop { let (tcp_stream, _remote_addr) = incoming.accept().await?; let tls_acceptor = tls_acceptor.clone(); - let identifier = ci.clone(); + let identifier = tls_config.client_identifier.clone(); let ctx = controller.clone(); tokio::spawn(async move { let tls_stream = match tls_acceptor.accept(tcp_stream).await { diff --git a/rust/openvasd/src/main.rs b/rust/openvasd/src/main.rs index 7907c805a..c18effcf5 100644 --- a/rust/openvasd/src/main.rs +++ b/rust/openvasd/src/main.rs @@ -24,7 +24,7 @@ fn create_context( db: DB, sh: ScanHandler, config: &config::Config, -) -> controller::Context +) -> Result, Box> where ScanHandler: ScanStarter + ScanStopper @@ -46,15 +46,18 @@ where Err(e) => tracing::warn!("Notus Scanner disabled: {e}"), } - ctx_builder + let tls_config = tls::tls_config(config)?; + + Ok(ctx_builder .mode(config.mode.clone()) .scheduler_config(config.scheduler.clone()) .feed_config(config.feed.clone()) .scanner(sh) .api_key(config.endpoints.key.clone()) + .tls_config(tls_config) .enable_get_scans(config.endpoints.enable_get_scans) .storage(db) - .build() + .build()) } async fn run( @@ -92,7 +95,7 @@ where storage::redis::Storage::new(ic, config.storage.redis.url.clone(), feeds), scanner, config, - ); + )?; controller::run(ctx, config).await } config::StorageType::InMemory => { @@ -102,7 +105,7 @@ where storage::inmemory::Storage::new(crate::crypt::ChaCha20Crypt::default(), feeds), scanner, config, - ); + )?; controller::run(ctx, config).await } config::StorageType::FileSystem => { @@ -115,7 +118,7 @@ where storage::file::encrypted(&config.storage.fs.path, key, feeds)?, scanner, config, - ); + )?; controller::run(ctx, config).await } else { tracing::warn!( @@ -125,7 +128,7 @@ where storage::file::unencrypted(&config.storage.fs.path, feeds)?, scanner, config, - ); + )?; controller::run(ctx, config).await } } diff --git a/rust/openvasd/src/tls.rs b/rust/openvasd/src/tls.rs index 54f6b124e..76d51dc5e 100644 --- a/rust/openvasd/src/tls.rs +++ b/rust/openvasd/src/tls.rs @@ -165,9 +165,14 @@ pub fn config_to_tls_paths( ))) } -pub type TlsData = (Arc>, ServerConfig, bool); +#[derive(Debug)] +pub struct TlsConfig { + pub client_identifier: Arc>, + pub config: ServerConfig, + pub has_clients: bool, +} -pub fn tls_config(config: &crate::config::Config) -> Result, Error> { +pub fn tls_config(config: &crate::config::Config) -> Result, Error> { let (key, certs, clients) = match config_to_tls_paths(config)? { Some(x) => x, None => return Ok(None), @@ -191,7 +196,11 @@ pub fn tls_config(config: &crate::config::Config) -> Result, Err config.alpn_protocols = vec![b"h2".to_vec()]; - Ok(Some((client_identifier, config, !clients.is_empty()))) + Ok(Some(TlsConfig { + client_identifier, + config, + has_clients: !clients.is_empty(), + })) } else { match &config.tls.client_certs { Some(clicerts) => { @@ -222,7 +231,11 @@ pub fn tls_config(config: &crate::config::Config) -> Result, Err config.alpn_protocols = vec![b"h2".to_vec()]; - Ok(Some((client_identifier, config, !clients.is_empty()))) + Ok(Some(TlsConfig { + client_identifier, + config, + has_clients: !clients.is_empty(), + })) } }