diff --git a/Cargo.lock b/Cargo.lock index f1d8d25e9a..4c711b14f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -538,7 +538,7 @@ dependencies = [ name = "ckb-channel" version = "0.100.0-pre" dependencies = [ - "crossbeam-channel 0.3.9", + "crossbeam-channel", ] [[package]] @@ -1511,15 +1511,6 @@ dependencies = [ "itertools 0.9.0", ] -[[package]] -name = "crossbeam-channel" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8ec7fcd21571dc78f96cc96243cab8d8f035247c3efd16c687be154c3fa9efa" -dependencies = [ - "crossbeam-utils 0.6.5", -] - [[package]] name = "crossbeam-channel" version = "0.5.1" @@ -1527,7 +1518,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" dependencies = [ "cfg-if 1.0.0", - "crossbeam-utils 0.8.5", + "crossbeam-utils", ] [[package]] @@ -1538,7 +1529,7 @@ checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" dependencies = [ "cfg-if 1.0.0", "crossbeam-epoch", - "crossbeam-utils 0.8.5", + "crossbeam-utils", ] [[package]] @@ -1548,22 +1539,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" dependencies = [ "cfg-if 1.0.0", - "crossbeam-utils 0.8.5", + "crossbeam-utils", "lazy_static", "memoffset", "scopeguard", ] -[[package]] -name = "crossbeam-utils" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8306fcef4a7b563b76b7dd949ca48f52bc1141aa067d2ea09565f3e2652aa5c" -dependencies = [ - "cfg-if 0.1.10", - "lazy_static", -] - [[package]] name = "crossbeam-utils" version = "0.8.5" @@ -3751,9 +3732,9 @@ version = "1.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d78120e2c850279833f1dd3582f730c4ab53ed95aeaaaa862a2a5c71b1656d8e" dependencies = [ - "crossbeam-channel 0.5.1", + "crossbeam-channel", "crossbeam-deque", - "crossbeam-utils 0.8.5", + "crossbeam-utils", "lazy_static", "num_cpus", ] diff --git a/sync/src/tests/mod.rs b/sync/src/tests/mod.rs index aa787b02e6..218465ea7c 100644 --- a/sync/src/tests/mod.rs +++ b/sync/src/tests/mod.rs @@ -1,3 +1,4 @@ +use ckb_channel::{bounded, Receiver, Select, Sender}; use ckb_network::{ bytes::Bytes, Behaviour, CKBProtocolContext, CKBProtocolHandler, Peer, PeerIndex, ProtocolId, TargetSession, @@ -6,9 +7,8 @@ use ckb_util::RwLock; use futures::future::Future; use std::collections::{HashMap, HashSet}; use std::pin::Pin; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; -use std::thread; +use std::thread::{self, JoinHandle}; use std::time::Duration; mod inflight_blocks; @@ -19,17 +19,43 @@ mod util; const DEFAULT_CHANNEL: usize = 128; -#[derive(Default)] +enum Msg { + Bytes(Bytes), + Empty, +} + +#[derive(Hash, Clone, PartialEq, Eq)] +enum Index { + Msg(ProtocolId, PeerIndex), + Timer(ProtocolId, u64), + Stop, +} + struct TestNode { pub peers: Vec, pub protocols: HashMap>>, - pub msg_senders: HashMap<(ProtocolId, PeerIndex), SyncSender>, - pub msg_receivers: HashMap<(ProtocolId, PeerIndex), Receiver>, - pub timer_senders: HashMap<(ProtocolId, u64), SyncSender<()>>, - pub timer_receivers: HashMap<(ProtocolId, u64), Receiver<()>>, + pub senders: HashMap>, + pub receivers: HashMap>, + pub stop: Sender, + pub th: Option>, } impl TestNode { + pub fn new() -> TestNode { + let (stop_tx, stop_rx) = bounded(1); + let mut receivers = HashMap::new(); + receivers.insert(Index::Stop, stop_rx); + + TestNode { + receivers, + senders: HashMap::new(), + protocols: HashMap::new(), + peers: Vec::new(), + stop: stop_tx, + th: None, + } + } + pub fn add_protocol( &mut self, protocol: ProtocolId, @@ -38,45 +64,42 @@ impl TestNode { ) { self.protocols.insert(protocol, Arc::clone(handler)); timers.iter().for_each(|timer| { - let (timer_sender, timer_receiver) = sync_channel(DEFAULT_CHANNEL); - self.timer_senders.insert((protocol, *timer), timer_sender); - self.timer_receivers - .insert((protocol, *timer), timer_receiver); + let (timer_sender, timer_receiver) = bounded(DEFAULT_CHANNEL); + let index = Index::Timer(protocol, *timer); + self.senders.insert(index.clone(), timer_sender); + self.receivers.insert(index, timer_receiver); }); handler.write().init(Arc::new(TestNetworkContext { protocol, - msg_senders: self.msg_senders.clone(), - timer_senders: self.timer_senders.clone(), + senders: self.senders.clone(), })) } pub fn connect(&mut self, remote: &mut TestNode, protocol: ProtocolId) { - let (local_sender, local_receiver) = sync_channel(DEFAULT_CHANNEL); + let (local_sender, local_receiver) = bounded(DEFAULT_CHANNEL); let local_index = self.peers.len(); self.peers.insert(local_index, local_index.into()); - self.msg_senders - .insert((protocol, local_index.into()), local_sender); + let local_ch_index = Index::Msg(protocol, local_index.into()); + self.senders.insert(local_ch_index.clone(), local_sender); - let (remote_sender, remote_receiver) = sync_channel(DEFAULT_CHANNEL); + let (remote_sender, remote_receiver) = bounded(DEFAULT_CHANNEL); let remote_index = remote.peers.len(); remote.peers.insert(remote_index, remote_index.into()); - remote - .msg_senders - .insert((protocol, remote_index.into()), remote_sender); - self.msg_receivers - .insert((protocol, remote_index.into()), remote_receiver); + let remote_ch_index = Index::Msg(protocol, local_index.into()); remote - .msg_receivers - .insert((protocol, local_index.into()), local_receiver); + .senders + .insert(remote_ch_index.clone(), remote_sender); + self.receivers.insert(remote_ch_index, remote_receiver); + + remote.receivers.insert(local_ch_index, local_receiver); if let Some(handler) = self.protocols.get(&protocol) { handler.write().connected( Arc::new(TestNetworkContext { protocol, - msg_senders: self.msg_senders.clone(), - timer_senders: self.timer_senders.clone(), + senders: self.senders.clone(), }), local_index.into(), "v1", @@ -87,8 +110,7 @@ impl TestNode { handler.write().connected( Arc::new(TestNetworkContext { protocol, - msg_senders: remote.msg_senders.clone(), - timer_senders: remote.timer_senders.clone(), + senders: remote.senders.clone(), }), local_index.into(), "v1", @@ -96,50 +118,85 @@ impl TestNode { } } - pub fn start bool>(&self, signal: &SyncSender<()>, pred: F) { - loop { - for ((protocol, peer), receiver) in &self.msg_receivers { - let _ = receiver.try_recv().map(|payload| { - if let Some(handler) = self.protocols.get(protocol) { - handler.write().received( - Arc::new(TestNetworkContext { - protocol: *protocol, - msg_senders: self.msg_senders.clone(), - timer_senders: self.timer_senders.clone(), - }), - *peer, - payload.clone(), - ) + pub fn start bool + Send + 'static>( + &mut self, + thread_name: String, + signal: Sender<()>, + pred: F, + ) { + let receivers: Vec<_> = self + .receivers + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + + let protocols = self.protocols.clone(); + let senders = self.senders.clone(); + + let th = thread::Builder::new() + .name(thread_name) + .spawn(move || { + let mut sel = Select::new(); + for r in &receivers { + sel.recv(&r.1); + } + loop { + let index = sel.ready(); + let (index, rv) = &receivers[index]; + let res = rv.try_recv(); + + match index { + Index::Msg(protocol, peer) => { + if let Ok(Msg::Bytes(payload)) = res { + if let Some(handler) = protocols.get(protocol) { + handler.write().received( + Arc::new(TestNetworkContext { + protocol: *protocol, + senders: senders.clone(), + }), + *peer, + payload.clone(), + ) + }; + + if pred(payload) { + let _ = signal.send(()); + } + } + } + Index::Timer(protocol, timer) => { + if let Some(handler) = protocols.get(protocol) { + handler.write().notify( + Arc::new(TestNetworkContext { + protocol: *protocol, + senders: senders.clone(), + }), + *timer, + ) + } + } + Index::Stop => { + break; + } }; + } + }) + .expect("thread spawn"); - if pred(&payload) { - let _ = signal.send(()); - } - }); - } + self.th = Some(th); + } - for ((protocol, timer), receiver) in &self.timer_receivers { - let _ = receiver.try_recv().map(|_| { - if let Some(handler) = self.protocols.get(protocol) { - handler.write().notify( - Arc::new(TestNetworkContext { - protocol: *protocol, - msg_senders: self.msg_senders.clone(), - timer_senders: self.timer_senders.clone(), - }), - *timer, - ) - } - }); - } + pub fn stop(mut self) { + self.stop.send(Msg::Empty).expect("stop recv"); + if let Some(th) = self.th.take() { + th.join().expect("th join"); } } } struct TestNetworkContext { protocol: ProtocolId, - msg_senders: HashMap<(ProtocolId, PeerIndex), SyncSender>, - timer_senders: HashMap<(ProtocolId, u64), SyncSender<()>>, + senders: HashMap>, } impl CKBProtocolContext for TestNetworkContext { @@ -148,11 +205,12 @@ impl CKBProtocolContext for TestNetworkContext { } // Interact with underlying p2p service fn set_notify(&self, interval: Duration, token: u64) -> Result<(), ckb_network::Error> { - if let Some(sender) = self.timer_senders.get(&(self.protocol, token)) { + let index = Index::Timer(self.protocol, token); + if let Some(sender) = self.senders.get(&index) { let sender = sender.clone(); thread::spawn(move || loop { thread::sleep(interval); - let _ = sender.send(()); + let _ = sender.send(Msg::Empty); }); } Ok(()) @@ -199,8 +257,9 @@ impl CKBProtocolContext for TestNetworkContext { peer_index: PeerIndex, data: Bytes, ) -> Result<(), ckb_network::Error> { - if let Some(sender) = self.msg_senders.get(&(proto_id, peer_index)) { - let _ = sender.send(data); + let index = Index::Msg(proto_id, peer_index); + if let Some(sender) = self.senders.get(&index) { + let _ = sender.send(Msg::Bytes(data)); } Ok(()) } @@ -209,8 +268,9 @@ impl CKBProtocolContext for TestNetworkContext { peer_index: PeerIndex, data: Bytes, ) -> Result<(), ckb_network::Error> { - if let Some(sender) = self.msg_senders.get(&(self.protocol, peer_index)) { - let _ = sender.send(data); + let index = Index::Msg(self.protocol, peer_index); + if let Some(sender) = self.senders.get(&index) { + let _ = sender.send(Msg::Bytes(data)); } Ok(()) } @@ -223,9 +283,12 @@ impl CKBProtocolContext for TestNetworkContext { TargetSession::Single(peer) => self.send_message_to(peer, data).unwrap(), TargetSession::Filter(peers) => { for peer in self - .msg_senders + .senders .keys() - .map(|(_, id)| id) + .filter_map(|index| match index { + Index::Msg(_, id) => Some(id), + _ => None, + }) .copied() .collect::>() { @@ -249,7 +312,13 @@ impl CKBProtocolContext for TestNetworkContext { } fn with_peer_mut(&self, _peer_index: PeerIndex, _f: Box) {} fn connected_peers(&self) -> Vec { - self.msg_senders.keys().map(|k| k.1).collect::>() + self.senders + .keys() + .filter_map(|index| match index { + Index::Msg(_, peer_id) => Some(*peer_id), + _ => None, + }) + .collect::>() } fn report_peer(&self, _peer_index: PeerIndex, _behaviour: Behaviour) {} fn ban_peer(&self, _peer_index: PeerIndex, _duration: Duration, _reason: String) {} diff --git a/sync/src/tests/synchronizer.rs b/sync/src/tests/synchronizer.rs index 6f5950e6a5..becd9a594d 100644 --- a/sync/src/tests/synchronizer.rs +++ b/sync/src/tests/synchronizer.rs @@ -6,6 +6,7 @@ use crate::tests::TestNode; use crate::{SyncShared, Synchronizer}; use ckb_chain::chain::ChainService; use ckb_chain_spec::consensus::ConsensusBuilder; +use ckb_channel::bounded; use ckb_dao::DaoCalculator; use ckb_dao_utils::genesis_dao_data; use ckb_launcher::SharedBuilder; @@ -25,9 +26,7 @@ use ckb_util::RwLock; use ckb_verification_traits::Switch; use faketime::{self, unix_time_as_millis}; use std::collections::HashSet; -use std::sync::mpsc::sync_channel; use std::sync::Arc; -use std::thread; const DEFAULT_CHANNEL: usize = 128; @@ -42,36 +41,29 @@ fn basic_sync() { node1.connect(&mut node2, SupportProtocols::Sync.protocol_id()); - let (signal_tx1, signal_rx1) = sync_channel(DEFAULT_CHANNEL); - thread::Builder::new() - .name(thread_name.clone()) - .spawn(move || { - node1.start(&signal_tx1, |data| { - let msg = packed::SyncMessage::from_slice(&data) - .expect("sync message") - .to_enum(); - // terminate thread after 3 blocks - if let packed::SyncMessageUnionReader::SendBlock(reader) = msg.as_reader() { - let block = reader.block().to_entity().into_view(); - block.header().number() == 3 - } else { - false - } - }); - }) - .expect("thread spawn"); - - let (signal_tx2, _) = sync_channel(DEFAULT_CHANNEL); - thread::Builder::new() - .name(thread_name) - .spawn(move || { - node2.start(&signal_tx2, |_| false); - }) - .expect("thread spawn"); + let (signal_tx1, signal_rx1) = bounded(DEFAULT_CHANNEL); + node1.start(thread_name.clone(), signal_tx1, |data| { + let msg = packed::SyncMessage::from_slice(&data) + .expect("sync message") + .to_enum(); + // terminate thread after 3 blocks + if let packed::SyncMessageUnionReader::SendBlock(reader) = msg.as_reader() { + let block = reader.block().to_entity().into_view(); + block.header().number() == 3 + } else { + false + } + }); + + let (signal_tx2, _) = bounded(DEFAULT_CHANNEL); + node2.start(thread_name, signal_tx2, |_| false); // Wait node1 receive block from node2 let _ = signal_rx1.recv(); + node1.stop(); + node2.stop(); + assert_eq!(shared1.snapshot().tip_number(), 3); assert_eq!( shared1.snapshot().tip_number(), @@ -176,7 +168,7 @@ fn setup_node(height: u64) -> (TestNode, Shared) { pack.take_relay_tx_receiver(), )); let synchronizer = Synchronizer::new(chain_controller, sync_shared); - let mut node = TestNode::default(); + let mut node = TestNode::new(); let protocol = Arc::new(RwLock::new(synchronizer)) as Arc<_>; node.add_protocol( SupportProtocols::Sync.protocol_id(), diff --git a/tx-pool/src/process.rs b/tx-pool/src/process.rs index 5bfc10f7db..4748c1c31e 100644 --- a/tx-pool/src/process.rs +++ b/tx-pool/src/process.rs @@ -1077,18 +1077,28 @@ fn _submit_entry( entry: TxEntry, callbacks: &Callbacks, ) -> Result<(), Reject> { + let tx_hash = entry.transaction().hash(); match status { TxStatus::Fresh => { - tx_pool.add_pending(entry.clone()); - callbacks.call_pending(tx_pool, &entry); + if tx_pool.add_pending(entry.clone()) { + callbacks.call_pending(tx_pool, &entry); + } else { + return Err(Reject::Duplicated(tx_hash)); + } } TxStatus::Gap => { - tx_pool.add_gap(entry.clone()); - callbacks.call_pending(tx_pool, &entry); + if tx_pool.add_gap(entry.clone()) { + callbacks.call_pending(tx_pool, &entry); + } else { + return Err(Reject::Duplicated(tx_hash)); + } } TxStatus::Proposed => { - tx_pool.add_proposed(entry.clone())?; - callbacks.call_proposed(tx_pool, &entry, true); + if tx_pool.add_proposed(entry.clone())? { + callbacks.call_proposed(tx_pool, &entry, true); + } else { + return Err(Reject::Duplicated(tx_hash)); + } } } Ok(()) diff --git a/util/channel/Cargo.toml b/util/channel/Cargo.toml index 1093003731..521026b118 100644 --- a/util/channel/Cargo.toml +++ b/util/channel/Cargo.toml @@ -11,4 +11,4 @@ repository = "https://github.com/nervosnetwork/ckb" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -crossbeam-channel = "~0.3" +crossbeam-channel = "0.5.1" diff --git a/util/channel/src/lib.rs b/util/channel/src/lib.rs index b8d139912e..90755a11f3 100644 --- a/util/channel/src/lib.rs +++ b/util/channel/src/lib.rs @@ -1,6 +1,6 @@ //! Reexports `crossbeam_channel` to uniform the dependency version. pub use crossbeam_channel::{ - bounded, select, unbounded, Receiver, RecvError, RecvTimeoutError, SendError, Sender, + bounded, select, unbounded, Receiver, RecvError, RecvTimeoutError, Select, SendError, Sender, TrySendError, };