diff --git a/src/banking_stage.rs b/src/banking_stage.rs index ef06f0c3834f85..c689b8d9cf4eb8 100644 --- a/src/banking_stage.rs +++ b/src/banking_stage.rs @@ -7,7 +7,8 @@ use crate::compute_leader_confirmation_service::ComputeLeaderConfirmationService use crate::counter::Counter; use crate::entry::Entry; use crate::packet::Packets; -use crate::poh_recorder::{PohRecorder, PohRecorderError}; +use crate::packet::SharedPackets; +use crate::poh_recorder::PohRecorder; use crate::poh_service::{PohService, PohServiceConfig}; use crate::result::{Error, Result}; use crate::service::Service; @@ -19,7 +20,6 @@ use solana_sdk::hash::Hash; use solana_sdk::pubkey::Pubkey; use solana_sdk::timing; use solana_sdk::transaction::Transaction; -use std::sync::atomic::Ordering; use std::sync::mpsc::{channel, Receiver, RecvTimeoutError}; use std::sync::{Arc, Mutex}; use std::thread::{self, Builder, JoinHandle}; @@ -27,22 +27,16 @@ use std::time::Duration; use std::time::Instant; use sys_info; -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum BankingStageReturnType { - LeaderRotation(u64), - ChannelDisconnected, -} +pub type UnprocessedPackets = Vec<(SharedPackets, usize)>; // `usize` is the index of the first unprocessed packet in `SharedPackets` // number of threads is 1 until mt bank is ready pub const NUM_THREADS: u32 = 10; /// Stores the stage's thread handle and output receiver. pub struct BankingStage { - /// Handle to the stage's thread. - bank_thread_hdls: Vec>>, + bank_thread_hdls: Vec>, poh_service: PohService, compute_confirmation_service: ComputeLeaderConfirmationService, - max_tick_height: u64, } impl BankingStage { @@ -76,58 +70,32 @@ impl BankingStage { ); // Many banks that process transactions in parallel. - let bank_thread_hdls: Vec>> = (0 - ..Self::num_threads()) + let bank_thread_hdls: Vec> = (0..Self::num_threads()) .map(|_| { let thread_bank = bank.clone(); let thread_verified_receiver = shared_verified_receiver.clone(); let thread_poh_recorder = poh_recorder.clone(); - let thread_banking_exit = poh_service.poh_exit.clone(); Builder::new() .name("solana-banking-stage-tx".to_string()) .spawn(move || { - let return_result = loop { - if let Err(e) = Self::process_packets( + let mut unprocessed_packets: UnprocessedPackets = vec![]; + loop { + match Self::process_packets( &thread_bank, &thread_verified_receiver, &thread_poh_recorder, ) { - debug!("process_packets error: {:?}", e); - match e { - Error::RecvTimeoutError(RecvTimeoutError::Timeout) => (), - Error::RecvTimeoutError(RecvTimeoutError::Disconnected) => { - break Some(BankingStageReturnType::ChannelDisconnected); - } - Error::RecvError(_) => { - break Some(BankingStageReturnType::ChannelDisconnected); - } - Error::SendError => { - break Some(BankingStageReturnType::ChannelDisconnected); - } - Error::BankError(BankError::RecordFailure) => { - warn!("Bank failed to record"); - break Some(BankingStageReturnType::ChannelDisconnected); - } - Error::BankError(BankError::MaxHeightReached) => { - // Bank has reached its max tick height. Exit quietly - // and wait for the PohRecorder to start leader rotation - break None; - } - _ => { - error!("solana-banking-stage-tx: unhandled error: {:?}", e) - } + Err(Error::RecvTimeoutError(RecvTimeoutError::Timeout)) => (), + Ok(more_unprocessed_packets) => { + unprocessed_packets.extend(more_unprocessed_packets); + } + Err(err) => { + debug!("solana-banking-stage-tx: exit due to {:?}", err); + break; } } - if thread_banking_exit.load(Ordering::Relaxed) { - break None; - } - }; - - // Signal exit only on "Some" error - if return_result.is_some() { - thread_banking_exit.store(true, Ordering::Relaxed); } - return_result + unprocessed_packets }) .unwrap() }) @@ -137,7 +105,6 @@ impl BankingStage { bank_thread_hdls, poh_service, compute_confirmation_service, - max_tick_height, }, entry_receiver, ) @@ -155,22 +122,28 @@ impl BankingStage { .collect() } + /// Sends transactions to the bank. + /// + /// Returns the number of transactions successfully processed by the bank, which may be less + /// than the total number if max PoH height was reached and the bank halted fn process_transactions( bank: &Arc, transactions: &[Transaction], poh: &PohRecorder, - ) -> Result<()> { - debug!("transactions: {}", transactions.len()); + ) -> Result<(usize)> { let mut chunk_start = 0; while chunk_start != transactions.len() { let chunk_end = chunk_start + Entry::num_will_fit(&transactions[chunk_start..]); - bank.process_and_record_transactions(&transactions[chunk_start..chunk_end], poh)?; - + let result = + bank.process_and_record_transactions(&transactions[chunk_start..chunk_end], poh); + if Err(BankError::MaxHeightReached) == result { + break; + } + result?; chunk_start = chunk_end; } - debug!("done process_transactions"); - Ok(()) + Ok(chunk_start) } /// Process the incoming packets @@ -178,7 +151,7 @@ impl BankingStage { bank: &Arc, verified_receiver: &Arc>>, poh: &PohRecorder, - ) -> Result<()> { + ) -> Result { let recv_start = Instant::now(); let mms = verified_receiver .lock() @@ -196,29 +169,45 @@ impl BankingStage { let count = mms.iter().map(|x| x.1.len()).sum(); let proc_start = Instant::now(); let mut new_tx_count = 0; + + let mut unprocessed_packets = vec![]; + let mut bank_shutdown = false; for (msgs, vers) in mms { + if bank_shutdown { + unprocessed_packets.push((msgs, 0)); + continue; + } + let transactions = Self::deserialize_transactions(&msgs.read().unwrap()); reqs_len += transactions.len(); debug!("transactions received {}", transactions.len()); - - let transactions: Vec<_> = transactions - .into_iter() - .zip(vers) - .filter_map(|(tx, ver)| match tx { - None => None, - Some(tx) => { - if tx.verify_refs() && ver != 0 { - Some(tx) - } else { - None + let (verified_transactions, verified_transaction_index): (Vec<_>, Vec<_>) = + transactions + .into_iter() + .zip(vers) + .zip(0..) + .filter_map(|((tx, ver), index)| match tx { + None => None, + Some(tx) => { + if tx.verify_refs() && ver != 0 { + Some((tx, index)) + } else { + None + } } - } - }) - .collect(); - debug!("verified transactions {}", transactions.len()); - Self::process_transactions(bank, &transactions, poh)?; - new_tx_count += transactions.len(); + }) + .unzip(); + + debug!("verified transactions {}", verified_transactions.len()); + + let processed = Self::process_transactions(bank, &verified_transactions, poh)?; + if processed < verified_transactions.len() { + bank_shutdown = true; + // Collect any unprocessed transactions in this batch for forwarding + unprocessed_packets.push((msgs, verified_transaction_index[processed])); + } + new_tx_count += processed; } inc_new_counter_info!( @@ -237,49 +226,43 @@ impl BankingStage { ); inc_new_counter_info!("banking_stage-process_packets", count); inc_new_counter_info!("banking_stage-process_transactions", new_tx_count); - Ok(()) + + Ok(unprocessed_packets) + } + + pub fn join_and_collect_unprocessed_packets(&mut self) -> UnprocessedPackets { + let mut unprocessed_packets: UnprocessedPackets = vec![]; + for bank_thread_hdl in self.bank_thread_hdls.drain(..) { + match bank_thread_hdl.join() { + Ok(more_unprocessed_packets) => { + unprocessed_packets.extend(more_unprocessed_packets) + } + err => warn!("bank_thread_hdl join failed: {:?}", err), + } + } + unprocessed_packets } } impl Service for BankingStage { - type JoinReturnType = Option; - - fn join(self) -> thread::Result> { - let mut return_value = None; + type JoinReturnType = (); + fn join(self) -> thread::Result<()> { for bank_thread_hdl in self.bank_thread_hdls { - let thread_return_value = bank_thread_hdl.join()?; - if thread_return_value.is_some() { - return_value = thread_return_value; - } + bank_thread_hdl.join()?; } - self.compute_confirmation_service.join()?; - - let poh_return_value = self.poh_service.join()?; - match poh_return_value { - Ok(_) => (), - Err(Error::PohRecorderError(PohRecorderError::MaxHeightReached)) => { - return_value = Some(BankingStageReturnType::LeaderRotation(self.max_tick_height)); - } - Err(Error::SendError) => { - return_value = Some(BankingStageReturnType::ChannelDisconnected); - } - Err(_) => (), - } - - Ok(return_value) + let _ = self.poh_service.join()?; + Ok(()) } } #[cfg(test)] mod tests { use super::*; - use crate::bank::Bank; - use crate::banking_stage::BankingStageReturnType; use crate::entry::EntrySlice; use crate::genesis_block::GenesisBlock; - use crate::leader_scheduler::DEFAULT_TICKS_PER_SLOT; + use crate::leader_scheduler::{LeaderSchedulerConfig, DEFAULT_TICKS_PER_SLOT}; use crate::packet::to_packets; use solana_sdk::signature::{Keypair, KeypairUtil}; use solana_sdk::system_transaction::SystemTransaction; @@ -301,32 +284,7 @@ mod tests { &to_validator_sender, ); drop(verified_sender); - assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::ChannelDisconnected) - ); - } - - #[test] - fn test_banking_stage_shutdown2() { - let (genesis_block, _mint_keypair) = GenesisBlock::new(2); - let bank = Arc::new(Bank::new(&genesis_block)); - let (_verified_sender, verified_receiver) = channel(); - let (to_validator_sender, _) = channel(); - let (banking_stage, entry_receiver) = BankingStage::new( - &bank, - verified_receiver, - PohServiceConfig::default(), - &bank.last_id(), - DEFAULT_TICKS_PER_SLOT, - genesis_block.bootstrap_leader_id, - &to_validator_sender, - ); - drop(entry_receiver); - assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::ChannelDisconnected) - ); + banking_stage.join().unwrap(); } #[test] @@ -352,10 +310,7 @@ mod tests { assert!(entries.len() != 0); assert!(entries.verify(&start_hash)); assert_eq!(entries[entries.len() - 1].id, bank.last_id()); - assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::ChannelDisconnected) - ); + banking_stage.join().unwrap(); } #[test] @@ -409,10 +364,7 @@ mod tests { last_id = entries.last().unwrap().id; }); drop(entry_receiver); - assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::ChannelDisconnected) - ); + banking_stage.join().unwrap(); } #[test] fn test_banking_stage_entryfication() { @@ -461,10 +413,7 @@ mod tests { .send(vec![(packets[0].clone(), vec![1u8])]) .unwrap(); drop(verified_sender); - assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::ChannelDisconnected) - ); + banking_stage.join().unwrap(); // Collect the ledger and feed it to a new bank. let entries: Vec<_> = entry_receiver.iter().flat_map(|x| x).collect(); @@ -484,13 +433,12 @@ mod tests { } // Test that when the max_tick_height is reached, the banking stage exits - // with reason BankingStageReturnType::LeaderRotation #[test] fn test_max_tick_height_shutdown() { let (genesis_block, _mint_keypair) = GenesisBlock::new(2); let bank = Arc::new(Bank::new(&genesis_block)); - let (_verified_sender_, verified_receiver) = channel(); - let (to_validator_sender, _to_validator_receiver) = channel(); + let (verified_sender, verified_receiver) = channel(); + let (to_validator_sender, to_validator_receiver) = channel(); let max_tick_height = 10; let (banking_stage, _entry_receiver) = BankingStage::new( &bank, @@ -501,9 +449,58 @@ mod tests { genesis_block.bootstrap_leader_id, &to_validator_sender, ); + assert_eq!(to_validator_receiver.recv().unwrap(), max_tick_height); + drop(verified_sender); + banking_stage.join().unwrap(); + } + + #[test] + fn test_returns_unprocessed_packet() { + solana_logger::setup(); + let (genesis_block, mint_keypair) = GenesisBlock::new(2); + let leader_scheduler_config = LeaderSchedulerConfig::new(1, 1, 1); + let bank = Arc::new(Bank::new_with_leader_scheduler_config( + &genesis_block, + &leader_scheduler_config, + )); + let (verified_sender, verified_receiver) = channel(); + let (to_validator_sender, to_validator_receiver) = channel(); + let (mut banking_stage, _entry_receiver) = BankingStage::new( + &bank, + verified_receiver, + PohServiceConfig::default(), + &bank.last_id(), + leader_scheduler_config.ticks_per_slot, + genesis_block.bootstrap_leader_id, + &to_validator_sender, + ); + + // Wait for Poh recorder to hit max height assert_eq!( - banking_stage.join().unwrap(), - Some(BankingStageReturnType::LeaderRotation(max_tick_height)) + to_validator_receiver.recv().unwrap(), + leader_scheduler_config.ticks_per_slot ); + + // Now send a transaction to the banking stage + let transaction = SystemTransaction::new_account( + &mint_keypair, + Keypair::new().pubkey(), + 2, + genesis_block.last_id(), + 0, + ); + + let packets = to_packets(&[transaction]); + verified_sender + .send(vec![(packets[0].clone(), vec![1u8])]) + .unwrap(); + + // Shut down the banking stage, it should give back the transaction + drop(verified_sender); + let unprocessed_packets = banking_stage.join_and_collect_unprocessed_packets(); + assert_eq!(unprocessed_packets.len(), 1); + let (packets, start_index) = &unprocessed_packets[0]; + assert_eq!(packets.read().unwrap().packets.len(), 1); // TODO: maybe compare actual packet contents too + assert_eq!(*start_index, 0); } } diff --git a/src/tpu.rs b/src/tpu.rs index a66ee136284c7b..636f60e7acdd09 100644 --- a/src/tpu.rs +++ b/src/tpu.rs @@ -2,7 +2,7 @@ //! multi-stage transaction processing pipeline in software. use crate::bank::Bank; -use crate::banking_stage::BankingStage; +use crate::banking_stage::{BankingStage, UnprocessedPackets}; use crate::blocktree::Blocktree; use crate::broadcast_service::BroadcastService; use crate::cluster_info::ClusterInfo; @@ -115,7 +115,7 @@ impl Tpu { tpu } - fn tpu_mode_close(&self) { + fn mode_close(&self) { match &self.tpu_mode { Some(TpuMode::Leader(svcs)) => { svcs.fetch_stage.close(); @@ -127,8 +127,39 @@ impl Tpu { } } + fn forward_unprocessed_packets( + tpu: &std::net::SocketAddr, + unprocessed_packets: UnprocessedPackets, + ) -> std::io::Result<()> { + let socket = UdpSocket::bind("0.0.0.0:0")?; + for (packets, start_index) in unprocessed_packets { + let packets = packets.read().unwrap(); + for packet in packets.packets.iter().skip(start_index) { + socket.send_to(&packet.data[..packet.meta.size], tpu)?; + } + } + Ok(()) + } + + fn close_and_forward_unprocessed_packets(&mut self) { + self.mode_close(); + + if let Some(TpuMode::Leader(svcs)) = self.tpu_mode.take().as_mut() { + let unprocessed_packets = svcs.banking_stage.join_and_collect_unprocessed_packets(); + + if !unprocessed_packets.is_empty() { + let tpu = self.cluster_info.read().unwrap().leader_data().unwrap().tpu; + info!("forwarding unprocessed packets to new leader at {:?}", tpu); + Tpu::forward_unprocessed_packets(&tpu, unprocessed_packets).unwrap_or_else(|err| { + warn!("Failed to forward unprocessed transactions: {:?}", err) + }); + } + } + } + pub fn switch_to_forwarder(&mut self, transactions_sockets: Vec) { - self.tpu_mode_close(); + self.close_and_forward_unprocessed_packets(); + let tpu_forwarder = TpuForwarder::new(transactions_sockets, self.cluster_info.clone()); self.tpu_mode = Some(TpuMode::Forwarder(ForwarderServices::new(tpu_forwarder))); } @@ -148,7 +179,7 @@ impl Tpu { to_validator_sender: &TpuRotationSender, blocktree: &Arc, ) { - self.tpu_mode_close(); + self.close_and_forward_unprocessed_packets(); self.exit = Arc::new(AtomicBool::new(false)); let (packet_sender, packet_receiver) = channel(); @@ -214,7 +245,7 @@ impl Tpu { } pub fn close(self) -> thread::Result<()> { - self.tpu_mode_close(); + self.mode_close(); self.join() } } diff --git a/tests/multinode.rs b/tests/multinode.rs index 977242e8eb2768..8dc402061f252e 100644 --- a/tests/multinode.rs +++ b/tests/multinode.rs @@ -1703,6 +1703,7 @@ fn test_broadcast_last_tick() { .expect("Expected to be able to reconstruct entries from blob") .0[0]; assert_eq!(actual_last_tick, expected_last_tick); + break; } else { assert!(!b_r.is_last_in_slot()); } @@ -2078,13 +2079,12 @@ fn test_two_fullnodes_rotate_every_second_tick() { } #[test] -#[ignore] fn test_one_fullnode_rotate_every_tick_with_transactions() { test_fullnode_rotate(1, 1, false, true); } #[test] #[ignore] -fn test_two_fullnodes_rotate_every_tick_with_transacations() { +fn test_two_fullnodes_rotate_every_tick_with_transactions() { test_fullnode_rotate(1, 1, true, true); }