From 9e4dd59c4eed1b4d0e85a11560c728ac0bdb52c1 Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 08:13:30 -0600 Subject: [PATCH] v2.1: use trailing_zeros for threadset iteration (backport of #3871) (#3897) --- .../thread_aware_account_locks.rs | 53 ++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs index b279102756eed4..f3e3d3f8d683af 100644 --- a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs +++ b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs @@ -421,7 +421,7 @@ impl ThreadSet { #[inline(always)] pub(crate) fn contained_threads_iter(self) -> impl Iterator { - (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) + ThreadSetIterator(self.0) } #[inline(always)] @@ -430,6 +430,22 @@ impl ThreadSet { } } +struct ThreadSetIterator(u64); + +impl Iterator for ThreadSetIterator { + type Item = ThreadId; + + fn next(&mut self) -> Option { + if self.0 == 0 { + None + } else { + let thread_id = self.0.trailing_zeros() as ThreadId; + self.0 &= self.0 - 1; + Some(thread_id) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -739,4 +755,39 @@ mod tests { let any_threads = ThreadSet::any(MAX_THREADS); assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); } + + #[test] + fn test_thread_set_iter() { + let mut thread_set = ThreadSet::none(); + assert!(thread_set.contained_threads_iter().next().is_none()); + + thread_set.insert(4); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4] + ); + + thread_set.insert(5); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 5] + ); + thread_set.insert(63); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 5, 63] + ); + + thread_set.remove(5); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 63] + ); + + let thread_set = ThreadSet::any(64); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + (0..64).collect::>() + ); + } }