Skip to content

Commit

Permalink
refactor: reduce settings duplication w/ deserialize_with
Browse files Browse the repository at this point in the history
also kill the unused router ssl settings
  • Loading branch information
pjenvey committed Mar 30, 2023
1 parent fa9109d commit 4f3e450
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 111 deletions.
93 changes: 58 additions & 35 deletions autoconnect/autoconnect-settings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ extern crate slog;
extern crate slog_scope;
extern crate serde_derive;

use std::io;
use std::net::ToSocketAddrs;
use std::{io, net::ToSocketAddrs, time::Duration};

use config::{Config, ConfigError, Environment, File};
use fernet::Fernet;
use lazy_static::lazy_static;
use serde_derive::Deserialize;
use serde::{Deserialize, Deserializer};

pub const ENV_PREFIX: &str = "autoconnect";

Expand Down Expand Up @@ -52,20 +51,20 @@ pub struct Settings {
pub router_port: u16,
/// The DNS name to use for internal routing
pub router_hostname: Option<String>,
/// TLS key for internal router connections (may be empty if ELB is used)
pub router_ssl_key: Option<String>,
/// TLS certificate for internal router connections (may be empty if ELB is used)
pub router_ssl_cert: Option<String>,
/// TLS DiffieHellman parameter for internal router connections
pub router_ssl_dh_param: Option<String>,
/// The server based ping interval (also used for Broadcast sends)
pub auto_ping_interval: f64,
#[serde(deserialize_with = "deserialize_f64_to_duration")]
pub auto_ping_interval: Duration,
/// How long to wait for a response Pong before being timed out and connection drop
pub auto_ping_timeout: f64,
#[serde(deserialize_with = "deserialize_f64_to_duration")]
pub auto_ping_timeout: Duration,
/// Max number of websocket connections to allow
pub max_connections: u32,
/// How long to wait for the initial connection handshake.
#[serde(deserialize_with = "deserialize_u32_to_duration")]
pub open_handshake_timeout: Duration,
/// How long to wait while closing a connection for the response handshake.
pub close_handshake_timeout: u32,
#[serde(deserialize_with = "deserialize_u32_to_duration")]
pub close_handshake_timeout: Duration,
/// The URL scheme (http/https) for the endpoint URL
pub endpoint_scheme: String,
/// The host url for the endpoint URL (differs from `hostname` and `resolve_hostname`)
Expand All @@ -89,7 +88,8 @@ pub struct Settings {
/// Broadcast token for authentication
pub megaphone_api_token: Option<String>,
/// How often to poll the server for new data
pub megaphone_poll_interval: u32,
#[serde(deserialize_with = "deserialize_u32_to_duration")]
pub megaphone_poll_interval: Duration,
/// Use human readable (simplified, non-JSON)
pub human_logs: bool,
/// Maximum allowed number of backlogged messages. Exceeding this number will
Expand All @@ -109,13 +109,11 @@ impl Default for Settings {
resolve_hostname: false,
router_port: 8081,
router_hostname: None,
router_ssl_key: None,
router_ssl_cert: None,
router_ssl_dh_param: None,
auto_ping_interval: 300.0,
auto_ping_timeout: 4.0,
auto_ping_interval: Duration::from_secs(300),
auto_ping_timeout: Duration::from_secs(4),
max_connections: 0,
close_handshake_timeout: 0,
open_handshake_timeout: Duration::from_secs(5),
close_handshake_timeout: Duration::from_secs(0),
endpoint_scheme: "http".to_owned(),
endpoint_hostname: "localhost".to_owned(),
endpoint_port: 8082,
Expand All @@ -127,7 +125,7 @@ impl Default for Settings {
db_settings: "".to_owned(),
megaphone_api_url: None,
megaphone_api_token: None,
megaphone_poll_interval: 30,
megaphone_poll_interval: Duration::from_secs(30),
human_logs: false,
msg_limit: 100,
max_pending_notification_queue: 10,
Expand All @@ -149,15 +147,13 @@ impl Settings {
s = s.add_source(Environment::with_prefix(&ENV_PREFIX.to_uppercase()).separator("__"));

let built = s.build()?;
built.try_deserialize::<Settings>()
let s = built.try_deserialize::<Settings>()?;
s.validate()?;
Ok(s)
}

pub fn router_url(&self) -> String {
let router_scheme = if self.router_ssl_key.is_none() {
"http"
} else {
"https"
};
let router_scheme = "http";
let url = format!(
"{}://{}",
router_scheme,
Expand Down Expand Up @@ -195,6 +191,41 @@ impl Settings {
HOSTNAME.clone()
}
}

pub fn validate(&self) -> Result<(), ConfigError> {
let non_zero = |val: Duration, name| {
if val.is_zero() {
return Err(ConfigError::Message(format!(
"Invalid {}_{}: cannot be 0",
ENV_PREFIX, name
)));
}
Ok(())
};
non_zero(self.megaphone_poll_interval, "MEGAPHONE_POLL_INTERVAL")?;
non_zero(self.auto_ping_interval, "AUTO_PING_INTERVAL")?;
non_zero(self.auto_ping_timeout, "AUTO_PING_TIMEOUT")?;
Ok(())
}
}

fn deserialize_u32_to_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let seconds: u32 = Deserialize::deserialize(deserializer)?;
Ok(Duration::from_secs(seconds.into()))
}

fn deserialize_f64_to_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let seconds: f64 = Deserialize::deserialize(deserializer)?;
Ok(Duration::new(
seconds as u64,
(seconds.fract() * 1_000_000_000.0) as u32,
))
}

#[cfg(test)]
Expand All @@ -215,15 +246,6 @@ mod tests {
settings.router_port = 8080;
let url = settings.router_url();
assert_eq!("http://testname:8080", url);

settings.router_port = 443;
settings.router_ssl_key = Some("key".to_string());
let url = settings.router_url();
assert_eq!("https://testname", url);

settings.router_port = 8080;
let url = settings.router_url();
assert_eq!("https://testname:8080", url);
}

#[test]
Expand Down Expand Up @@ -272,6 +294,7 @@ mod tests {
&settings.crypto_key,
"[mqCGb8D-N7mqx6iWJov9wm70Us6kA9veeXdb8QUuzLQ=]"
);
assert_eq!(settings.open_handshake_timeout, Duration::from_secs(5));

// reset (just in case)
if let Ok(p) = v1 {
Expand Down
94 changes: 18 additions & 76 deletions autoconnect/autoconnect-settings/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,35 @@
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use cadence::StatsdClient;
use fernet::{Fernet, MultiFernet};

use crate::{Settings, ENV_PREFIX};
use autoconnect_common::registry::ClientRegistry;
use autopush_common::db::{client::DbClient, dynamodb::DdbClientImpl, DbSettings, StorageType};
use autopush_common::{
errors::{ApcErrorKind, Result},
metrics::new_metrics,
};

fn ito_dur(seconds: u32) -> Option<Duration> {
if seconds == 0 {
None
} else {
Some(Duration::new(seconds.into(), 0))
}
}

fn fto_dur(seconds: f64) -> Option<Duration> {
if seconds == 0.0 {
None
} else {
Some(Duration::new(
seconds as u64,
(seconds.fract() * 1_000_000_000.0) as u32,
))
}
}
use crate::{Settings, ENV_PREFIX};

/// A thread safe set of options specific to the Server. These are compiled from [crate::Settings]
#[derive(Clone)]
pub struct AppState {
pub router_port: u16,
pub port: u16,
/// Encryption object for the endpoint URL
pub fernet: MultiFernet,
pub metrics: Arc<StatsdClient>,
/// Handle to the data storage object
pub db_client: Box<dyn DbClient>,
pub ssl_key: Option<PathBuf>,
pub ssl_cert: Option<PathBuf>,
pub ssl_dh_param: Option<PathBuf>,
pub open_handshake_timeout: Option<Duration>,
pub auto_ping_interval: Duration,
pub auto_ping_timeout: Duration,
pub max_connections: Option<u32>,
pub close_handshake_timeout: Option<Duration>,
pub metrics: Arc<StatsdClient>,

/// Encryption object for the endpoint URL
pub fernet: MultiFernet,
/// The connected WebSocket clients
pub registry: Arc<ClientRegistry>,

pub settings: Settings,
pub router_url: String,
pub endpoint_url: String,
pub statsd_host: Option<String>,
pub statsd_port: u16,
pub megaphone_api_url: Option<String>,
pub megaphone_api_token: Option<String>,
pub megaphone_poll_interval: Duration,
pub human_logs: bool,
pub msg_limit: u32,
pub registry: Arc<ClientRegistry>,
pub max_pending_notification_queue: usize,
}

impl AppState {
pub fn from_settings(settings: &Settings) -> Result<Self> {
pub fn from_settings(settings: Settings) -> Result<Self> {
let crypto_key = &settings.crypto_key;
if !(crypto_key.starts_with('[') && crypto_key.ends_with(']')) {
return Err(
Expand All @@ -90,9 +55,6 @@ impl AppState {
settings.statsd_port,
)?);

let router_url = settings.router_url();
let endpoint_url = settings.endpoint_url();

let db_settings = DbSettings {
dsn: settings.db_dsn.clone(),
db_settings: settings.db_settings.clone(),
Expand All @@ -101,38 +63,18 @@ impl AppState {
StorageType::DynamoDb => Box::new(DdbClientImpl::new(metrics.clone(), &db_settings)?),
StorageType::INVALID => panic!("Invalid Storage type. Check {}_DB_DSN.", ENV_PREFIX),
};

let router_url = settings.router_url();
let endpoint_url = settings.endpoint_url();

Ok(Self {
port: settings.port,
fernet,
metrics,
db_client,
router_port: settings.router_port,
statsd_host: settings.statsd_host.clone(),
statsd_port: settings.statsd_port,
metrics,
fernet,
registry: Arc::new(ClientRegistry::default()),
settings,
router_url,
endpoint_url,
ssl_key: settings.router_ssl_key.clone().map(PathBuf::from),
ssl_cert: settings.router_ssl_cert.clone().map(PathBuf::from),
ssl_dh_param: settings.router_ssl_dh_param.clone().map(PathBuf::from),
auto_ping_interval: fto_dur(settings.auto_ping_interval)
.expect("auto ping interval cannot be 0"),
auto_ping_timeout: fto_dur(settings.auto_ping_timeout)
.expect("auto ping timeout cannot be 0"),
close_handshake_timeout: ito_dur(settings.close_handshake_timeout),
max_connections: if settings.max_connections == 0 {
None
} else {
Some(settings.max_connections)
},
open_handshake_timeout: ito_dur(5),
megaphone_api_url: settings.megaphone_api_url.clone(),
megaphone_api_token: settings.megaphone_api_token.clone(),
megaphone_poll_interval: ito_dur(settings.megaphone_poll_interval)
.expect("megaphone poll interval cannot be 0"),
human_logs: settings.human_logs,
msg_limit: settings.msg_limit,
registry: Arc::new(ClientRegistry::default()),
max_pending_notification_queue: settings.max_pending_notification_queue as usize,
})
}
}

0 comments on commit 4f3e450

Please sign in to comment.