From fd02c1feaf993ccd3e4ab04e9cd81beae5793c79 Mon Sep 17 00:00:00 2001 From: Dmitry Markin Date: Tue, 25 Jun 2024 15:03:59 +0300 Subject: [PATCH] kademlia: Preserve publisher & expiration time in DHT records (#162) This PR fixes a bug with publisher & expiration time not being preserved in DHT records. Resolves https://github.com/paritytech/litep2p/issues/129. --- src/protocol/libp2p/kademlia/config.rs | 6 +- src/protocol/libp2p/kademlia/handle.rs | 35 ++------ src/protocol/libp2p/kademlia/message.rs | 114 ++++++++++++++++++++---- src/protocol/libp2p/kademlia/mod.rs | 32 +++++-- src/protocol/libp2p/kademlia/types.rs | 19 ++-- tests/protocol/kademlia.rs | 37 ++++++-- 6 files changed, 172 insertions(+), 71 deletions(-) diff --git a/src/protocol/libp2p/kademlia/config.rs b/src/protocol/libp2p/kademlia/config.rs index f457932a..6a50c599 100644 --- a/src/protocol/libp2p/kademlia/config.rs +++ b/src/protocol/libp2p/kademlia/config.rs @@ -65,6 +65,9 @@ pub struct Config { /// Incoming records validation mode. pub(super) validation_mode: IncomingRecordValidationMode, + /// Default record TTl. + pub(super) record_ttl: Duration, + /// TX channel for sending events to `KademliaHandle`. pub(super) event_tx: Sender, @@ -94,13 +97,14 @@ impl Config { protocol_names, update_mode, validation_mode, + record_ttl, codec: ProtocolCodec::UnsignedVarint(None), replication_factor, known_peers, cmd_rx, event_tx, }, - KademliaHandle::new(cmd_tx, event_rx, record_ttl), + KademliaHandle::new(cmd_tx, event_rx), ) } diff --git a/src/protocol/libp2p/kademlia/handle.rs b/src/protocol/libp2p/kademlia/handle.rs index 8a3af6b5..5d3b4630 100644 --- a/src/protocol/libp2p/kademlia/handle.rs +++ b/src/protocol/libp2p/kademlia/handle.rs @@ -31,7 +31,6 @@ use std::{ num::NonZeroUsize, pin::Pin, task::{Context, Poll}, - time::{Duration, Instant}, }; /// Quorum. @@ -223,23 +222,15 @@ pub struct KademliaHandle { /// Next query ID. next_query_id: usize, - - /// Default TTL for the records. - record_ttl: Duration, } impl KademliaHandle { /// Create new [`KademliaHandle`]. - pub(super) fn new( - cmd_tx: Sender, - event_rx: Receiver, - record_ttl: Duration, - ) -> Self { + pub(super) fn new(cmd_tx: Sender, event_rx: Receiver) -> Self { Self { cmd_tx, event_rx, next_query_id: 0usize, - record_ttl, } } @@ -265,9 +256,7 @@ impl KademliaHandle { } /// Store record to DHT. - pub async fn put_record(&mut self, mut record: Record) -> QueryId { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - + pub async fn put_record(&mut self, record: Record) -> QueryId { let query_id = self.next_query_id(); let _ = self.cmd_tx.send(KademliaCommand::PutRecord { record, query_id }).await; @@ -277,12 +266,10 @@ impl KademliaHandle { /// Store record to DHT to the given peers. pub async fn put_record_to_peers( &mut self, - mut record: Record, + record: Record, peers: Vec, update_local_store: bool, ) -> QueryId { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - let query_id = self.next_query_id(); let _ = self .cmd_tx @@ -314,9 +301,7 @@ impl KademliaHandle { /// Store the record in the local store. Used in combination with /// [`IncomingRecordValidationMode::Manual`]. - pub async fn store_record(&mut self, mut record: Record) { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - + pub async fn store_record(&mut self, record: Record) { let _ = self.cmd_tx.send(KademliaCommand::StoreRecord { record }).await; } @@ -337,9 +322,7 @@ impl KademliaHandle { } /// Try to initiate `PUT_VALUE` query and if the channel is clogged, return an error. - pub fn try_put_record(&mut self, mut record: Record) -> Result { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - + pub fn try_put_record(&mut self, record: Record) -> Result { let query_id = self.next_query_id(); self.cmd_tx .try_send(KademliaCommand::PutRecord { record, query_id }) @@ -351,12 +334,10 @@ impl KademliaHandle { /// return an error. pub fn try_put_record_to_peers( &mut self, - mut record: Record, + record: Record, peers: Vec, update_local_store: bool, ) -> Result { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - let query_id = self.next_query_id(); self.cmd_tx .try_send(KademliaCommand::PutRecordToPeers { @@ -384,9 +365,7 @@ impl KademliaHandle { /// Try to store the record in the local store, and if the channel is clogged, return an error. /// Used in combination with [`IncomingRecordValidationMode::Manual`]. - pub fn try_store_record(&mut self, mut record: Record) -> Result<(), ()> { - record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); - + pub fn try_store_record(&mut self, record: Record) -> Result<(), ()> { self.cmd_tx.try_send(KademliaCommand::StoreRecord { record }).map_err(|_| ()) } } diff --git a/src/protocol/libp2p/kademlia/message.rs b/src/protocol/libp2p/kademlia/message.rs index 7ae25bf9..e0729aa8 100644 --- a/src/protocol/libp2p/kademlia/message.rs +++ b/src/protocol/libp2p/kademlia/message.rs @@ -18,14 +18,18 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::libp2p::kademlia::{ - record::{Key as RecordKey, Record}, - schema, - types::KademliaPeer, +use crate::{ + protocol::libp2p::kademlia::{ + record::{Key as RecordKey, Record}, + schema, + types::KademliaPeer, + }, + PeerId, }; use bytes::{Bytes, BytesMut}; use prost::Message; +use std::time::{Duration, Instant}; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::ipfs::kademlia::message"; @@ -78,16 +82,11 @@ impl KademliaMessage { } /// Create `PUT_VALUE` message for `record`. - // TODO: set ttl pub fn put_value(record: Record) -> Bytes { let message = schema::kademlia::Message { key: record.key.clone().into(), r#type: schema::kademlia::MessageType::PutValue.into(), - record: Some(schema::kademlia::Record { - key: record.key.into(), - value: record.value, - ..Default::default() - }), + record: Some(record_to_schema(record)), cluster_level_raw: 10, ..Default::default() }; @@ -140,11 +139,7 @@ impl KademliaMessage { cluster_level_raw: 10, r#type: schema::kademlia::MessageType::GetValue.into(), closer_peers: peers.iter().map(|peer| peer.into()).collect(), - record: record.map(|record| schema::kademlia::Record { - key: record.key.to_vec(), - value: record.value, - ..Default::default() - }), + record: record.map(record_to_schema), ..Default::default() }; @@ -174,7 +169,7 @@ impl KademliaMessage { let record = message.record?; Some(Self::PutValue { - record: Record::new(record.key, record.value), + record: record_from_schema(record)?, }) } 1 => { @@ -185,9 +180,15 @@ impl KademliaMessage { false => Some(RecordKey::from(message.key.clone())), }; + let record = if let Some(record) = message.record { + Some(record_from_schema(record)?) + } else { + None + }; + Some(Self::GetRecord { key, - record: message.record.map(|record| Record::new(record.key, record.value)), + record, peers: message .closer_peers .iter() @@ -207,3 +208,82 @@ impl KademliaMessage { } } } + +fn record_to_schema(record: Record) -> schema::kademlia::Record { + schema::kademlia::Record { + key: record.key.into(), + value: record.value, + time_received: String::new(), + publisher: record.publisher.map(|peer_id| peer_id.to_bytes()).unwrap_or_default(), + ttl: record + .expires + .map(|expires| { + let now = Instant::now(); + if expires > now { + u32::try_from((expires - now).as_secs()).unwrap_or(u32::MAX) + } else { + 1 // because 0 means "does not expire" + } + }) + .unwrap_or(0), + } +} + +fn record_from_schema(record: schema::kademlia::Record) -> Option { + Some(Record { + key: record.key.into(), + value: record.value, + publisher: if !record.publisher.is_empty() { + Some(PeerId::from_bytes(&record.publisher).ok()?) + } else { + None + }, + expires: if record.ttl > 0 { + Some(Instant::now() + Duration::from_secs(record.ttl as u64)) + } else { + None + }, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn non_empty_publisher_and_ttl_are_preserved() { + let expires = Instant::now() + Duration::from_secs(3600); + + let record = Record { + key: vec![1, 2, 3].into(), + value: vec![17], + publisher: Some(PeerId::random()), + expires: Some(expires), + }; + + let got_record = record_from_schema(record_to_schema(record.clone())).unwrap(); + + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher, record.publisher); + + // Check that the expiration time is sane. + let got_expires = got_record.expires.unwrap(); + assert!(got_expires - expires >= Duration::ZERO); + assert!(got_expires - expires < Duration::from_secs(10)); + } + + #[test] + fn empty_publisher_and_ttl_are_preserved() { + let record = Record { + key: vec![1, 2, 3].into(), + value: vec![17], + publisher: None, + expires: None, + }; + + let got_record = record_from_schema(record_to_schema(record.clone())).unwrap(); + + assert_eq!(got_record, record); + } +} diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index 24f3f2fd..40a278f6 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -45,7 +45,10 @@ use futures::StreamExt; use multiaddr::Multiaddr; use tokio::sync::mpsc::{Receiver, Sender}; -use std::collections::{hash_map::Entry, HashMap}; +use std::{ + collections::{hash_map::Entry, HashMap}, + time::{Duration, Instant}, +}; pub use self::handle::RecordsType; pub use config::{Config, ConfigBuilder}; @@ -115,7 +118,7 @@ pub(crate) struct Kademlia { service: TransportService, /// Local Kademlia key. - _local_key: Key, + local_key: Key, /// Connected peers, peers: HashMap, @@ -147,6 +150,9 @@ pub(crate) struct Kademlia { /// Incoming records validation mode. validation_mode: IncomingRecordValidationMode, + /// Default record TTL. + record_ttl: Duration, + /// Query engine. engine: QueryEngine, @@ -175,12 +181,13 @@ impl Kademlia { cmd_rx: config.cmd_rx, store: MemoryStore::new(), event_tx: config.event_tx, - _local_key: local_key, + local_key, pending_dials: HashMap::new(), executor: QueryExecutor::new(), pending_substreams: HashMap::new(), update_mode: config.update_mode, validation_mode: config.validation_mode, + record_ttl: config.record_ttl, replication_factor: config.replication_factor, engine: QueryEngine::new(local_peer_id, config.replication_factor, PARALLELISM_FACTOR), } @@ -775,9 +782,15 @@ impl Kademlia { self.routing_table.closest(Key::from(peer), self.replication_factor).into() ); } - Some(KademliaCommand::PutRecord { record, query_id }) => { + Some(KademliaCommand::PutRecord { mut record, query_id }) => { tracing::debug!(target: LOG_TARGET, ?query_id, key = ?record.key, "store record to DHT"); + // For `PUT_VALUE` requests originating locally we are always the publisher. + record.publisher = Some(self.local_key.clone().into_preimage()); + + // Make sure TTL is set. + record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); + let key = Key::new(record.key.clone()); self.store.put(record.clone()); @@ -788,9 +801,12 @@ impl Kademlia { self.routing_table.closest(key, self.replication_factor).into(), ); } - Some(KademliaCommand::PutRecordToPeers { record, query_id, peers, update_local_store }) => { + Some(KademliaCommand::PutRecordToPeers { mut record, query_id, peers, update_local_store }) => { tracing::debug!(target: LOG_TARGET, ?query_id, key = ?record.key, "store record to DHT to specified peers"); + // Make sure TTL is set. + record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); + if update_local_store { self.store.put(record.clone()); } @@ -854,13 +870,16 @@ impl Kademlia { self.service.add_known_address(&peer, addresses.into_iter()); } - Some(KademliaCommand::StoreRecord { record }) => { + Some(KademliaCommand::StoreRecord { mut record }) => { tracing::debug!( target: LOG_TARGET, key = ?record.key, "store record in local store", ); + // Make sure TTL is set. + record.expires = record.expires.or_else(|| Some(Instant::now() + self.record_ttl)); + self.store.put(record); } None => return Err(Error::EssentialTaskClosed), @@ -914,6 +933,7 @@ mod tests { replication_factor: 20usize, update_mode: RoutingTableUpdateMode::Automatic, validation_mode: IncomingRecordValidationMode::Automatic, + record_ttl: Duration::from_secs(36 * 60 * 60), event_tx, cmd_rx, }; diff --git a/src/protocol/libp2p/kademlia/types.rs b/src/protocol/libp2p/kademlia/types.rs index fe9d04eb..a0542653 100644 --- a/src/protocol/libp2p/kademlia/types.rs +++ b/src/protocol/libp2p/kademlia/types.rs @@ -49,7 +49,7 @@ construct_uint! { /// the hash digests, interpreted as an integer. See [`Key::distance`]. #[derive(Clone, Debug)] pub struct Key { - _preimage: T, + preimage: T, bytes: KeyBytes, } @@ -59,17 +59,17 @@ impl Key { /// /// The preimage of type `T` is preserved. /// See [`Key::into_preimage`] for more details. - pub fn new(_preimage: T) -> Key + pub fn new(preimage: T) -> Key where T: Borrow<[u8]>, { - let bytes = KeyBytes::new(_preimage.borrow()); - Key { _preimage, bytes } + let bytes = KeyBytes::new(preimage.borrow()); + Key { preimage, bytes } } /// Convert [`Key`] into its preimage. pub fn into_preimage(self) -> T { - self._preimage + self.preimage } /// Computes the distance of the keys according to the XOR metric. @@ -94,8 +94,8 @@ impl Key { /// /// Only used for testing #[cfg(test)] - pub fn from_bytes(bytes: KeyBytes, _preimage: T) -> Key { - Self { bytes, _preimage } + pub fn from_bytes(bytes: KeyBytes, preimage: T) -> Key { + Self { bytes, preimage } } } @@ -108,10 +108,7 @@ impl From> for KeyBytes { impl From for Key { fn from(p: PeerId) -> Self { let bytes = KeyBytes(Sha256::digest(p.to_bytes())); - Key { - _preimage: p, - bytes, - } + Key { preimage: p, bytes } } } diff --git a/tests/protocol/kademlia.rs b/tests/protocol/kademlia.rs index 9c0441a6..893871e9 100644 --- a/tests/protocol/kademlia.rs +++ b/tests/protocol/kademlia.rs @@ -172,7 +172,11 @@ async fn records_are_stored_automatically() { event = kad_handle2.next() => { match event { Some(KademliaEvent::IncomingRecord { record: got_record }) => { - assert_eq!(got_record, record); + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); + // Check if the record was stored. let _ = kad_handle2 .get_record(RecordKey::from(vec![1, 2, 3]), Quorum::One).await; @@ -180,7 +184,11 @@ async fn records_are_stored_automatically() { Some(KademliaEvent::GetRecordSuccess { query_id: _, records }) => { match records { RecordsType::LocalStore(got_record) => { - assert_eq!(got_record, record); + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); + break } RecordsType::Network(_) => { @@ -245,8 +253,11 @@ async fn records_are_stored_manually() { event = kad_handle2.next() => { match event { Some(KademliaEvent::IncomingRecord { record: got_record }) => { - assert_eq!(got_record, record); - assert!(got_record.expires.is_none()); + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); + kad_handle2.store_record(got_record).await; // Check if the record was stored. @@ -256,9 +267,11 @@ async fn records_are_stored_manually() { Some(KademliaEvent::GetRecordSuccess { query_id: _, records }) => { match records { RecordsType::LocalStore(got_record) => { + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); assert!(got_record.expires.is_some()); - record.expires = got_record.expires; - assert_eq!(got_record, record); + break } RecordsType::Network(_) => { @@ -325,7 +338,10 @@ async fn not_validated_records_are_not_stored() { event = kad_handle2.next() => { match event { Some(KademliaEvent::IncomingRecord { record: got_record }) => { - assert_eq!(got_record, record); + assert_eq!(got_record.key, record.key); + assert_eq!(got_record.value, record.value); + assert_eq!(got_record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(got_record.expires.is_some()); // Do not call `kad_handle2.store_record(record).await`. // Check if the record was stored. @@ -424,9 +440,14 @@ async fn get_record_retrieves_remote_records() { } RecordsType::Network(records) => { assert_eq!(records.len(), 1); + let PeerRecord { peer, record } = records.first().unwrap(); assert_eq!(peer, litep2p1.local_peer_id()); - assert_eq!(record, &original_record); + assert_eq!(record.key, original_record.key); + assert_eq!(record.value, original_record.value); + assert_eq!(record.publisher.unwrap(), *litep2p1.local_peer_id()); + assert!(record.expires.is_some()); + break } }