Skip to content

Commit

Permalink
Fix: Add 'mTLS' to response header.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tehforsch authored and nichtsfrei committed Jul 11, 2024
1 parent b6350b0 commit 155da37
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 37 deletions.
45 changes: 40 additions & 5 deletions rust/openvasd/src/controller/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::sync::RwLock;
use async_trait::async_trait;
use storage::DefaultDispatcher;

use crate::{config, notus::NotusWrapper, response, scheduling};
use crate::{config, notus::NotusWrapper, response, scheduling, tls::TlsConfig};

use models::scanner::{
Error, ScanDeleter, ScanResultFetcher, ScanResults, ScanStarter, ScanStopper,
Expand All @@ -25,6 +25,7 @@ pub struct ContextBuilder<S, DB, T> {
storage: DB,
feed_config: Option<crate::config::Feed>,
api_key: Option<String>,
tls_config: Option<TlsConfig>,
enable_get_scans: bool,
marker: std::marker::PhantomData<S>,
response: response::Response,
Expand All @@ -43,6 +44,7 @@ impl<S>
storage: crate::storage::inmemory::Storage::default(),
feed_config: None,
api_key: None,
tls_config: None,
marker: std::marker::PhantomData,
enable_get_scans: false,
response: response::Response::default(),
Expand Down Expand Up @@ -75,9 +77,12 @@ impl<S, DB, T> ContextBuilder<S, DB, T> {
/// Sets the api key.
pub fn api_key(mut self, api_key: impl Into<Option<String>>) -> Self {
self.api_key = api_key.into();
if self.api_key.is_some() {
self.response.add_authentication("x-api-key");
}
self
}

/// Set the TLS config
pub fn tls_config(mut self, tls_config: Option<TlsConfig>) -> Self {
self.tls_config = tls_config;
self
}

Expand Down Expand Up @@ -106,6 +111,7 @@ impl<S, DB, T> ContextBuilder<S, DB, T> {
storage: _,
feed_config,
api_key,
tls_config,
enable_get_scans,
marker,
response,
Expand All @@ -118,6 +124,7 @@ impl<S, DB, T> ContextBuilder<S, DB, T> {
storage,
feed_config,
api_key,
tls_config,
enable_get_scans,
marker,
response,
Expand All @@ -141,6 +148,7 @@ impl<S, DB> ContextBuilder<S, DB, NoScanner> {
let Self {
feed_config,
api_key,
tls_config,
enable_get_scans,
scanner: _,
marker: _,
Expand All @@ -156,6 +164,7 @@ impl<S, DB> ContextBuilder<S, DB, NoScanner> {
feed_config,
marker: std::marker::PhantomData,
api_key,
tls_config,
enable_get_scans,
response,
notus,
Expand All @@ -166,7 +175,31 @@ impl<S, DB> ContextBuilder<S, DB, NoScanner> {
}

impl<S, DB> ContextBuilder<S, DB, Scanner<S>> {
pub fn build(self) -> Context<S, DB> {
fn configure_authentication_methods(&mut self) {
let tls_config = if let Some(tls_config) = self.tls_config.take() {
if tls_config.has_clients && self.api_key.is_some() {
tracing::warn!("Client certificates and api key are configured. To disable the possibility to bypass client verification the API key is ignored.");
self.api_key = None;
}
Some(tls_config)
} else {
None
};
self.tls_config = tls_config;
match (self.tls_config.is_some(), self.api_key.is_some()) {
(true, true) => unreachable!(),
(true, false) => self.response.add_authentication("mTLS"),
(false, true) => self.response.add_authentication("x-api-key"),
(false, false) => {
tracing::warn!(
"Neither mTLS nor an API key are set. /scans endpoint is unsecured."
);
}
}
}

pub fn build(mut self) -> Context<S, DB> {
self.configure_authentication_methods();
let scheduler = scheduling::Scheduler::new(
self.scheduler_config.unwrap_or_default(),
self.scanner.0,
Expand All @@ -178,6 +211,7 @@ impl<S, DB> ContextBuilder<S, DB, Scanner<S>> {
feed_config: self.feed_config,
abort: Default::default(),
api_key: self.api_key,
tls_config: self.tls_config,
enable_get_scans: self.enable_get_scans,
notus: self.notus,
mode: self.mode,
Expand All @@ -196,6 +230,7 @@ pub struct Context<S, DB> {
///
/// When none api key is set, no authentication is required.
pub api_key: Option<String>,
pub tls_config: Option<TlsConfig>,
/// Whether to enable the GET /scans endpoint
pub enable_get_scans: bool,
pub mode: config::Mode,
Expand Down
27 changes: 6 additions & 21 deletions rust/openvasd/src/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ use std::{
sync::{Arc, RwLock},
};

use crate::{
config,
tls::{self},
};
use crate::config;
pub use context::{Context, ContextBuilder, NoOpScanner};
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpListener;
Expand Down Expand Up @@ -59,6 +56,7 @@ fn retrieve_and_reset(id: Arc<RwLock<ClientIdentifier>>) -> ClientIdentifier {
*ci = ClientIdentifier::Unknown;
cci
}

pub async fn run<'a, S, DB>(
mut ctx: Context<S, DB>,
config: &config::Config,
Expand All @@ -73,43 +71,30 @@ where
+ 'static,
DB: crate::storage::Storage + std::marker::Send + 'static + std::marker::Sync,
{
let tlsc = {
if let Some((c, conf, has_clients)) = tls::tls_config(config)? {
if has_clients && ctx.api_key.is_some() {
tracing::warn!("Client certificates and api key are configured. To disable the possibility to bypass client verification the API key is ignored.");
ctx.api_key = None;
}
Some((c, conf))
} else {
None
}
};
if tlsc.is_none() && ctx.api_key.is_none() {
tracing::warn!("Neither mTLS nor an API key are set. /scans endpoint is unsecured.");
}
let addr = config.listener.address;
let addr: SocketAddr = addr;
let incoming = TcpListener::bind(&addr).await?;

let tls_config = ctx.tls_config.take();
let controller = std::sync::Arc::new(ctx);
tracing::info!(?config.mode, "running in");
if config.mode == config::Mode::Service {
tokio::spawn(crate::controller::results::fetch(Arc::clone(&controller)));
}
tokio::spawn(crate::controller::feed::fetch(Arc::clone(&controller)));

if let Some((ci, conf)) = tlsc {
if let Some(tls_config) = tls_config {
use hyper::server::conn::http2::Builder;
tracing::info!("listening on https://{}", addr);

let config = Arc::new(conf);
let config = Arc::new(tls_config.config);
let tls_acceptor = tokio_rustls::TlsAcceptor::from(config);

loop {
let (tcp_stream, _remote_addr) = incoming.accept().await?;

let tls_acceptor = tls_acceptor.clone();
let identifier = ci.clone();
let identifier = tls_config.client_identifier.clone();
let ctx = controller.clone();
tokio::spawn(async move {
let tls_stream = match tls_acceptor.accept(tcp_stream).await {
Expand Down
17 changes: 10 additions & 7 deletions rust/openvasd/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn create_context<DB, ScanHandler>(
db: DB,
sh: ScanHandler,
config: &config::Config,
) -> controller::Context<ScanHandler, DB>
) -> Result<controller::Context<ScanHandler, DB>, Box<dyn std::error::Error + Send + Sync>>
where
ScanHandler: ScanStarter
+ ScanStopper
Expand All @@ -46,15 +46,18 @@ where
Err(e) => tracing::warn!("Notus Scanner disabled: {e}"),
}

ctx_builder
let tls_config = tls::tls_config(config)?;

Ok(ctx_builder
.mode(config.mode.clone())
.scheduler_config(config.scheduler.clone())
.feed_config(config.feed.clone())
.scanner(sh)
.api_key(config.endpoints.key.clone())
.tls_config(tls_config)
.enable_get_scans(config.endpoints.enable_get_scans)
.storage(db)
.build()
.build())
}

async fn run<S>(
Expand Down Expand Up @@ -92,7 +95,7 @@ where
storage::redis::Storage::new(ic, config.storage.redis.url.clone(), feeds),
scanner,
config,
);
)?;
controller::run(ctx, config).await
}
config::StorageType::InMemory => {
Expand All @@ -102,7 +105,7 @@ where
storage::inmemory::Storage::new(crate::crypt::ChaCha20Crypt::default(), feeds),
scanner,
config,
);
)?;
controller::run(ctx, config).await
}
config::StorageType::FileSystem => {
Expand All @@ -115,7 +118,7 @@ where
storage::file::encrypted(&config.storage.fs.path, key, feeds)?,
scanner,
config,
);
)?;
controller::run(ctx, config).await
} else {
tracing::warn!(
Expand All @@ -125,7 +128,7 @@ where
storage::file::unencrypted(&config.storage.fs.path, feeds)?,
scanner,
config,
);
)?;
controller::run(ctx, config).await
}
}
Expand Down
21 changes: 17 additions & 4 deletions rust/openvasd/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,14 @@ pub fn config_to_tls_paths(
)))
}

pub type TlsData = (Arc<RwLock<ClientIdentifier>>, ServerConfig, bool);
#[derive(Debug)]
pub struct TlsConfig {
pub client_identifier: Arc<RwLock<ClientIdentifier>>,
pub config: ServerConfig,
pub has_clients: bool,
}

pub fn tls_config(config: &crate::config::Config) -> Result<Option<TlsData>, Error> {
pub fn tls_config(config: &crate::config::Config) -> Result<Option<TlsConfig>, Error> {
let (key, certs, clients) = match config_to_tls_paths(config)? {
Some(x) => x,
None => return Ok(None),
Expand All @@ -191,7 +196,11 @@ pub fn tls_config(config: &crate::config::Config) -> Result<Option<TlsData>, Err

config.alpn_protocols = vec![b"h2".to_vec()];

Ok(Some((client_identifier, config, !clients.is_empty())))
Ok(Some(TlsConfig {
client_identifier,
config,
has_clients: !clients.is_empty(),
}))
} else {
match &config.tls.client_certs {
Some(clicerts) => {
Expand Down Expand Up @@ -222,7 +231,11 @@ pub fn tls_config(config: &crate::config::Config) -> Result<Option<TlsData>, Err

config.alpn_protocols = vec![b"h2".to_vec()];

Ok(Some((client_identifier, config, !clients.is_empty())))
Ok(Some(TlsConfig {
client_identifier,
config,
has_clients: !clients.is_empty(),
}))
}
}

Expand Down

0 comments on commit 155da37

Please sign in to comment.