From 45a7b21d99507a6921865c9abe1c16d81573a6d2 Mon Sep 17 00:00:00 2001 From: Stanimal Date: Mon, 6 Apr 2020 16:10:04 +0200 Subject: [PATCH] Persistent store and forward - Implemented sqlite persistence for store and forward (SAF). - Nodes keep track of a timestamp that they last requested messages for - Nodes use that timestamp when requesting messages to reduce the number of duplicate messages requested - Request for SAF messages is broken up into 3 responses: Discovery, Join and ExplicitlyAddressed - respectively, discovery messages, join messages and messages that are explicitly addressed to this node. - DB migrations run on node startup - Sqlite operations run on tokio blocking threads - Sqlite connection interface that allows usage of sqlite's in-memory database connection in addition to the file-system database --- applications/tari_base_node/src/builder.rs | 7 +- comms/dht/Cargo.toml | 3 + comms/dht/diesel.toml | 5 + comms/dht/examples/memorynet.rs | 2 +- comms/dht/migrations/.gitkeep | 0 .../2020-04-01-095825_initial/down.sql | 2 + .../2020-04-01-095825_initial/up.sql | 27 ++ comms/dht/src/actor.rs | 99 ++++- comms/dht/src/builder.rs | 7 +- comms/dht/src/config.rs | 17 +- comms/dht/src/dht.rs | 35 +- comms/dht/src/envelope.rs | 29 +- comms/dht/src/inbound/dht_handler/task.rs | 10 +- comms/dht/src/inbound/error.rs | 3 - comms/dht/src/inbound/message.rs | 21 +- comms/dht/src/lib.rs | 11 + comms/dht/src/outbound/serialize.rs | 13 +- comms/dht/src/proto/store_forward.proto | 15 +- comms/dht/src/proto/tari.dht.store_forward.rs | 22 +- comms/dht/src/schema.rs | 26 ++ comms/dht/src/storage/connection.rs | 122 ++++++ comms/dht/src/storage/database.rs | 77 ++++ comms/dht/src/storage/dht_setting_entry.rs | 51 +++ comms/dht/src/storage/error.rs | 37 ++ comms/dht/src/storage/mod.rs | 33 ++ comms/dht/src/store_forward/database/mod.rs | 165 +++++++ .../store_forward/database/stored_message.rs | 89 ++++ comms/dht/src/store_forward/error.rs | 15 +- comms/dht/src/store_forward/message.rs | 61 ++- comms/dht/src/store_forward/mod.rs | 25 +- .../src/store_forward/saf_handler/layer.rs | 15 +- .../store_forward/saf_handler/middleware.rs | 10 +- .../dht/src/store_forward/saf_handler/task.rs | 251 ++++++----- comms/dht/src/store_forward/service.rs | 274 ++++++++++++ comms/dht/src/store_forward/store.rs | 413 +++++++++++++----- comms/dht/src/test_utils/dht_actor_mock.rs | 36 +- comms/dht/src/test_utils/makers.rs | 2 +- comms/dht/src/test_utils/mod.rs | 13 +- .../src/test_utils/store_and_forward_mock.rs | 135 ++++++ .../src/{store_forward/state.rs => utils.rs} | 43 +- comms/src/message/envelope.rs | 4 + .../storage/src/lmdb_store/store.rs | 2 +- 42 files changed, 1874 insertions(+), 353 deletions(-) create mode 100644 comms/dht/diesel.toml create mode 100644 comms/dht/migrations/.gitkeep create mode 100644 comms/dht/migrations/2020-04-01-095825_initial/down.sql create mode 100644 comms/dht/migrations/2020-04-01-095825_initial/up.sql create mode 100644 comms/dht/src/schema.rs create mode 100644 comms/dht/src/storage/connection.rs create mode 100644 comms/dht/src/storage/database.rs create mode 100644 comms/dht/src/storage/dht_setting_entry.rs create mode 100644 comms/dht/src/storage/error.rs create mode 100644 comms/dht/src/storage/mod.rs create mode 100644 comms/dht/src/store_forward/database/mod.rs create mode 100644 comms/dht/src/store_forward/database/stored_message.rs create mode 100644 comms/dht/src/store_forward/service.rs create mode 100644 comms/dht/src/test_utils/store_and_forward_mock.rs rename comms/dht/src/{store_forward/state.rs => utils.rs} (61%) diff --git a/applications/tari_base_node/src/builder.rs b/applications/tari_base_node/src/builder.rs index dc4df6915b..b347f2f219 100644 --- a/applications/tari_base_node/src/builder.rs +++ b/applications/tari_base_node/src/builder.rs @@ -46,7 +46,7 @@ use tari_comms::{ ConnectionManagerEvent, PeerManager, }; -use tari_comms_dht::Dht; +use tari_comms_dht::{DbConnectionUrl, Dht, DhtConfig}; use tari_core::{ base_node::{ chain_metadata_service::{ChainMetadataHandle, ChainMetadataServiceInitializer}, @@ -841,7 +841,10 @@ async fn setup_base_node_comms( max_concurrent_inbound_tasks: 100, outbound_buffer_size: 100, // TODO - make this configurable - dht: Default::default(), + dht: DhtConfig { + database_url: DbConnectionUrl::File(config.data_dir.join("dht.db")), + ..Default::default() + }, // TODO: This should be false unless testing locally - make this configurable allow_test_addresses: true, listener_liveness_whitelist_cidrs: config.listener_liveness_whitelist_cidrs.clone(), diff --git a/comms/dht/Cargo.toml b/comms/dht/Cargo.toml index 2f2679fd99..9e8e4717f9 100644 --- a/comms/dht/Cargo.toml +++ b/comms/dht/Cargo.toml @@ -22,6 +22,8 @@ bitflags = "1.2.0" bytes = "0.4.12" chrono = "0.4.9" derive-error = "0.0.4" +diesel = {version="1.4", features = ["sqlite", "serde_json", "chrono"]} +diesel_migrations = "1.4" digest = "0.8.1" futures= {version= "^0.3.1"} log = "0.4.8" @@ -34,6 +36,7 @@ serde_repr = "0.1.5" tokio = {version="0.2.10", features=["rt-threaded", "blocking"]} tower= "0.3.0" ttl_cache = "0.5.1" + # tower-filter dependencies pin-project = "0.4" diff --git a/comms/dht/diesel.toml b/comms/dht/diesel.toml new file mode 100644 index 0000000000..92267c829f --- /dev/null +++ b/comms/dht/diesel.toml @@ -0,0 +1,5 @@ +# For documentation on how to configure this file, +# see diesel.rs/guides/configuring-diesel-cli + +[print_schema] +file = "src/schema.rs" diff --git a/comms/dht/examples/memorynet.rs b/comms/dht/examples/memorynet.rs index 66d5cc6d82..ef04d93a96 100644 --- a/comms/dht/examples/memorynet.rs +++ b/comms/dht/examples/memorynet.rs @@ -403,7 +403,7 @@ async fn do_store_and_forward_discovery( println!("Waiting a few seconds for discovery to propagate around the network..."); time::delay_for(Duration::from_secs(8)).await; - let mut total_messages = drain_messaging_events(messaging_rx, true).await; + let mut total_messages = drain_messaging_events(messaging_rx, false).await; banner!("🤓 {} is coming back online", get_name(node_identity.node_id())); let (tx, ims_rx) = mpsc::channel(1); diff --git a/comms/dht/migrations/.gitkeep b/comms/dht/migrations/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/comms/dht/migrations/2020-04-01-095825_initial/down.sql b/comms/dht/migrations/2020-04-01-095825_initial/down.sql new file mode 100644 index 0000000000..6e0a2cbbd7 --- /dev/null +++ b/comms/dht/migrations/2020-04-01-095825_initial/down.sql @@ -0,0 +1,2 @@ +DROP TABLE IF EXISTS stored_messages; +DROP TABLE IF EXISTS dht_settings; diff --git a/comms/dht/migrations/2020-04-01-095825_initial/up.sql b/comms/dht/migrations/2020-04-01-095825_initial/up.sql new file mode 100644 index 0000000000..3805ecaa31 --- /dev/null +++ b/comms/dht/migrations/2020-04-01-095825_initial/up.sql @@ -0,0 +1,27 @@ +CREATE TABLE stored_messages ( + id INTEGER NOT NULL PRIMARY KEY, + version INT NOT NULL, + origin_pubkey TEXT NOT NULL, + origin_signature TEXT NOT NULL, + message_type INT NOT NULL, + destination_pubkey TEXT, + destination_node_id TEXT, + header BLOB NOT NULL, + body BLOB NOT NULL, + is_encrypted BOOLEAN NOT NULL CHECK (is_encrypted IN (0,1)), + priority INT NOT NULL, + stored_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX idx_stored_messages_destination_pubkey ON stored_messages (destination_pubkey); +CREATE INDEX idx_stored_messages_destination_node_id ON stored_messages (destination_node_id); +CREATE INDEX idx_stored_messages_stored_at ON stored_messages (stored_at); +CREATE INDEX idx_stored_messages_priority ON stored_messages (priority); + +CREATE TABLE dht_settings ( + id INTEGER PRIMARY KEY NOT NULL, + key TEXT NOT NULL, + value BLOB NOT NULL +); + +CREATE UNIQUE INDEX idx_dht_settings_key ON dht_settings (key); diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 88cf9896bd..3c30ea6215 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -32,16 +32,15 @@ use crate::{ discovery::DhtDiscoveryError, outbound::{OutboundMessageRequester, SendMessageParams}, proto::{dht::JoinMessage, envelope::DhtMessageType, store_forward::StoredMessagesRequest}, + storage::{DbConnection, DhtDatabase, DhtSettingKey, StorageError}, DhtConfig, }; -use chrono::{DateTime, Utc}; use derive_error::Error; use futures::{ channel::{mpsc, mpsc::SendError, oneshot}, future, future::BoxFuture, stream::{Fuse, FuturesUnordered}, - FutureExt, SinkExt, StreamExt, }; @@ -60,7 +59,10 @@ use tari_comms::{ }, types::CommsPublicKey, }; -use tari_crypto::tari_utilities::ByteArray; +use tari_crypto::tari_utilities::{ + message_format::{MessageFormat, MessageFormatError}, + ByteArray, +}; use tari_shutdown::ShutdownSignal; use tari_storage::IterationResult; use ttl_cache::TtlCache; @@ -80,6 +82,11 @@ pub enum DhtActorError { SendFailed(String), DiscoveryError(DhtDiscoveryError), BlockingJoinError(tokio::task::JoinError), + StorageError(StorageError), + #[error(no_from)] + StoredValueFailedToDeserialize(MessageFormatError), + #[error(no_from)] + FailedToSerializeValue(MessageFormatError), } impl From for DhtActorError { @@ -98,23 +105,27 @@ impl From for DhtActorError { pub enum DhtRequest { /// Send a Join request to the network SendJoin, - /// Send a request for stored messages, optionally specifying a date time that the foreign node should - /// use to filter the returned messages. - SendRequestStoredMessages(Option>), + /// Send requests to neighbours for stored messages + SendRequestStoredMessages, /// Inserts a message signature to the msg hash cache. This operation replies with a boolean /// which is true if the signature already exists in the cache, otherwise false MsgHashCacheInsert(Vec, oneshot::Sender), /// Fetch selected peers according to the broadcast strategy SelectPeers(BroadcastStrategy, oneshot::Sender>), + GetSetting(DhtSettingKey, oneshot::Sender>, DhtActorError>>), + SetSetting(DhtSettingKey, Vec), } impl Display for DhtRequest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use DhtRequest::*; match self { - DhtRequest::SendJoin => f.write_str("SendJoin"), - DhtRequest::SendRequestStoredMessages(d) => f.write_str(&format!("SendRequestStoredMessages ({:?})", d)), - DhtRequest::MsgHashCacheInsert(_, _) => f.write_str("MsgHashCacheInsert"), - DhtRequest::SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), + SendJoin => f.write_str("SendJoin"), + SendRequestStoredMessages => f.write_str("SendRequestStoredMessages"), + MsgHashCacheInsert(_, _) => f.write_str("MsgHashCacheInsert"), + SelectPeers(s, _) => f.write_str(&format!("SelectPeers (Strategy={})", s)), + GetSetting(key, _) => f.write_str(&format!("GetStoreItem (key={})", key)), + SetSetting(key, value) => f.write_str(&format!("SelectPeers (key={}, value={} bytes)", key, value.len())), } } } @@ -151,10 +162,25 @@ impl DhtRequester { } pub async fn send_request_stored_messages(&mut self) -> Result<(), DhtActorError> { - self.sender - .send(DhtRequest::SendRequestStoredMessages(None)) - .await - .map_err(Into::into) + self.sender.send(DhtRequest::SendRequestStoredMessages).await?; + Ok(()) + } + + pub async fn get_setting(&mut self, key: DhtSettingKey) -> Result, DhtActorError> { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender.send(DhtRequest::GetSetting(key, reply_tx)).await?; + match reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)?? { + Some(bytes) => T::from_binary(&bytes) + .map(Some) + .map_err(DhtActorError::StoredValueFailedToDeserialize), + None => Ok(None), + } + } + + pub async fn set_setting(&mut self, key: DhtSettingKey, value: T) -> Result<(), DhtActorError> { + let bytes = value.to_binary().map_err(DhtActorError::FailedToSerializeValue)?; + self.sender.send(DhtRequest::SetSetting(key, bytes)).await?; + Ok(()) } } @@ -191,18 +217,22 @@ impl<'a> DhtActor<'a> { } } - pub async fn run(mut self) { + pub async fn run(mut self) -> Result<(), DhtActorError> { + let conn = DbConnection::connect_url(self.config.database_url.clone()).await?; + let output = conn.migrate().await?; + info!(target: LOG_TARGET, "Dht database migration:\n{}", output); + let db = DhtDatabase::new(conn); + let mut shutdown_signal = self .shutdown_signal .take() - .expect("DhtActor initialized without shutdown_signal") - .fuse(); + .expect("DhtActor initialized without shutdown_signal"); loop { futures::select! { request = self.request_rx.select_next_some() => { debug!(target: LOG_TARGET, "DhtActor received message: {}", request); - let handler = self.request_handler(request); + let handler = self.request_handler(db.clone(), request); self.pending_jobs.push(handler); }, @@ -227,9 +257,11 @@ impl<'a> DhtActor<'a> { } } } + + Ok(()) } - fn request_handler(&mut self, request: DhtRequest) -> BoxFuture<'a, Result<(), DhtActorError>> { + fn request_handler(&mut self, db: DhtDatabase, request: DhtRequest) -> BoxFuture<'a, Result<(), DhtActorError>> { use DhtRequest::*; match request { SendJoin => { @@ -265,16 +297,31 @@ impl<'a> DhtActor<'a> { } }) }, - SendRequestStoredMessages(maybe_since) => { + SendRequestStoredMessages => { let node_identity = Arc::clone(&self.node_identity); let outbound_requester = self.outbound_requester.clone(); Box::pin(Self::request_stored_messages( node_identity, outbound_requester, + db, self.config.num_neighbouring_nodes, - maybe_since, )) }, + GetSetting(key, reply_tx) => Box::pin(async move { + let _ = reply_tx.send(db.get_value(key).await.map_err(Into::into)); + Ok(()) + }), + SetSetting(key, value) => Box::pin(async move { + match db.set_value(key, value).await { + Ok(_) => { + info!(target: LOG_TARGET, "Dht setting '{}' set", key); + }, + Err(err) => { + error!(target: LOG_TARGET, "set_setting failed because {:?}", err); + }, + } + Ok(()) + }), } } @@ -313,10 +360,16 @@ impl<'a> DhtActor<'a> { async fn request_stored_messages( node_identity: Arc, mut outbound_requester: OutboundMessageRequester, + db: DhtDatabase, num_neighbouring_nodes: usize, - maybe_since: Option>, ) -> Result<(), DhtActorError> { + let request = db + .get_value(DhtSettingKey::SafLastRequestTimestamp) + .await? + .map(StoredMessagesRequest::since) + .unwrap_or_else(StoredMessagesRequest::new); + outbound_requester .send_message_no_header( SendMessageParams::new() @@ -328,7 +381,7 @@ impl<'a> DhtActor<'a> { ) .with_dht_message_type(DhtMessageType::SafRequestMessages) .finish(), - maybe_since.map(StoredMessagesRequest::since).unwrap_or_default(), + request, ) .await .map_err(|err| DhtActorError::SendFailed(format!("Failed to send request for stored messages: {}", err)))?; diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index ad6f175dc4..43e2512b24 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{outbound::DhtOutboundRequest, Dht, DhtConfig}; +use crate::{outbound::DhtOutboundRequest, DbConnectionUrl, Dht, DhtConfig}; use futures::channel::mpsc; use std::{sync::Arc, time::Duration}; use tari_comms::{ @@ -80,6 +80,11 @@ impl DhtBuilder { self } + pub fn with_database_url(mut self, database_url: DbConnectionUrl) -> Self { + self.config.database_url = database_url; + self + } + pub fn with_signature_cache_ttl(mut self, ttl: Duration) -> Self { self.config.msg_hash_cache_ttl = ttl; self diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index e8d0e3be53..0f920a10e6 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -20,20 +20,22 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::envelope::Network; +use crate::{envelope::Network, storage::DbConnectionUrl}; use std::time::Duration; /// The default maximum number of messages that can be stored using the Store-and-forward middleware pub const SAF_MSG_CACHE_STORAGE_CAPACITY: usize = 10_000; /// The default time-to-live duration used for storage of low priority messages by the Store-and-forward middleware -pub const SAF_LOW_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(6 * 60 * 60); +pub const SAF_LOW_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours /// The default time-to-live duration used for storage of high priority messages by the Store-and-forward middleware -pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(24 * 60 * 60); +pub const SAF_HIGH_PRIORITY_MSG_STORAGE_TTL: Duration = Duration::from_secs(2 * 24 * 60 * 60); // 2 days /// The default number of peer nodes that a message has to be closer to, to be considered a neighbour pub const DEFAULT_NUM_NEIGHBOURING_NODES: usize = 10; #[derive(Debug, Clone)] pub struct DhtConfig { + /// The `DbConnectionUrl` for the Dht database. Default: In-memory database + pub database_url: DbConnectionUrl, /// The size of the buffer (channel) which holds pending outbound message requests. /// Default: 20 pub outbound_buffer_size: usize, @@ -53,8 +55,10 @@ pub struct DhtConfig { /// Default: 6 hours pub saf_low_priority_msg_storage_ttl: Duration, /// The time-to-live duration used for storage of high priority messages by the Store-and-forward middleware. - /// Default: 24 hours + /// Default: 2 days pub saf_high_priority_msg_storage_ttl: Duration, + /// The limit on the message size to store in SAF storage in bytes. Default 500kb + pub saf_max_message_size: usize, /// The max capacity of the message hash cache /// Default: 1000 pub msg_hash_cache_capacity: usize, @@ -92,6 +96,7 @@ impl DhtConfig { pub fn default_local_test() -> Self { Self { network: Network::LocalTest, + database_url: DbConnectionUrl::Memory, ..Default::default() } } @@ -102,14 +107,16 @@ impl Default for DhtConfig { Self { num_neighbouring_nodes: DEFAULT_NUM_NEIGHBOURING_NODES, saf_num_closest_nodes: 10, - saf_max_returned_messages: 1000, + saf_max_returned_messages: 100, outbound_buffer_size: 20, saf_msg_cache_storage_capacity: SAF_MSG_CACHE_STORAGE_CAPACITY, saf_low_priority_msg_storage_ttl: SAF_LOW_PRIORITY_MSG_STORAGE_TTL, saf_high_priority_msg_storage_ttl: SAF_HIGH_PRIORITY_MSG_STORAGE_TTL, + saf_max_message_size: 512 * 1024, // 512 kb msg_hash_cache_capacity: 10_000, msg_hash_cache_ttl: Duration::from_secs(300), broadcast_cooldown_max_attempts: 3, + database_url: DbConnectionUrl::Memory, broadcast_cooldown_period: Duration::from_secs(60 * 30), discovery_request_timeout: Duration::from_secs(2 * 60), network: Network::TestNet, diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 0af7ac2261..15e444a3da 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -31,6 +31,7 @@ use crate::{ outbound::DhtOutboundRequest, proto::envelope::DhtMessageType, store_forward, + store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}, tower_filter, DhtConfig, }; @@ -60,7 +61,9 @@ pub struct Dht { outbound_tx: mpsc::Sender, /// Sender for DHT requests dht_sender: mpsc::Sender, - /// Sender for DHT requests + /// Sender for SAF requests + saf_sender: mpsc::Sender, + /// Sender for DHT discovery requests discovery_sender: mpsc::Sender, /// Connection manager actor requester connection_manager: ConnectionManagerRequester, @@ -78,6 +81,7 @@ impl Dht { { let (dht_sender, dht_receiver) = mpsc::channel(20); let (discovery_sender, discovery_receiver) = mpsc::channel(20); + let (saf_sender, saf_receiver) = mpsc::channel(20); let dht = Self { node_identity, @@ -85,10 +89,18 @@ impl Dht { config, outbound_tx, dht_sender, + saf_sender, connection_manager, discovery_sender, }; + if dht.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD) { + task::spawn( + dht.store_and_forward_service(saf_receiver, shutdown_signal.clone()) + .run(), + ); + } + task::spawn(dht.actor(dht_receiver, shutdown_signal.clone()).run()); task::spawn(dht.discovery_service(discovery_receiver, shutdown_signal).run()); @@ -130,6 +142,15 @@ impl Dht { ) } + fn store_and_forward_service( + &self, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, + ) -> StoreAndForwardService + { + StoreAndForwardService::new(self.config.clone(), request_rx, shutdown_signal) + } + /// Return a new OutboundMessageRequester connected to the receiver pub fn outbound_requester(&self) -> OutboundMessageRequester { OutboundMessageRequester::new(self.outbound_tx.clone()) @@ -145,6 +166,11 @@ impl Dht { DhtDiscoveryRequester::new(self.discovery_sender.clone(), self.config.discovery_request_timeout) } + /// Returns a requester for the StoreAndForwardService associated with this instance + pub fn store_and_forward_requester(&self) -> StoreAndForwardRequester { + StoreAndForwardRequester::new(self.saf_sender.clone()) + } + /// Returns an the full DHT stack as a `tower::layer::Layer`. This can be composed with /// other inbound middleware services which expect an DecryptedDhtMessage pub fn inbound_middleware_layer( @@ -163,9 +189,6 @@ impl Dht { S: Service + Clone + Send + Sync + 'static, S::Future: Send, { - let saf_storage = Arc::new(store_forward::SafStorage::new( - self.config.saf_msg_cache_storage_capacity, - )); let builder = ServiceBuilder::new() .layer(inbound::DeserializeLayer::new()) .layer(inbound::ValidateLayer::new( @@ -192,11 +215,11 @@ impl Dht { self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&saf_storage), + self.store_and_forward_requester(), )) .layer(store_forward::MessageHandlerLayer::new( self.config.clone(), - saf_storage, + self.store_and_forward_requester(), self.dht_requester(), Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/envelope.rs b/comms/dht/src/envelope.rs index 01822b31ad..ae1d0d6893 100644 --- a/comms/dht/src/envelope.rs +++ b/comms/dht/src/envelope.rs @@ -69,11 +69,36 @@ bitflags! { } } +impl DhtMessageFlags { + pub fn is_encrypted(self) -> bool { + self.contains(Self::ENCRYPTED) + } +} + impl DhtMessageType { pub fn is_dht_message(self) -> bool { + self.is_dht_discovery() || self.is_dht_join() + } + + pub fn is_dht_discovery(self) -> bool { + match self { + DhtMessageType::Discovery => true, + _ => false, + } + } + + pub fn is_dht_join(self) -> bool { + match self { + DhtMessageType::Join => true, + _ => false, + } + } + + pub fn is_saf_message(self) -> bool { + use DhtMessageType::*; match self { - DhtMessageType::None => false, - _ => true, + SafRequestMessages | SafStoredMessages => true, + _ => false, } } } diff --git a/comms/dht/src/inbound/dht_handler/task.rs b/comms/dht/src/inbound/dht_handler/task.rs index f76dda820c..9a3dc372d1 100644 --- a/comms/dht/src/inbound/dht_handler/task.rs +++ b/comms/dht/src/inbound/dht_handler/task.rs @@ -85,9 +85,13 @@ where S: Service .take() .expect("ProcessDhtMessage initialized without message"); - // If this message failed to decrypt, this middleware is not interested in it + // If this message failed to decrypt, we stop it going further at this layer if message.decryption_failed() { - return self.next_service.oneshot(message).await; + debug!( + target: LOG_TARGET, + "Message that failed to decrypt will be discarded here. DhtHeader={}", message.dht_header + ); + return Ok(()); } match message.dht_header.message_type { @@ -189,7 +193,7 @@ where S: Service let body = decryption_result.expect("already checked that this message decrypted successfully"); let join_msg = body .decode_part::(0)? - .ok_or_else(|| DhtInboundError::InvalidJoinNetAddresses)?; + .ok_or_else(|| DhtInboundError::InvalidMessageBody)?; let addresses = join_msg .addresses diff --git a/comms/dht/src/inbound/error.rs b/comms/dht/src/inbound/error.rs index 43e548f4a6..d383728263 100644 --- a/comms/dht/src/inbound/error.rs +++ b/comms/dht/src/inbound/error.rs @@ -28,7 +28,6 @@ use tari_comms::{message::MessageError, peer_manager::PeerManagerError}; #[derive(Debug, Error)] pub enum DhtInboundError { MessageError(MessageError), - // MessageFormatError(MessageFormatError), PeerManagerError(PeerManagerError), DhtOutboundError(DhtOutboundError), /// Failed to decode message @@ -39,8 +38,6 @@ pub enum DhtInboundError { InvalidNodeId, /// All given addresses were invalid InvalidAddresses, - /// One or more NetAddress in the join message were invalid - InvalidJoinNetAddresses, DhtDiscoveryError(DhtDiscoveryError), #[error(msg_embedded, no_from, non_std)] OriginRequired(String), diff --git a/comms/dht/src/inbound/message.rs b/comms/dht/src/inbound/message.rs index 980b037d4c..6224bf340c 100644 --- a/comms/dht/src/inbound/message.rs +++ b/comms/dht/src/inbound/message.rs @@ -20,7 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{consts::DHT_ENVELOPE_HEADER_VERSION, envelope::DhtMessageHeader}; +use crate::{ + consts::DHT_ENVELOPE_HEADER_VERSION, + envelope::{DhtMessageFlags, DhtMessageHeader}, +}; use std::{ fmt::{Display, Error, Formatter}, sync::Arc, @@ -119,4 +122,20 @@ impl DecryptedDhtMessage { .map(|o| &o.public_key) .unwrap_or(&self.source_peer.public_key) } + + /// Returns true if the message is or was encrypted by + pub fn is_encrypted(&self) -> bool { + self.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) + } + + pub fn has_origin(&self) -> bool { + self.dht_header.origin.is_some() + } + + pub fn body_size(&self) -> usize { + match self.decryption_result.as_ref() { + Ok(b) => b.total_size(), + Err(b) => b.len(), + } + } } diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 37f4e18686..e4d2500db2 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -107,6 +107,11 @@ // Details: https://doc.rust-lang.org/beta/unstable-book/language-features/type-alias-impl-trait.html #![feature(type_alias_impl_trait)] +#[macro_use] +extern crate diesel; +#[macro_use] +extern crate diesel_migrations; + #[macro_use] mod macros; @@ -132,9 +137,15 @@ pub use dht::Dht; mod discovery; pub use discovery::DhtDiscoveryRequester; +mod storage; +pub use storage::DbConnectionUrl; + mod logging_middleware; mod proto; mod tower_filter; +mod utils; + +mod schema; pub mod broadcast_strategy; pub mod domain_message; diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index c7be4e5b42..5dc07daf54 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -32,7 +32,7 @@ use tari_comms::{ utils::signature, Bytes, }; -use tari_crypto::tari_utilities::{hex::Hex, message_format::MessageFormat}; +use tari_crypto::tari_utilities::message_format::MessageFormat; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::serialize"; @@ -97,10 +97,9 @@ where S: Service // If forwarding the message, the DhtHeader already has a signature that should not change if is_forwarded { - trace!( + debug!( target: LOG_TARGET, - "Forwarded message {:?}. Message will not be signed", - message.tag + "Message ({}) is being forwarded so this node will NOT signed it", message.tag ); } else { // Sign the body if the origin public key was previously specified. @@ -108,12 +107,6 @@ where S: Service let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), &body) .map_err(PipelineError::from_debug)?; origin.signature = signature.to_binary().map_err(PipelineError::from_debug)?; - trace!( - target: LOG_TARGET, - "Signed message {:?}: {}", - message.tag, - origin.signature.to_hex() - ); } } diff --git a/comms/dht/src/proto/store_forward.proto b/comms/dht/src/proto/store_forward.proto index 702aee4225..c7d26c416f 100644 --- a/comms/dht/src/proto/store_forward.proto +++ b/comms/dht/src/proto/store_forward.proto @@ -11,6 +11,7 @@ package tari.dht.store_forward; // will be sent. message StoredMessagesRequest { google.protobuf.Timestamp since = 1; + uint32 request_id = 2; } // Storage for a single message envelope, including the date and time when the element was stored @@ -18,10 +19,22 @@ message StoredMessage { google.protobuf.Timestamp stored_at = 1; uint32 version = 2; tari.dht.envelope.DhtHeader dht_header = 3; - bytes encrypted_body = 4; + bytes body = 4; } // The StoredMessages contains the set of applicable messages retrieved from a neighbouring peer node. message StoredMessagesResponse { repeated StoredMessage messages = 1; + uint32 request_id = 2; + enum SafResponseType { + // All applicable messages + General = 0; + // Send messages explicitly addressed to the requesting node or within the requesting node's region + ExplicitlyAddressed = 1; + // Send Discovery messages that could be for the requester + Discovery = 2; + // Send Join messages that the requester could be interested in + Join = 3; + } + SafResponseType response_type = 3; } diff --git a/comms/dht/src/proto/tari.dht.store_forward.rs b/comms/dht/src/proto/tari.dht.store_forward.rs index f1f881adaf..e0b58b90b7 100644 --- a/comms/dht/src/proto/tari.dht.store_forward.rs +++ b/comms/dht/src/proto/tari.dht.store_forward.rs @@ -5,6 +5,8 @@ pub struct StoredMessagesRequest { #[prost(message, optional, tag = "1")] pub since: ::std::option::Option<::prost_types::Timestamp>, + #[prost(uint32, tag = "2")] + pub request_id: u32, } /// Storage for a single message envelope, including the date and time when the element was stored #[derive(Clone, PartialEq, ::prost::Message)] @@ -16,11 +18,29 @@ pub struct StoredMessage { #[prost(message, optional, tag = "3")] pub dht_header: ::std::option::Option, #[prost(bytes, tag = "4")] - pub encrypted_body: std::vec::Vec, + pub body: std::vec::Vec, } /// The StoredMessages contains the set of applicable messages retrieved from a neighbouring peer node. #[derive(Clone, PartialEq, ::prost::Message)] pub struct StoredMessagesResponse { #[prost(message, repeated, tag = "1")] pub messages: ::std::vec::Vec, + #[prost(uint32, tag = "2")] + pub request_id: u32, + #[prost(enumeration = "stored_messages_response::SafResponseType", tag = "3")] + pub response_type: i32, +} +pub mod stored_messages_response { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] + #[repr(i32)] + pub enum SafResponseType { + /// All applicable messages + General = 0, + /// Send messages explicitly addressed to the requesting node or within the requesting node's region + ExplicitlyAddressed = 1, + /// Send Discovery messages that could be for the requester + Discovery = 2, + /// Send Join messages that the requester could be interested in + Join = 3, + } } diff --git a/comms/dht/src/schema.rs b/comms/dht/src/schema.rs new file mode 100644 index 0000000000..2ebbfb28b0 --- /dev/null +++ b/comms/dht/src/schema.rs @@ -0,0 +1,26 @@ +table! { + dht_settings (id) { + id -> Integer, + key -> Text, + value -> Binary, + } +} + +table! { + stored_messages (id) { + id -> Integer, + version -> Integer, + origin_pubkey -> Text, + origin_signature -> Text, + message_type -> Integer, + destination_pubkey -> Nullable, + destination_node_id -> Nullable, + header -> Binary, + body -> Binary, + is_encrypted -> Bool, + priority -> Integer, + stored_at -> Timestamp, + } +} + +allow_tables_to_appear_in_same_query!(dht_settings, stored_messages,); diff --git a/comms/dht/src/storage/connection.rs b/comms/dht/src/storage/connection.rs new file mode 100644 index 0000000000..8c7037108a --- /dev/null +++ b/comms/dht/src/storage/connection.rs @@ -0,0 +1,122 @@ +// Copyright 2020. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::storage::error::StorageError; +use diesel::{Connection, SqliteConnection}; +use std::{ + io, + path::PathBuf, + sync::{Arc, Mutex}, +}; +use tokio::task; + +#[derive(Clone, Debug)] +pub enum DbConnectionUrl { + /// In-memory database. Each connection has it's own database + Memory, + /// In-memory database shared with more than one in-process connection according to the given identifier + MemoryShared(String), + /// Database persisted on disk + File(PathBuf), +} + +impl DbConnectionUrl { + pub fn to_url_string(&self) -> String { + use DbConnectionUrl::*; + match self { + Memory => ":memory:".to_owned(), + MemoryShared(identifier) => format!("file:{}?mode=memory&cache=shared", identifier), + File(path) => path + .to_str() + .expect("Invalid non-UTF8 character in database path") + .to_owned(), + } + } +} + +#[derive(Clone)] +pub struct DbConnection { + inner: Arc>, +} + +impl DbConnection { + #[cfg(test)] + pub async fn connect_memory(name: String) -> Result { + Self::connect_url(DbConnectionUrl::MemoryShared(name)).await + } + + pub async fn connect_url(db_url: DbConnectionUrl) -> Result { + let conn = task::spawn_blocking(move || { + let conn = SqliteConnection::establish(&db_url.to_url_string())?; + conn.execute("PRAGMA foreign_keys = ON; PRAGMA busy_timeout = 60000;")?; + Result::<_, StorageError>::Ok(conn) + }) + .await??; + + Ok(Self::new(conn)) + } + + fn new(conn: SqliteConnection) -> Self { + Self { + inner: Arc::new(Mutex::new(conn)), + } + } + + pub async fn migrate(&self) -> Result { + embed_migrations!("./migrations"); + + self.with_connection_async(|conn| { + let mut buf = io::Cursor::new(Vec::new()); + embedded_migrations::run_with_output(conn, &mut buf) + .map_err(|err| StorageError::DatabaseMigrationFailed(format!("Database migration failed {}", err)))?; + Ok(String::from_utf8_lossy(&buf.into_inner()).to_string()) + }) + .await + } + + pub async fn with_connection_async(&self, f: F) -> Result + where + F: FnOnce(&SqliteConnection) -> Result + Send + 'static, + R: Send + 'static, + { + let conn_mutex = self.inner.clone(); + let ret = task::spawn_blocking(move || { + let lock = acquire_lock!(conn_mutex); + f(&*lock) + }) + .await??; + Ok(ret) + } +} + +#[cfg(test)] +mod test { + use super::*; + use tari_test_utils::random; + + #[tokio_macros::test_basic] + async fn connect_and_migrate() { + let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); + let output = conn.migrate().await.unwrap(); + assert!(output.starts_with("Running migration")); + } +} diff --git a/comms/dht/src/storage/database.rs b/comms/dht/src/storage/database.rs new file mode 100644 index 0000000000..c24f7fb236 --- /dev/null +++ b/comms/dht/src/storage/database.rs @@ -0,0 +1,77 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::{dht_setting_entry::DhtSettingsEntry, DbConnection, StorageError}; +use crate::{ + schema::dht_settings, + storage::{dht_setting_entry::NewDhtSettingEntry, DhtSettingKey}, +}; +use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl}; +use tari_crypto::tari_utilities::message_format::MessageFormat; + +#[derive(Clone)] +pub struct DhtDatabase { + connection: DbConnection, +} + +impl DhtDatabase { + pub fn new(connection: DbConnection) -> Self { + Self { connection } + } + + pub async fn get_value(&self, key: DhtSettingKey) -> Result, StorageError> { + match self.get_value_bytes(key).await? { + Some(bytes) => T::from_binary(&bytes).map(Some).map_err(Into::into), + None => Ok(None), + } + } + + pub async fn get_value_bytes(&self, key: DhtSettingKey) -> Result>, StorageError> { + self.connection + .with_connection_async(move |conn| { + dht_settings::table + .filter(dht_settings::key.eq(key.to_string())) + .first(conn) + .map(|rec: DhtSettingsEntry| Some(rec.value)) + .or_else(|err| match err { + diesel::result::Error::NotFound => Ok(None), + err => Err(err.into()), + }) + }) + .await + } + + pub async fn set_value(&self, key: DhtSettingKey, value: Vec) -> Result<(), StorageError> { + self.connection + .with_connection_async(move |conn| { + diesel::replace_into(dht_settings::table) + .values(NewDhtSettingEntry { + key: key.to_string(), + value, + }) + .execute(conn) + .map(|_| ()) + .map_err(Into::into) + }) + .await + } +} diff --git a/comms/dht/src/storage/dht_setting_entry.rs b/comms/dht/src/storage/dht_setting_entry.rs new file mode 100644 index 0000000000..4733149b74 --- /dev/null +++ b/comms/dht/src/storage/dht_setting_entry.rs @@ -0,0 +1,51 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::schema::dht_settings; +use std::fmt; + +#[derive(Debug, Clone, Copy)] +pub enum DhtSettingKey { + /// The timestamp of the last time this node made a SAF request + SafLastRequestTimestamp, +} + +impl fmt::Display for DhtSettingKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +#[derive(Clone, Debug, Insertable)] +#[table_name = "dht_settings"] +pub struct NewDhtSettingEntry { + pub key: String, + pub value: Vec, +} + +#[derive(Clone, Debug, Queryable, Identifiable)] +#[table_name = "dht_settings"] +pub struct DhtSettingsEntry { + pub id: i32, + pub key: String, + pub value: Vec, +} diff --git a/comms/dht/src/storage/error.rs b/comms/dht/src/storage/error.rs new file mode 100644 index 0000000000..3706da7bbb --- /dev/null +++ b/comms/dht/src/storage/error.rs @@ -0,0 +1,37 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use derive_error::Error; +use tari_crypto::tari_utilities::message_format::MessageFormatError; +use tokio::task; + +#[derive(Debug, Error)] +pub enum StorageError { + /// Database path contained non-UTF8 characters that are not supported by the host OS + InvalidUnicodePath, + JoinError(task::JoinError), + ConnectionError(diesel::ConnectionError), + #[error(msg_embedded, no_from, non_std)] + DatabaseMigrationFailed(String), + ResultError(diesel::result::Error), + MessageFormatError(MessageFormatError), +} diff --git a/comms/dht/src/storage/mod.rs b/comms/dht/src/storage/mod.rs new file mode 100644 index 0000000000..39f113762b --- /dev/null +++ b/comms/dht/src/storage/mod.rs @@ -0,0 +1,33 @@ +// Copyright 2019. The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod connection; +pub use connection::{DbConnection, DbConnectionUrl}; + +mod error; +pub use error::StorageError; + +mod dht_setting_entry; +pub use dht_setting_entry::{DhtSettingKey, DhtSettingsEntry}; + +mod database; +pub use database::DhtDatabase; diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs new file mode 100644 index 0000000000..12b772a1d5 --- /dev/null +++ b/comms/dht/src/store_forward/database/mod.rs @@ -0,0 +1,165 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +mod stored_message; +pub use stored_message::{NewStoredMessage, StoredMessage}; + +use crate::{ + envelope::DhtMessageType, + schema::stored_messages, + storage::{DbConnection, StorageError}, + store_forward::message::StoredMessagePriority, +}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use diesel::{BoolExpressionMethods, ExpressionMethods, QueryDsl, RunQueryDsl}; +use tari_comms::types::CommsPublicKey; +use tari_crypto::tari_utilities::hex::Hex; + +pub struct StoreAndForwardDatabase { + connection: DbConnection, +} + +impl StoreAndForwardDatabase { + pub fn new(connection: DbConnection) -> Self { + Self { connection } + } + + pub async fn insert_message(&self, message: NewStoredMessage) -> Result<(), StorageError> { + self.connection + .with_connection_async(|conn| { + diesel::insert_into(stored_messages::table) + .values(message) + .execute(conn)?; + Ok(()) + }) + .await + } + + pub async fn find_messages_for_public_key( + &self, + public_key: &CommsPublicKey, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let pk_hex = public_key.to_hex(); + self.connection + .with_connection_async(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter(stored_messages::destination_pubkey.eq(pk_hex)) + .filter(stored_messages::message_type.eq(DhtMessageType::None as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.ge(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub async fn find_messages_of_type_for_pubkey( + &self, + public_key: &CommsPublicKey, + message_type: DhtMessageType, + since: Option>, + limit: i64, + ) -> Result, StorageError> + { + let pk_hex = public_key.to_hex(); + self.connection + .with_connection_async(move |conn| { + let mut query = stored_messages::table + .select(stored_messages::all_columns) + .filter( + stored_messages::destination_pubkey + .eq(pk_hex) + .or(stored_messages::destination_pubkey.is_null()), + ) + .filter(stored_messages::message_type.eq(message_type as i32)) + .into_boxed(); + + if let Some(since) = since { + query = query.filter(stored_messages::stored_at.ge(since.naive_utc())); + } + + query + .order_by(stored_messages::stored_at.desc()) + .limit(limit) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + #[cfg(test)] + pub(crate) async fn get_all_messages(&self) -> Result, StorageError> { + self.connection + .with_connection_async(|conn| { + stored_messages::table + .select(stored_messages::all_columns) + .get_results(conn) + .map_err(Into::into) + }) + .await + } + + pub(crate) async fn delete_messages_with_priority_older_than( + &self, + priority: StoredMessagePriority, + since: NaiveDateTime, + ) -> Result + { + self.connection + .with_connection_async(move |conn| { + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .filter(stored_messages::priority.eq(priority as i32)) + .execute(conn) + .map_err(Into::into) + }) + .await + } +} + +#[cfg(test)] +mod test { + use super::*; + use tari_test_utils::random; + + #[tokio_macros::test_basic] + async fn insert_messages() { + let conn = DbConnection::connect_memory(random::string(8)).await.unwrap(); + // let conn = DbConnection::connect_path("/tmp/tmp.db").await.unwrap(); + conn.migrate().await.unwrap(); + let db = StoreAndForwardDatabase::new(conn); + db.insert_message(Default::default()).await.unwrap(); + let messages = db.get_all_messages().await.unwrap(); + assert_eq!(messages.len(), 1); + } +} diff --git a/comms/dht/src/store_forward/database/stored_message.rs b/comms/dht/src/store_forward/database/stored_message.rs new file mode 100644 index 0000000000..d3d9826674 --- /dev/null +++ b/comms/dht/src/store_forward/database/stored_message.rs @@ -0,0 +1,89 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::{ + envelope::DhtMessageHeader, + proto::envelope::DhtHeader, + schema::stored_messages, + store_forward::message::StoredMessagePriority, +}; +use chrono::NaiveDateTime; +use std::convert::TryInto; +use tari_comms::message::MessageExt; +use tari_crypto::tari_utilities::hex::Hex; + +#[derive(Clone, Debug, Insertable, Default)] +#[table_name = "stored_messages"] +pub struct NewStoredMessage { + pub version: i32, + pub origin_pubkey: String, + pub origin_signature: String, + pub message_type: i32, + pub destination_pubkey: Option, + pub destination_node_id: Option, + pub header: Vec, + pub body: Vec, + pub is_encrypted: bool, + pub priority: i32, +} + +impl NewStoredMessage { + pub fn try_construct( + version: u32, + dht_header: DhtMessageHeader, + priority: StoredMessagePriority, + body: Vec, + ) -> Option + { + Some(Self { + version: version.try_into().ok()?, + origin_pubkey: dht_header.origin.as_ref().map(|o| o.public_key.to_hex())?, + origin_signature: dht_header.origin.as_ref().map(|o| o.signature.to_hex())?, + message_type: dht_header.message_type as i32, + destination_pubkey: dht_header.destination.public_key().map(|pk| pk.to_hex()), + destination_node_id: dht_header.destination.node_id().map(|node_id| node_id.to_hex()), + body, + is_encrypted: dht_header.flags.is_encrypted(), + priority: priority as i32, + header: { + let dht_header: DhtHeader = dht_header.into(); + dht_header.to_encoded_bytes().ok()? + }, + }) + } +} + +#[derive(Clone, Debug, Queryable, Identifiable)] +pub struct StoredMessage { + pub id: i32, + pub version: i32, + pub origin_pubkey: String, + pub origin_signature: String, + pub message_type: i32, + pub destination_pubkey: Option, + pub destination_node_id: Option, + pub header: Vec, + pub body: Vec, + pub is_encrypted: bool, + pub priority: i32, + pub stored_at: NaiveDateTime, +} diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index dc0f084949..a5573b11f6 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -20,12 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{actor::DhtActorError, envelope::DhtMessageError, outbound::DhtOutboundError}; +use crate::{actor::DhtActorError, envelope::DhtMessageError, outbound::DhtOutboundError, storage::StorageError}; use derive_error::Error; use prost::DecodeError; use std::io; use tari_comms::{message::MessageError, peer_manager::PeerManagerError}; -use tari_crypto::tari_utilities::ciphers::cipher::CipherError; +use tari_crypto::tari_utilities::{byte_array::ByteArrayError, ciphers::cipher::CipherError}; #[derive(Debug, Error)] pub enum StoreAndForwardError { @@ -58,4 +58,15 @@ pub enum StoreAndForwardError { MessageOriginRequired, /// The message was malformed MalformedMessage, + + StorageError(StorageError), + /// The store and forward service requester channel closed + RequesterChannelClosed, + /// The request was cancelled by the store and forward service + RequestCancelled, + /// The message was not valid for store and forward + InvalidStoreMessage, + /// The envelope version is invalid + InvalidEnvelopeVersion, + MalformedNodeId(ByteArrayError), } diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index 3c98383b82..5cdcc01c3b 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -21,13 +21,22 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{ - envelope::DhtMessageHeader, - proto::store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + proto::{ + envelope::DhtHeader, + store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + }, + store_forward::{database, StoreAndForwardError}, }; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use prost::Message; use prost_types::Timestamp; +use rand::{rngs::OsRng, RngCore}; +use std::{ + cmp, + convert::{TryFrom, TryInto}, +}; -/// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` +/// Utility function that converts a `chrono::DateTime` to a `prost::Timestamp` pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { Timestamp { seconds: datetime.timestamp(), @@ -35,26 +44,54 @@ pub(crate) fn datetime_to_timestamp(datetime: DateTime) -> Timestamp { } } +/// Utility function that converts a `prost::Timestamp` to a `chrono::DateTime` +pub(crate) fn timestamp_to_datetime(timestamp: Timestamp) -> DateTime { + let naive = NaiveDateTime::from_timestamp(timestamp.seconds, cmp::max(0, timestamp.nanos) as u32); + DateTime::from_utc(naive, Utc) +} + impl StoredMessagesRequest { + pub fn new() -> Self { + Self { + since: None, + request_id: OsRng.next_u32(), + } + } + pub fn since(since: DateTime) -> Self { Self { since: Some(datetime_to_timestamp(since)), + request_id: OsRng.next_u32(), } } } +#[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: DhtMessageHeader, encrypted_body: Vec) -> Self { + pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, encrypted_body: Vec) -> Self { Self { version, dht_header: Some(dht_header.into()), - encrypted_body, + body: encrypted_body, stored_at: Some(datetime_to_timestamp(Utc::now())), } } +} - pub fn has_required_fields(&self) -> bool { - self.dht_header.is_some() +impl TryFrom for StoredMessage { + type Error = StoreAndForwardError; + + fn try_from(message: database::StoredMessage) -> Result { + let dht_header = DhtHeader::decode(message.header.as_slice())?; + Ok(Self { + stored_at: Some(datetime_to_timestamp(DateTime::from_utc(message.stored_at, Utc))), + version: message + .version + .try_into() + .map_err(|_| StoreAndForwardError::InvalidEnvelopeVersion)?, + body: message.body, + dht_header: Some(dht_header), + }) } } @@ -64,8 +101,8 @@ impl StoredMessagesResponse { } } -impl From> for StoredMessagesResponse { - fn from(messages: Vec) -> Self { - Self { messages } - } +#[derive(Debug, Copy, Clone)] +pub enum StoredMessagePriority { + Low = 1, + High = 10, } diff --git a/comms/dht/src/store_forward/mod.rs b/comms/dht/src/store_forward/mod.rs index 9b9e08b282..aa8d9a91e9 100644 --- a/comms/dht/src/store_forward/mod.rs +++ b/comms/dht/src/store_forward/mod.rs @@ -20,17 +20,24 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +type SafResult = Result; + +mod service; +pub use service::{StoreAndForwardRequest, StoreAndForwardRequester, StoreAndForwardService}; + +mod database; +pub use database::StoredMessage; + mod error; +pub use error::StoreAndForwardError; + mod forward; +pub use forward::ForwardLayer; + mod message; + mod saf_handler; -mod state; -mod store; +pub use saf_handler::MessageHandlerLayer; -pub use self::{ - error::StoreAndForwardError, - forward::ForwardLayer, - saf_handler::MessageHandlerLayer, - state::SafStorage, - store::StoreLayer, -}; +mod store; +pub use store::StoreLayer; diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 6622a3b713..d829d46838 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -21,14 +21,19 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use super::middleware::MessageHandlerMiddleware; -use crate::{actor::DhtRequester, config::DhtConfig, outbound::OutboundMessageRequester, store_forward::SafStorage}; +use crate::{ + actor::DhtRequester, + config::DhtConfig, + outbound::OutboundMessageRequester, + store_forward::StoreAndForwardRequester, +}; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; use tower::layer::Layer; pub struct MessageHandlerLayer { config: DhtConfig, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, node_identity: Arc, @@ -38,7 +43,7 @@ pub struct MessageHandlerLayer { impl MessageHandlerLayer { pub fn new( config: DhtConfig, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, node_identity: Arc, peer_manager: Arc, @@ -47,7 +52,7 @@ impl MessageHandlerLayer { { Self { config, - store, + saf_requester, dht_requester, node_identity, peer_manager, @@ -63,7 +68,7 @@ impl Layer for MessageHandlerLayer { MessageHandlerMiddleware::new( self.config.clone(), service, - Arc::clone(&self.store), + self.saf_requester.clone(), self.dht_requester.clone(), Arc::clone(&self.node_identity), Arc::clone(&self.peer_manager), diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 94736306b0..2a01ac713f 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -26,7 +26,7 @@ use crate::{ config::DhtConfig, inbound::DecryptedDhtMessage, outbound::OutboundMessageRequester, - store_forward::SafStorage, + store_forward::StoreAndForwardRequester, }; use futures::{task::Context, Future}; use std::{sync::Arc, task::Poll}; @@ -40,7 +40,7 @@ use tower::Service; pub struct MessageHandlerMiddleware { config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, node_identity: Arc, @@ -51,7 +51,7 @@ impl MessageHandlerMiddleware { pub fn new( config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, node_identity: Arc, peer_manager: Arc, @@ -60,7 +60,7 @@ impl MessageHandlerMiddleware { { Self { config, - store, + saf_requester, dht_requester, next_service, node_identity, @@ -86,7 +86,7 @@ where S: Service + Cl MessageHandlerTask::new( self.config.clone(), self.next_service.clone(), - Arc::clone(&self.store), + self.saf_requester.clone(), self.dht_requester.clone(), Arc::clone(&self.peer_manager), self.outbound_service.clone(), diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index bcbf3696da..9a0274037d 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -24,15 +24,28 @@ use crate::{ actor::DhtRequester, config::DhtConfig, crypt, - envelope::{Destination, DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, + envelope::{DhtMessageFlags, DhtMessageHeader, DhtMessageOrigin, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, outbound::{OutboundMessageRequester, SendMessageParams}, proto::{ envelope::DhtMessageType, - store_forward::{StoredMessage, StoredMessagesRequest, StoredMessagesResponse}, + store_forward::{ + stored_messages_response::SafResponseType, + StoredMessage as ProtoStoredMessage, + StoredMessagesRequest, + StoredMessagesResponse, + }, + }, + storage::DhtSettingKey, + store_forward::{ + error::StoreAndForwardError, + message::{datetime_to_timestamp, timestamp_to_datetime}, + service::FetchStoredMessageQuery, + StoreAndForwardRequester, }, - store_forward::{error::StoreAndForwardError, SafStorage}, + utils::try_convert_all, }; +use chrono::Utc; use digest::Digest; use futures::{future, stream, Future, StreamExt}; use log::*; @@ -45,7 +58,6 @@ use tari_comms::{ types::Challenge, utils::signature, }; -use tari_crypto::tari_utilities::ByteArray; use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::store_forward"; @@ -58,7 +70,7 @@ pub struct MessageHandlerTask { outbound_service: OutboundMessageRequester, node_identity: Arc, message: Option, - store: Arc, + saf_requester: StoreAndForwardRequester, } impl MessageHandlerTask @@ -68,7 +80,7 @@ where S: Service pub fn new( config: DhtConfig, next_service: S, - store: Arc, + saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, outbound_service: OutboundMessageRequester, @@ -78,7 +90,7 @@ where S: Service { Self { config, - store, + saf_requester, dht_requester, next_service, peer_manager, @@ -94,10 +106,10 @@ where S: Service .take() .expect("DhtInboundMessageTask initialized without message"); - if message.dht_header.message_type.is_dht_message() && message.decryption_failed() { + if message.dht_header.message_type.is_saf_message() && message.decryption_failed() { debug!( target: LOG_TARGET, - "Received SAFRetrieveMessages message which could not decrypt from NodeId={}. Discarding message.", + "Received store and forward message which could not decrypt from NodeId={}. Discarding message.", message.source_peer.node_id ); return Ok(()); @@ -157,59 +169,60 @@ where S: Service return Ok(()); } + let source_pubkey = Box::new(message.source_peer.public_key.clone()); + // Compile a set of stored messages for the requesting peer - let messages = self.store.with_lock(|mut store| { - store - .iter() - // All messages within start_time (if specified) - .filter(|(_, msg)| { - retrieve_msgs.since.as_ref().map(|since| msg.stored_at.as_ref().map(|s| since.seconds <= s.seconds).unwrap_or( false)).unwrap_or( true) - }) - .filter(|(_, msg)|{ - if msg.dht_header.is_none() { - warn!(target: LOG_TARGET, "Message was stored without a header. This should never happen!"); - return false; - } - let dht_header = msg.dht_header.as_ref().expect("previously checked"); - - match &dht_header.destination { - None=> false, - // The stored message was sent with an undisclosed recipient. Perhaps this node - // is interested in it - Some(Destination::Unknown(_)) => true, - // Was the stored message sent for the requesting node public key? - Some(Destination::PublicKey(pk)) => pk.as_slice() == message.source_peer.public_key.as_bytes(), - // Was the stored message sent for the requesting node node id? - Some( Destination::NodeId(node_id)) => node_id.as_slice() == message.source_peer.node_id.as_bytes(), - } - }) - .take(self.config.saf_max_returned_messages) - .map(|(_, msg)| msg) - .cloned() - .collect::>() - }); + let mut query = FetchStoredMessageQuery::new(source_pubkey); + if let Some(since) = retrieve_msgs.since.map(timestamp_to_datetime) { + query.since(since); + } - let stored_messages: StoredMessagesResponse = messages.into(); + let response_types = vec![ + SafResponseType::Discovery, + SafResponseType::Join, + SafResponseType::ExplicitlyAddressed, + ]; - trace!( - target: LOG_TARGET, - "Responding to received message retrieval request with {} message(s)", - stored_messages.messages().len() - ); - self.outbound_service - .send_message_no_header( - SendMessageParams::new() - .direct_public_key(message.source_peer.public_key.clone()) - .with_dht_message_type(DhtMessageType::SafStoredMessages) - .finish(), - stored_messages, - ) - .await?; + for resp_type in response_types { + query.with_response_type(resp_type); + let messages = self.saf_requester.fetch_messages(query.clone()).await?; + + if messages.is_empty() { + debug!( + target: LOG_TARGET, + "No {:?} stored messages for peer '{}'", + resp_type, + message.source_peer.node_id.short_str() + ); + continue; + } + + let stored_messages = StoredMessagesResponse { + messages: try_convert_all(messages)?, + request_id: retrieve_msgs.request_id, + response_type: resp_type as i32, + }; + + info!( + target: LOG_TARGET, + "Responding to received message retrieval request with {} message(s)", + stored_messages.messages().len() + ); + self.outbound_service + .send_message_no_header( + SendMessageParams::new() + .direct_public_key(message.source_peer.public_key.clone()) + .with_dht_message_type(DhtMessageType::SafStoredMessages) + .finish(), + stored_messages, + ) + .await?; + } Ok(()) } - async fn handle_stored_messages(self, message: DecryptedDhtMessage) -> Result<(), StoreAndForwardError> { + async fn handle_stored_messages(mut self, message: DecryptedDhtMessage) -> Result<(), StoreAndForwardError> { trace!( target: LOG_TARGET, "Received stored messages from {}", @@ -230,6 +243,38 @@ where S: Service response.messages().len() ); + let last_timestamp = self + .dht_requester + .get_setting(DhtSettingKey::SafLastRequestTimestamp) + .await? + .map(datetime_to_timestamp); + + let max_stored_timestamp = response + .messages() + .iter() + .fold(last_timestamp, |acc, m| match &acc { + Some(since) => { + match &m.stored_at { + Some(ts) => { + // If this timestamp is greater than the last one we have + if ts.seconds > since.seconds && ts.seconds < Utc::now().timestamp() { + m.stored_at.clone() + } else { + acc + } + }, + None => acc, + } + }, + None => m.stored_at.clone(), + }) + .map(timestamp_to_datetime) + .unwrap_or_else(Utc::now); + + self.dht_requester + .set_setting(DhtSettingKey::SafLastRequestTimestamp, max_stored_timestamp) + .await?; + let tasks = response .messages .into_iter() @@ -309,7 +354,7 @@ where S: Service fn process_incoming_stored_message( &self, source_peer: Arc, - message: StoredMessage, + message: ProtoStoredMessage, ) -> impl Future> { let node_identity = Arc::clone(&self.node_identity); @@ -347,15 +392,14 @@ where S: Service // Check that the destination is either undisclosed Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; // Verify the signature - Self::check_signature(origin, &message.encrypted_body)?; + Self::check_signature(origin, &message.body)?; // Check that the message has not already been received. - Self::check_duplicate(&mut dht_requester, &message.encrypted_body).await?; + Self::check_duplicate(&mut dht_requester, &message.body).await?; // Attempt to decrypt the message (if applicable), and deserialize it - let decrypted_body = - Self::maybe_decrypt_and_deserialize(&node_identity, origin, dht_flags, &message.encrypted_body)?; + let decrypted_body = Self::maybe_decrypt_and_deserialize(&node_identity, origin, dht_flags, &message.body)?; - let inbound_msg = DhtInboundMessage::new(dht_header, Arc::clone(&source_peer), message.encrypted_body); + let inbound_msg = DhtInboundMessage::new(dht_header, Arc::clone(&source_peer), message.body); Ok(DecryptedDhtMessage::succeeded(decrypted_body, inbound_msg)) } @@ -430,9 +474,12 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - store_forward::message::datetime_to_timestamp, + proto::envelope::DhtHeader, + store_forward::{message::StoredMessagePriority, StoredMessage}, test_utils::{ create_dht_actor_mock, + create_store_and_forward_mock, + make_dht_header, make_dht_inbound_message, make_node_identity, make_peer_manager, @@ -443,17 +490,34 @@ mod test { use chrono::Utc; use futures::channel::mpsc; use prost::Message; - use std::time::Duration; use tari_comms::{message::MessageExt, wrap_in_envelope_body}; + use tari_crypto::tari_utilities::hex::Hex; use tokio::runtime::Handle; // TODO: unit tests for static functions (check_signature, etc) + fn make_stored_message(node_identity: &NodeIdentity, dht_header: DhtMessageHeader) -> StoredMessage { + StoredMessage { + id: 1, + version: 0, + origin_pubkey: node_identity.public_key().to_hex(), + origin_signature: String::new(), + message_type: DhtMessageType::None as i32, + destination_pubkey: None, + destination_node_id: None, + header: DhtHeader::from(dht_header).to_encoded_bytes().unwrap(), + body: b"A".to_vec(), + is_encrypted: false, + priority: StoredMessagePriority::High as i32, + stored_at: Utc::now().naive_utc(), + } + } + #[tokio_macros::test_basic] async fn request_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); - let storage = Arc::new(SafStorage::new(10)); + let (requester, mock_state) = create_store_and_forward_mock(); let peer_manager = make_peer_manager(); let (oms_tx, mut oms_rx) = mpsc::channel(1); @@ -461,33 +525,14 @@ mod test { let node_identity = make_node_identity(); // Recent message - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - storage.insert( - vec![0], - StoredMessage::new(0, inbound_msg.dht_header, b"A".to_vec()), - Duration::from_secs(60), - ); - - // Expired message - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - storage.insert( - vec![1], - StoredMessage::new(0, inbound_msg.dht_header, vec![]), - Duration::from_secs(0), - ); - - // Out of time range - let inbound_msg = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty()); - let mut msg = StoredMessage::new(0, inbound_msg.dht_header, vec![]); - msg.stored_at = Some(datetime_to_timestamp( - Utc::now().checked_sub_signed(chrono::Duration::days(1)).unwrap(), - )); + let dht_header = make_dht_header(&node_identity, &[], DhtMessageFlags::empty()); + mock_state + .add_message(make_stored_message(&node_identity, dht_header)) + .await; + let since = Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap(); let mut message = DecryptedDhtMessage::succeeded( - wrap_in_envelope_body!(StoredMessagesRequest::since( - Utc::now().checked_sub_signed(chrono::Duration::seconds(60)).unwrap() - )) - .unwrap(), + wrap_in_envelope_body!(StoredMessagesRequest::since(since)).unwrap(), make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), ); message.dht_header.message_type = DhtMessageType::SafRequestMessages; @@ -498,11 +543,11 @@ mod test { let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), - storage, + requester, dht_requester, peer_manager, OutboundMessageRequester::new(oms_tx), - node_identity, + node_identity.clone(), message, ); @@ -512,15 +557,21 @@ mod test { let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 1); - assert_eq!(msg.messages()[0].encrypted_body, b"A"); + assert_eq!(msg.messages()[0].body, b"A"); assert!(!spy.is_called()); + + assert_eq!(mock_state.call_count(), 1); + let calls = mock_state.take_calls().await; + assert!(calls[0].contains("FetchMessages")); + assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); + assert!(calls[0].contains(format!("{:?}", since).as_str())); } #[tokio_macros::test_basic] async fn receive_stored_messages() { let rt_handle = Handle::current(); let spy = service_spy(); - let storage = Arc::new(SafStorage::new(10)); + let (requester, _) = create_store_and_forward_mock(); let peer_manager = make_peer_manager(); let (oms_tx, _) = mpsc::channel(1); @@ -559,8 +610,8 @@ mod test { .await .unwrap(); - let msg1 = StoredMessage::new(0, inbound_msg_a.dht_header.clone(), msg_a); - let msg2 = StoredMessage::new(0, inbound_msg_b.dht_header, msg_b); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), msg_a); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, msg_b); // Cleartext message let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()) .unwrap() @@ -568,10 +619,12 @@ mod test { .unwrap(); let clear_header = make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty()).dht_header; - let msg_clear = StoredMessage::new(0, clear_header, clear_msg); + let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], + request_id: 123, + response_type: 0 }) .unwrap(), make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::ENCRYPTED), @@ -586,7 +639,7 @@ mod test { let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), - storage, + requester, dht_requester, peer_manager, OutboundMessageRequester::new(oms_tx), @@ -601,11 +654,11 @@ mod test { // Deserialize each request into the message (a vec of a single byte in this case) let msgs = requests .into_iter() - .map(|req| req.success().unwrap().decode_part::>(0).unwrap().unwrap()) + .map(|req| req.success().unwrap().decode_part::>(0).unwrap().unwrap()) .collect::>>(); assert!(msgs.contains(&b"A".to_vec())); assert!(msgs.contains(&b"B".to_vec())); assert!(msgs.contains(&b"Clear".to_vec())); - assert_eq!(mock_state.call_count(), msgs.len()); + mock_state.get_setting(&DhtSettingKey::SafLastRequestTimestamp).unwrap(); } } diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs new file mode 100644 index 0000000000..fb41681f1b --- /dev/null +++ b/comms/dht/src/store_forward/service.rs @@ -0,0 +1,274 @@ +// Copyright 2020, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use super::{ + database::{NewStoredMessage, StoreAndForwardDatabase, StoredMessage}, + message::StoredMessagePriority, + SafResult, + StoreAndForwardError, +}; +use crate::{ + envelope::DhtMessageType, + proto::store_forward::stored_messages_response::SafResponseType, + storage::DbConnection, + DhtConfig, +}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use futures::{ + channel::{mpsc, oneshot}, + stream::Fuse, + SinkExt, + StreamExt, +}; +use log::*; +use std::{convert::TryFrom, time::Duration}; +use tari_comms::types::CommsPublicKey; +use tari_shutdown::ShutdownSignal; +use tokio::time; + +const LOG_TARGET: &str = "comms::dht::store_forward::actor"; +/// The interval to initiate a database cleanup. +/// This involves cleaning up messages which have been stored too long according to their priority +const CLEANUP_INTERVAL: Duration = Duration::from_secs(10 * 60); // 10 mins + +#[derive(Debug, Clone)] +pub struct FetchStoredMessageQuery { + public_key: Box, + since: Option>, + response_type: SafResponseType, +} + +impl FetchStoredMessageQuery { + pub fn new(public_key: Box) -> Self { + Self { + public_key, + since: None, + response_type: SafResponseType::General, + } + } + + pub fn since(&mut self, since: DateTime) -> &mut Self { + self.since = Some(since); + self + } + + pub fn with_response_type(&mut self, response_type: SafResponseType) -> &mut Self { + self.response_type = response_type; + self + } +} + +#[derive(Debug)] +pub enum StoreAndForwardRequest { + FetchMessages(FetchStoredMessageQuery, oneshot::Sender>>), + InsertMessage(NewStoredMessage), +} + +#[derive(Clone)] +pub struct StoreAndForwardRequester { + sender: mpsc::Sender, +} + +impl StoreAndForwardRequester { + pub fn new(sender: mpsc::Sender) -> Self { + Self { sender } + } + + pub async fn fetch_messages( + &mut self, + request: FetchStoredMessageQuery, + ) -> Result, StoreAndForwardError> + { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(StoreAndForwardRequest::FetchMessages(request, reply_tx)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + reply_rx.await.map_err(|_| StoreAndForwardError::RequestCancelled)? + } + + pub async fn insert_message(&mut self, message: NewStoredMessage) -> Result<(), StoreAndForwardError> { + self.sender + .send(StoreAndForwardRequest::InsertMessage(message)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } +} + +pub struct StoreAndForwardService { + config: DhtConfig, + request_rx: Fuse>, + shutdown_signal: Option, +} + +impl StoreAndForwardService { + pub fn new( + config: DhtConfig, + request_rx: mpsc::Receiver, + shutdown_signal: ShutdownSignal, + ) -> Self + { + Self { + config, + request_rx: request_rx.fuse(), + shutdown_signal: Some(shutdown_signal), + } + } + + pub(crate) async fn connect_database(&self) -> SafResult { + let conn = DbConnection::connect_url(self.config.database_url.clone()).await?; + let output = conn.migrate().await?; + info!(target: LOG_TARGET, "Store and forward database migration:\n{}", output); + Ok(StoreAndForwardDatabase::new(conn)) + } + + pub async fn run(mut self) -> SafResult<()> { + let db = self.connect_database().await?; + let mut shutdown_signal = self + .shutdown_signal + .take() + .expect("StoreAndForwardActor initialized without shutdown_signal"); + + let mut cleanup_ticker = time::interval(CLEANUP_INTERVAL).fuse(); + + // Do initial cleanup to account for time passed since being offline + if let Err(err) = self.cleanup(&db).await { + error!( + target: LOG_TARGET, + "Error when performing store and forward cleanup: {:?}", err + ); + } + + loop { + futures::select! { + request = self.request_rx.select_next_some() => { + self.handle_request(&db, request).await; + }, + + _ = cleanup_ticker.next() => { + if let Err(err) = self.cleanup(&db).await { + error!(target: LOG_TARGET, "Error when performing store and forward cleanup: {:?}", err); + } + }, + + _ = shutdown_signal => { + info!(target: LOG_TARGET, "StoreAndForwardActor is shutting down because the shutdown signal was triggered"); + break; + } + } + } + + Ok(()) + } + + async fn handle_request(&self, db: &StoreAndForwardDatabase, request: StoreAndForwardRequest) { + use StoreAndForwardRequest::*; + match request { + FetchMessages(query, reply_tx) => match self.handle_fetch_message_query(db, query).await { + Ok(messages) => { + let _ = reply_tx.send(Ok(messages)); + }, + Err(err) => { + error!( + target: LOG_TARGET, + "find_messages_by_public_key failed because '{:?}'", err + ); + let _ = reply_tx.send(Err(err)); + }, + }, + InsertMessage(msg) => { + let public_key = msg.destination_pubkey.clone(); + match db.insert_message(msg).await { + Ok(_) => info!( + target: LOG_TARGET, + "Store and forward message stored for public key '{}'", + public_key.unwrap_or_else(|| "".to_string()) + ), + Err(err) => { + error!(target: LOG_TARGET, "insert_message failed because '{:?}'", err); + }, + } + }, + } + } + + async fn handle_fetch_message_query( + &self, + db: &StoreAndForwardDatabase, + query: FetchStoredMessageQuery, + ) -> SafResult> + { + let limit = i64::try_from(self.config.saf_max_returned_messages) + .ok() + .or(Some(std::i64::MAX)) + .unwrap(); + let messages = match query.response_type { + SafResponseType::General => { + db.find_messages_for_public_key(&query.public_key, query.since, limit) + .await? + }, + SafResponseType::Join => { + db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Join, query.since, limit) + .await? + }, + SafResponseType::Discovery => { + db.find_messages_of_type_for_pubkey(&query.public_key, DhtMessageType::Discovery, query.since, limit) + .await? + }, + SafResponseType::ExplicitlyAddressed => { + db.find_messages_for_public_key(&query.public_key, query.since, limit) + .await? + }, + }; + + Ok(messages) + } + + async fn cleanup(&self, db: &StoreAndForwardDatabase) -> SafResult<()> { + let num_removed = db + .delete_messages_with_priority_older_than( + StoredMessagePriority::Low, + since(self.config.saf_low_priority_msg_storage_ttl), + ) + .await?; + info!(target: LOG_TARGET, "Cleaned {} old low priority messages", num_removed); + + let num_removed = db + .delete_messages_with_priority_older_than( + StoredMessagePriority::High, + since(self.config.saf_high_priority_msg_storage_ttl), + ) + .await?; + info!(target: LOG_TARGET, "Cleaned {} old high priority messages", num_removed); + Ok(()) + } +} + +fn since(period: Duration) -> NaiveDateTime { + use chrono::Duration as OldDuration; + let period = OldDuration::from_std(period).expect("period was out of range for chrono::Duration"); + Utc::now() + .naive_utc() + .checked_sub_signed(period) + .expect("period overflowed when used with checked_sub_signed") +} diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 6aca6b2b2f..dd99f2bb94 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -20,11 +20,17 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use super::StoreAndForwardRequester; use crate::{ - envelope::{DhtMessageFlags, NodeDestination}, + envelope::NodeDestination, inbound::DecryptedDhtMessage, - proto::store_forward::StoredMessage, - store_forward::{error::StoreAndForwardError, state::SafStorage}, + proto::dht::JoinMessage, + store_forward::{ + database::NewStoredMessage, + error::StoreAndForwardError, + message::StoredMessagePriority, + SafResult, + }, DhtConfig, }; use futures::{task::Context, Future}; @@ -32,9 +38,10 @@ use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ message::MessageExt, - peer_manager::{NodeIdentity, PeerManager}, + peer_manager::{NodeId, NodeIdentity, PeerManager}, pipeline::PipelineError, }; +use tari_crypto::tari_utilities::ByteArray; use tower::{layer::Layer, Service, ServiceExt}; const LOG_TARGET: &str = "comms::middleware::forward"; @@ -44,7 +51,7 @@ pub struct StoreLayer { peer_manager: Arc, config: DhtConfig, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreLayer { @@ -52,14 +59,14 @@ impl StoreLayer { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { peer_manager, config, node_identity, - storage, + saf_requester, } } } @@ -73,7 +80,7 @@ impl Layer for StoreLayer { self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&self.storage), + self.saf_requester.clone(), ) } } @@ -84,8 +91,7 @@ pub struct StoreMiddleware { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - - storage: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreMiddleware { @@ -94,7 +100,7 @@ impl StoreMiddleware { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { @@ -102,7 +108,7 @@ impl StoreMiddleware { config, peer_manager, node_identity, - storage, + saf_requester, } } } @@ -125,7 +131,7 @@ where S: Service + Cl self.config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), - Arc::clone(&self.storage), + self.saf_requester.clone(), ) .handle(msg) } @@ -135,7 +141,10 @@ where S: Service + Cl /// to the next service. struct StoreTask { next_service: S, - storage: Option, + peer_manager: Arc, + config: DhtConfig, + node_identity: Arc, + saf_requester: StoreAndForwardRequester, } impl StoreTask { @@ -144,16 +153,14 @@ impl StoreTask { config: DhtConfig, peer_manager: Arc, node_identity: Arc, - storage: Arc, + saf_requester: StoreAndForwardRequester, ) -> Self { Self { - storage: Some(InnerStorage { - config, - peer_manager, - node_identity, - storage, - }), + config, + peer_manager, + node_identity, + saf_requester, next_service, } } @@ -162,95 +169,201 @@ impl StoreTask { impl StoreTask where S: Service { + /// Determine if this is a message we should store for our peers and, if so, store it. + /// + /// The criteria for storing a message is: + /// 1. Messages MUST have a message origin set and be encrypted (Join messages are the exception) + /// 1. Unencrypted Join messages - this increases the knowledge the network has of peers (Low priority) + /// 1. Encrypted Discovery messages - so that nodes are aware of other nodes that are looking for them (High + /// priority) 1. Encrypted messages addressed to the neighbourhood - some node in the neighbourhood may be + /// interested in this message (High priority) 1. Encrypted messages addressed to a particular public key or + /// node id that this node knows about async fn handle(mut self, message: DecryptedDhtMessage) -> Result<(), PipelineError> { + if let Some(priority) = self + .get_storage_priority(&message) + .await + .map_err(PipelineError::from_debug)? + { + self.store(priority, message.clone()) + .await + .map_err(PipelineError::from_debug)?; + } + + trace!(target: LOG_TARGET, "Passing message to next service"); + self.next_service.oneshot(message).await?; + + Ok(()) + } + + async fn get_storage_priority(&self, message: &DecryptedDhtMessage) -> SafResult> { + let log_not_eligible = |reason: &str| { + debug!( + target: LOG_TARGET, + "Message from peer '{}' not eligible for SAF storage because {}", + message.source_peer.node_id.short_str(), + reason + ); + }; + + if message.body_size() > self.config.saf_max_message_size { + log_not_eligible(&format!( + "the message body exceeded the maximum storage size (body size={}, max={})", + message.body_size(), + self.config.saf_max_message_size + )); + return Ok(None); + } + + if message.origin_public_key() == self.node_identity.public_key() { + log_not_eligible("not storing message from this node"); + return Ok(None); + } + match message.success() { + // The message decryption was successful, or the message was not encrypted Some(_) => { - // If message was not originally encrypted and has an origin we want to store a copy for others - if message.dht_header.origin.is_some() && !message.dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) - { - debug!( - target: LOG_TARGET, - "Cleartext message sent from origin {}. Adding to SAF storage.", - message.origin_public_key() - ); - let mut storage = self.storage.take().expect("StoreTask initialized without storage"); - storage - .store(message.clone()) - .await - .map_err(PipelineError::from_debug)?; + // If the message doesnt have an origin we wont store it + if !message.has_origin() { + log_not_eligible("it does not have an origin"); + return Ok(None); + } + + // If this node decrypted the message, no need to store it + if message.is_encrypted() { + log_not_eligible("the message was encrypted for this node"); + return Ok(None); } - trace!(target: LOG_TARGET, "Passing message to next service"); - self.next_service.oneshot(message).await?; + // If this is a join message, we may want to store it if it's for our neighbourhood + if message.dht_header.message_type.is_dht_join() { + return match self.get_priority_for_dht_join(message).await? { + Some(priority) => Ok(Some(priority)), + None => { + log_not_eligible("the join message was not considered in this node's neighbourhood"); + Ok(None) + }, + }; + } + + log_not_eligible("it is not an eligible DhtMessageType (e.g. Join)"); + // Otherwise, don't store + Ok(None) }, + // This node could not decrypt the message None => { - if message.dht_header.origin.is_none() { - // TODO: #banheuristic + if !message.has_origin() { + // TODO: #banheuristic - the source should not have propagated this message warn!( target: LOG_TARGET, "Store task received an encrypted message with no source. This message is invalid and should \ not be stored or propagated. Dropping message. Sent by node '{}'", message.source_peer.node_id.short_str() ); - return Ok(()); + return Ok(None); } - debug!( - target: LOG_TARGET, - "Decryption failed for message. Adding to SAF storage." - ); - let mut storage = self.storage.take().expect("StoreTask initialized without storage"); - storage.store(message).await.map_err(PipelineError::from_debug)?; + + // The destination of the message will determine if we store it + self.get_priority_by_destination(message).await }, } - - Ok(()) } -} -struct InnerStorage { - peer_manager: Arc, - config: DhtConfig, - node_identity: Arc, - storage: Arc, -} - -impl InnerStorage { - async fn store(&mut self, message: DecryptedDhtMessage) -> Result<(), StoreAndForwardError> { - let DecryptedDhtMessage { - version, - decryption_result, - dht_header, - .. - } = message; + async fn get_priority_for_dht_join( + &self, + message: &DecryptedDhtMessage, + ) -> SafResult> + { + debug_assert!(message.dht_header.message_type.is_dht_join() && !message.is_encrypted()); + + let body = message + .decryption_result + .as_ref() + .expect("already checked that this message is not encrypted"); + let join_msg = body + .decode_part::(0)? + .ok_or_else(|| StoreAndForwardError::InvalidEnvelopeBody)?; + let node_id = NodeId::from_bytes(&join_msg.node_id).map_err(StoreAndForwardError::MalformedNodeId)?; + + // If this join request is for a peer that we'd consider to be a neighbour, store it for other neighbours + if self + .peer_manager + .in_network_region( + &node_id, + self.node_identity.node_id(), + self.config.num_neighbouring_nodes, + ) + .await? + { + return Ok(Some(StoredMessagePriority::Low)); + } - let origin = dht_header.origin.as_ref().expect("already checked"); + Ok(None) + } - let body = match decryption_result { - Ok(body) => body.to_encoded_bytes()?, - Err(encrypted_body) => encrypted_body, + async fn get_priority_by_destination( + &self, + message: &DecryptedDhtMessage, + ) -> SafResult> + { + let log_not_eligible = |reason: &str| { + debug!( + target: LOG_TARGET, + "Message from peer '{}' not eligible for SAF storage because {}", + message.source_peer.node_id.short_str(), + reason + ); }; let peer_manager = &self.peer_manager; let node_identity = &self.node_identity; - match &dht_header.destination { - NodeDestination::Unknown => { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_low_priority_msg_storage_ttl, - ); + if message.dht_header.destination == node_identity.public_key() || + message.dht_header.destination == node_identity.node_id() + { + log_not_eligible("the message is destined for this node"); + return Ok(None); + } + + use NodeDestination::*; + match &message.dht_header.destination { + Unknown => { + // No destination provided, only discovery messages are currently important enough to be stored + if message.dht_header.message_type.is_dht_discovery() { + Ok(Some(StoredMessagePriority::Low)) + } else { + log_not_eligible("destination is unknown, and message is not a Discovery"); + Ok(None) + } }, - NodeDestination::PublicKey(dest_public_key) => { - if peer_manager.exists(&dest_public_key).await { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_high_priority_msg_storage_ttl, - ); + PublicKey(dest_public_key) => { + // If we know the destination peer, keep the message for them + match peer_manager.find_by_public_key(&dest_public_key).await { + Ok(peer) => { + if peer.is_banned() { + log_not_eligible( + "origin peer is banned. ** This should not happen because it should have been checked \ + earlier in the pipeline **", + ); + Ok(None) + } else if peer.is_offline() || peer.is_recently_offline() { + Ok(Some(StoredMessagePriority::High)) + } else { + // TODO: Could be that we propagated this message to the peer in which case there is no need + // to store the message + Ok(Some(StoredMessagePriority::Low)) + } + }, + Err(err) if err.is_peer_not_found() => { + log_not_eligible(&format!( + "this node does not know the destination public key '{}'", + dest_public_key + )); + Ok(None) + }, + Err(err) => Err(err.into()), } }, - NodeDestination::NodeId(dest_node_id) => { + NodeId(dest_node_id) => { if peer_manager.exists_node_id(&dest_node_id).await || peer_manager .in_network_region( @@ -260,15 +373,43 @@ impl InnerStorage { ) .await? { - self.storage.insert( - origin.signature.clone(), - StoredMessage::new(version, dht_header, body), - self.config.saf_high_priority_msg_storage_ttl, - ); + Ok(Some(StoredMessagePriority::High)) + } else { + log_not_eligible(&format!( + "this node does not know the destination node id '{}' or does not consider it a neighbouring \ + node id", + dest_node_id + )); + Ok(None) } }, + } + } + + async fn store(&mut self, priority: StoredMessagePriority, message: DecryptedDhtMessage) -> SafResult<()> { + let DecryptedDhtMessage { + version, + decryption_result, + dht_header, + .. + } = message; + + let body = match decryption_result { + Ok(body) => body.to_encoded_bytes()?, + Err(encrypted_body) => encrypted_body, }; + debug!( + target: LOG_TARGET, + "Storing message from peer '{}' ({} bytes)", + message.source_peer.node_id.short_str(), + body.len() + ); + + let stored_message = NewStoredMessage::try_construct(version, dht_header, priority, body) + .ok_or_else(|| StoreAndForwardError::InvalidStoreMessage)?; + self.saf_requester.insert_message(stored_message).await?; + Ok(()) } } @@ -278,20 +419,29 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - test_utils::{make_dht_inbound_message, make_node_identity, make_peer_manager, service_spy}, + proto::envelope::DhtMessageType, + test_utils::{ + create_store_and_forward_mock, + make_dht_inbound_message, + make_node_identity, + make_peer_manager, + service_spy, + }, }; - use chrono::{DateTime, Utc}; - use std::time::{Duration, UNIX_EPOCH}; + use chrono::Utc; + use std::time::Duration; use tari_comms::wrap_in_envelope_body; + use tari_crypto::tari_utilities::hex::Hex; + use tari_test_utils::async_assert_eventually; #[tokio_macros::test_basic] async fn cleartext_message_no_origin() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); let mut inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); @@ -299,66 +449,95 @@ mod test { let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 0); - }); + let messages = mock_state.get_messages().await; + assert_eq!(messages.len(), 0); } #[tokio_macros::test_basic] - async fn cleartext_message_with_origin() { - let storage = Arc::new(SafStorage::new(1)); + async fn cleartext_join_message() { + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) - .layer(spy.to_service::()); + let join_msg_bytes = JoinMessage { + node_id: node_identity.node_id().to_vec(), + addresses: vec![], + peer_features: 0, + } + .to_encoded_bytes() + .unwrap(); + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) + .layer(spy.to_service::()); let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); - let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(Vec::new()).unwrap(), inbound_msg); + + let mut msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(join_msg_bytes).unwrap(), inbound_msg); + msg.dht_header.message_type = DhtMessageType::Join; service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 1); - }); + + // Because we dont wait for the message to reach the mock/service before continuing (for efficiency and it's not + // necessary) we need to wait for the call to happen eventually - it should be almost instant + async_assert_eventually!( + mock_state.call_count(), + expect = 1, + max_attempts = 10, + interval = Duration::from_millis(10), + ); + let messages = mock_state.get_messages().await; + assert_eq!(messages[0].message_type, DhtMessageType::Join as i32); } #[tokio_macros::test_basic] async fn decryption_succeeded_no_store() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, storage.clone()) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::ENCRYPTED); let msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(b"secret".to_vec()).unwrap(), inbound_msg); service.call(msg).await.unwrap(); assert!(spy.is_called()); - storage.with_lock(|mut lock| { - assert_eq!(lock.iter().count(), 0); - }); + + assert_eq!(mock_state.call_count(), 0); } #[tokio_macros::test_basic] async fn decryption_failed_should_store() { - let storage = Arc::new(SafStorage::new(1)); + let (requester, mock_state) = create_store_and_forward_mock(); let spy = service_spy(); let peer_manager = make_peer_manager(); + let origin_node_identity = make_node_identity(); + peer_manager.add_peer(origin_node_identity.to_peer()).await.unwrap(); let node_identity = make_node_identity(); - let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, Arc::clone(&storage)) + let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); - let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty()); + let mut inbound_msg = make_dht_inbound_message(&origin_node_identity, b"".to_vec(), DhtMessageFlags::empty()); + inbound_msg.dht_header.destination = + NodeDestination::PublicKey(Box::new(origin_node_identity.public_key().clone())); let msg = DecryptedDhtMessage::failed(inbound_msg.clone()); service.call(msg).await.unwrap(); - assert_eq!(spy.is_called(), false); - let msg = storage - .remove(&inbound_msg.dht_header.origin.unwrap().signature) - .unwrap(); - let timestamp: DateTime = (UNIX_EPOCH + Duration::from_secs(msg.stored_at.unwrap().seconds as u64)).into(); - assert!((Utc::now() - timestamp).num_seconds() <= 5); + assert_eq!(spy.is_called(), true); + + async_assert_eventually!( + mock_state.call_count(), + expect = 1, + max_attempts = 10, + interval = Duration::from_millis(10), + ); + + let message = mock_state.get_messages().await.remove(0); + assert_eq!( + message.origin_signature, + inbound_msg.dht_header.origin.unwrap().signature.to_hex() + ); + let duration = Utc::now().naive_utc().signed_duration_since(message.stored_at); + assert!(duration.num_seconds() <= 5); } } diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index a30445ce5b..47abd0339a 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -20,12 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::actor::{DhtRequest, DhtRequester}; +use crate::{ + actor::{DhtRequest, DhtRequester}, + storage::DhtSettingKey, +}; use futures::{channel::mpsc, stream::Fuse, StreamExt}; -use std::sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - RwLock, +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, + RwLock, + }, }; use tari_comms::peer_manager::Peer; @@ -39,6 +45,7 @@ pub struct DhtMockState { signature_cache_insert: Arc, call_count: Arc, select_peers: Arc>>, + settings: Arc>>>, } impl DhtMockState { @@ -47,6 +54,7 @@ impl DhtMockState { signature_cache_insert: Arc::new(AtomicBool::new(false)), call_count: Arc::new(AtomicUsize::new(0)), select_peers: Arc::new(RwLock::new(Vec::new())), + settings: Arc::new(RwLock::new(HashMap::new())), } } @@ -64,8 +72,8 @@ impl DhtMockState { self.call_count.fetch_add(1, Ordering::SeqCst); } - pub fn call_count(&self) -> usize { - self.call_count.load(Ordering::SeqCst) + pub fn get_setting(&self, key: &DhtSettingKey) -> Option> { + self.settings.read().unwrap().get(&key.to_string()).map(Clone::clone) } } @@ -105,7 +113,19 @@ impl DhtActorMock { let lock = self.state.select_peers.read().unwrap(); reply_tx.send(lock.clone()).unwrap(); }, - SendRequestStoredMessages(_) => {}, + SendRequestStoredMessages => {}, + GetSetting(key, reply_tx) => { + let _ = reply_tx.send(Ok(self + .state + .settings + .read() + .unwrap() + .get(&key.to_string()) + .map(Clone::clone))); + }, + SetSetting(key, value) => { + self.state.settings.write().unwrap().insert(key.to_string(), value); + }, } } } diff --git a/comms/dht/src/test_utils/makers.rs b/comms/dht/src/test_utils/makers.rs index d04211c927..650de26e4c 100644 --- a/comms/dht/src/test_utils/makers.rs +++ b/comms/dht/src/test_utils/makers.rs @@ -91,7 +91,7 @@ pub fn make_comms_inbound_message(node_identity: &NodeIdentity, message: Bytes, ) } -pub fn make_dht_header(node_identity: &NodeIdentity, message: &Vec, flags: DhtMessageFlags) -> DhtMessageHeader { +pub fn make_dht_header(node_identity: &NodeIdentity, message: &[u8], flags: DhtMessageFlags) -> DhtMessageHeader { DhtMessageHeader { version: 0, destination: NodeDestination::Unknown, diff --git a/comms/dht/src/test_utils/mod.rs b/comms/dht/src/test_utils/mod.rs index 90f9358d43..0550c92e78 100644 --- a/comms/dht/src/test_utils/mod.rs +++ b/comms/dht/src/test_utils/mod.rs @@ -38,11 +38,16 @@ macro_rules! unwrap_oms_send_msg { } mod dht_actor_mock; -mod dht_discovery_mock; -mod makers; -mod service; - pub use dht_actor_mock::{create_dht_actor_mock, DhtMockState}; + +mod dht_discovery_mock; pub use dht_discovery_mock::{create_dht_discovery_mock, DhtDiscoveryMockState}; + +mod makers; pub use makers::*; + +mod service; pub use service::{service_fn, service_spy}; + +mod store_and_forward_mock; +pub use store_and_forward_mock::{create_store_and_forward_mock, StoreAndForwardMockState}; diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs new file mode 100644 index 0000000000..e338b323c4 --- /dev/null +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -0,0 +1,135 @@ +// Copyright 2019, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use crate::store_forward::{StoreAndForwardRequest, StoreAndForwardRequester, StoredMessage}; +use chrono::Utc; +use futures::{channel::mpsc, stream::Fuse, StreamExt}; +use log::*; +use rand::{rngs::OsRng, RngCore}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::{runtime, sync::RwLock}; + +const LOG_TARGET: &str = "comms::dht::discovery_mock"; + +pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndForwardMockState) { + let (tx, rx) = mpsc::channel(10); + + let mock = StoreAndForwardMock::new(rx.fuse()); + let state = mock.get_shared_state(); + runtime::Handle::current().spawn(mock.run()); + (StoreAndForwardRequester::new(tx), state) +} + +#[derive(Debug, Clone)] +pub struct StoreAndForwardMockState { + call_count: Arc, + stored_messages: Arc>>, + calls: Arc>>, +} + +impl StoreAndForwardMockState { + pub fn new() -> Self { + Self { + call_count: Arc::new(AtomicUsize::new(0)), + stored_messages: Arc::new(RwLock::new(Vec::new())), + calls: Arc::new(RwLock::new(Vec::new())), + } + } + + pub fn inc_call_count(&self) { + self.call_count.fetch_add(1, Ordering::SeqCst); + } + + pub async fn add_call(&self, call: &StoreAndForwardRequest) { + self.inc_call_count(); + self.calls.write().await.push(format!("{:?}", call)); + } + + pub fn call_count(&self) -> usize { + self.call_count.load(Ordering::SeqCst) + } + + pub async fn get_messages(&self) -> Vec { + self.stored_messages.read().await.clone() + } + + pub async fn add_message(&self, message: StoredMessage) { + self.stored_messages.write().await.push(message) + } + + pub async fn take_calls(&self) -> Vec { + self.calls.write().await.drain(..).collect() + } +} + +pub struct StoreAndForwardMock { + receiver: Fuse>, + state: StoreAndForwardMockState, +} + +impl StoreAndForwardMock { + pub fn new(receiver: Fuse>) -> Self { + Self { + receiver, + state: StoreAndForwardMockState::new(), + } + } + + pub fn get_shared_state(&self) -> StoreAndForwardMockState { + self.state.clone() + } + + pub async fn run(mut self) { + while let Some(req) = self.receiver.next().await { + self.handle_request(req).await; + } + } + + async fn handle_request(&self, req: StoreAndForwardRequest) { + use StoreAndForwardRequest::*; + trace!(target: LOG_TARGET, "StoreAndForwardMock received request {:?}", req); + self.state.add_call(&req).await; + match req { + FetchMessages(_, reply_tx) => { + let msgs = self.state.stored_messages.read().await; + let _ = reply_tx.send(Ok(msgs.clone())); + }, + InsertMessage(msg) => self.state.stored_messages.write().await.push(StoredMessage { + id: OsRng.next_u32() as i32, + version: msg.version, + origin_pubkey: msg.origin_pubkey, + origin_signature: msg.origin_signature, + message_type: msg.message_type, + destination_pubkey: msg.destination_pubkey, + destination_node_id: msg.destination_node_id, + header: msg.header, + body: msg.body, + is_encrypted: msg.is_encrypted, + priority: msg.priority, + stored_at: Utc::now().naive_utc(), + }), + } + } +} diff --git a/comms/dht/src/store_forward/state.rs b/comms/dht/src/utils.rs similarity index 61% rename from comms/dht/src/store_forward/state.rs rename to comms/dht/src/utils.rs index 40a1bf513b..b99e334815 100644 --- a/comms/dht/src/store_forward/state.rs +++ b/comms/dht/src/utils.rs @@ -20,37 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::proto::store_forward::StoredMessage; -use std::{ - sync::{RwLock, RwLockWriteGuard}, - time::Duration, -}; -use ttl_cache::TtlCache; +use std::convert::TryInto; -pub type SignatureBytes = Vec; - -pub struct SafStorage { - message_cache: RwLock>, -} - -impl SafStorage { - pub fn new(cache_capacity: usize) -> Self { - Self { - message_cache: RwLock::new(TtlCache::new(cache_capacity)), - } - } - - pub fn insert(&self, key: SignatureBytes, message: StoredMessage, ttl: Duration) -> Option { - acquire_write_lock!(self.message_cache).insert(key, message, ttl) - } - - pub fn with_lock(&self, f: F) -> T - where F: FnOnce(RwLockWriteGuard>) -> T { - f(acquire_write_lock!(self.message_cache)) - } - - #[cfg(test)] - pub fn remove(&self, key: &SignatureBytes) -> Option { - acquire_write_lock!(self.message_cache).remove(key) +/// Tries to convert a series of `T`s to `U`s, returning an error at the first failure +pub fn try_convert_all(into_iter: I) -> Result, T::Error> +where + I: IntoIterator, + T: TryInto, +{ + let iter = into_iter.into_iter(); + let mut result = Vec::with_capacity(iter.size_hint().0); + for item in iter { + result.push(item.try_into()?); } + Ok(result) } diff --git a/comms/src/message/envelope.rs b/comms/src/message/envelope.rs index fcdf975c57..79b574d50c 100644 --- a/comms/src/message/envelope.rs +++ b/comms/src/message/envelope.rs @@ -151,6 +151,10 @@ impl EnvelopeBody { self.parts.len() } + pub fn total_size(&self) -> usize { + self.parts.iter().fold(0, |acc, b| acc + b.len()) + } + pub fn is_empty(&self) -> bool { self.parts.is_empty() } diff --git a/infrastructure/storage/src/lmdb_store/store.rs b/infrastructure/storage/src/lmdb_store/store.rs index ef4cc863d9..97a5813cda 100644 --- a/infrastructure/storage/src/lmdb_store/store.rs +++ b/infrastructure/storage/src/lmdb_store/store.rs @@ -622,7 +622,7 @@ mod test { #[test] fn test_lmdb_builder() { - let mut store = LMDBBuilder::new() + let store = LMDBBuilder::new() .set_path(env::temp_dir()) .set_environment_size(500) .set_max_number_of_databases(10)