diff --git a/sim-lib/src/defined_activity.rs b/sim-lib/src/defined_activity.rs index e259d1de..c27cf0eb 100644 --- a/sim-lib/src/defined_activity.rs +++ b/sim-lib/src/defined_activity.rs @@ -1,4 +1,6 @@ -use crate::{DestinationGenerator, NodeInfo, PaymentGenerationError, PaymentGenerator}; +use crate::{ + DestinationGenerator, NodeInfo, PaymentGenerationError, PaymentGenerator, ValueOrRange, +}; use std::fmt; use tokio::time::Duration; @@ -7,8 +9,8 @@ pub struct DefinedPaymentActivity { destination: NodeInfo, start: Duration, count: Option, - wait: Duration, - amount: u64, + wait: ValueOrRange, + amount: ValueOrRange, } impl DefinedPaymentActivity { @@ -16,8 +18,8 @@ impl DefinedPaymentActivity { destination: NodeInfo, start: Duration, count: Option, - wait: Duration, - amount: u64, + wait: ValueOrRange, + amount: ValueOrRange, ) -> Self { DefinedPaymentActivity { destination, @@ -33,7 +35,7 @@ impl fmt::Display for DefinedPaymentActivity { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "static payment of {} to {} every {:?}", + "static payment of {} to {} every {}s", self.amount, self.destination, self.wait ) } @@ -55,7 +57,7 @@ impl PaymentGenerator for DefinedPaymentActivity { } fn next_payment_wait(&self) -> Duration { - self.wait + Duration::from_secs(self.wait.value() as u64) } fn payment_amount( @@ -67,7 +69,7 @@ impl PaymentGenerator for DefinedPaymentActivity { "destination amount must not be set for defined activity generator".to_string(), )) } else { - Ok(self.amount) + Ok(self.amount.value()) } } } @@ -75,9 +77,9 @@ impl PaymentGenerator for DefinedPaymentActivity { #[cfg(test)] mod tests { use super::DefinedPaymentActivity; + use super::*; use crate::test_utils::{create_nodes, get_random_keypair}; use crate::{DestinationGenerator, PaymentGenerationError, PaymentGenerator}; - use std::time::Duration; #[test] fn test_defined_activity_generator() { @@ -91,8 +93,8 @@ mod tests { node.clone(), Duration::from_secs(0), None, - Duration::from_secs(60), - payment_amt, + crate::ValueOrRange::Value(60), + crate::ValueOrRange::Value(payment_amt), ); let (dest, dest_capacity) = generator.choose_destination(source.1); diff --git a/sim-lib/src/lib.rs b/sim-lib/src/lib.rs index 503ebb3a..716682ec 100644 --- a/sim-lib/src/lib.rs +++ b/sim-lib/src/lib.rs @@ -4,6 +4,7 @@ use bitcoin::Network; use csv::WriterBuilder; use lightning::ln::features::NodeFeatures; use lightning::ln::PaymentHash; +use rand::Rng; use random_activity::RandomActivityError; use serde::{Deserialize, Serialize}; use std::collections::HashSet; @@ -129,6 +130,47 @@ pub struct SimParams { pub activity: Vec, } +/// Either a value or a range parsed from the simulation file. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ValueOrRange { + Value(T), + Range(T, T), +} + +impl ValueOrRange +where + T: std::cmp::PartialOrd + rand_distr::uniform::SampleUniform + Copy, +{ + /// Get the enclosed value. If value is defined as a range, sample from it uniformly at random. + pub fn value(&self) -> T { + match self { + ValueOrRange::Value(x) => *x, + ValueOrRange::Range(x, y) => { + let mut rng = rand::thread_rng(); + rng.gen_range(*x..*y) + }, + } + } +} + +impl Display for ValueOrRange +where + T: Display, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ValueOrRange::Value(x) => write!(f, "{x}"), + ValueOrRange::Range(x, y) => write!(f, "({x}-{y})"), + } + } +} + +/// The payment amount in msat. Either a value or a range. +type Amount = ValueOrRange; +/// The interval of seconds between payments. Either a value or a range. +type Interval = ValueOrRange; + /// Data structure used to parse information from the simulation file. It allows source and destination to be /// [NodeId], which enables the use of public keys and aliases in the simulation description. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -146,9 +188,11 @@ pub struct ActivityParser { #[serde(default)] pub count: Option, /// The interval of the event, as in every how many seconds the payment is performed. - pub interval_secs: u16, + #[serde(with = "serializers::serde_value_or_range")] + pub interval_secs: Interval, /// The amount of m_sat to used in this payment. - pub amount_msat: u64, + #[serde(with = "serializers::serde_value_or_range")] + pub amount_msat: Amount, } /// Data structure used internally by the simulator. Both source and destination are represented as [PublicKey] here. @@ -164,9 +208,9 @@ pub struct ActivityDefinition { /// The number of payments to send over the course of the simulation. pub count: Option, /// The interval of the event, as in every how many seconds the payment is performed. - pub interval_secs: u16, + pub interval_secs: Interval, /// The amount of m_sat to used in this payment. - pub amount_msat: u64, + pub amount_msat: Amount, } #[derive(Debug, Error)] @@ -731,7 +775,7 @@ impl Simulation { description.destination.clone(), Duration::from_secs(description.start_secs.into()), description.count, - Duration::from_secs(description.interval_secs.into()), + description.interval_secs, description.amount_msat, ); diff --git a/sim-lib/src/serializers.rs b/sim-lib/src/serializers.rs index 3fd46fa7..2ce82892 100644 --- a/sim-lib/src/serializers.rs +++ b/sim-lib/src/serializers.rs @@ -45,6 +45,42 @@ pub mod serde_node_id { } } +pub mod serde_value_or_range { + use super::*; + use serde::de::Error; + + use crate::ValueOrRange; + + pub fn serialize(x: &ValueOrRange, serializer: S) -> Result + where + S: serde::Serializer, + T: std::fmt::Display, + { + serializer.serialize_str(&match x { + ValueOrRange::Value(p) => p.to_string(), + ValueOrRange::Range(x, y) => format!("[{}, {}]", x, y), + }) + } + + pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + T: serde::Deserialize<'de> + std::cmp::PartialOrd + std::fmt::Display + Copy, + { + let a = ValueOrRange::deserialize(deserializer)?; + if let ValueOrRange::Range(x, y) = a { + if x >= y { + return Err(D::Error::custom(format!( + "Cannot parse range. Ranges must be strictly increasing (i.e. [x, y] with x > y). Received [{}, {}]", + x, y + ))); + } + } + + Ok(a) + } +} + pub fn deserialize_path<'de, D>(deserializer: D) -> Result where D: serde::Deserializer<'de>,