From 4333b869729bd00b8703721e206249b237a25e72 Mon Sep 17 00:00:00 2001 From: pcpthm Date: Mon, 16 Sep 2019 04:37:52 +0000 Subject: [PATCH] Improve BTreeSet::Intersection::size_hint The commented invariant that an iterator is smaller than other iterator was violated after next is called and two iterators are consumed at different rates. --- src/liballoc/collections/btree/set.rs | 44 +++++++++++++-------------- src/liballoc/tests/btree/set.rs | 11 +++++++ 2 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs index d3af910a82c27..0cb91ba4c81da 100644 --- a/src/liballoc/collections/btree/set.rs +++ b/src/liballoc/collections/btree/set.rs @@ -3,7 +3,7 @@ use core::borrow::Borrow; use core::cmp::Ordering::{self, Less, Greater, Equal}; -use core::cmp::max; +use core::cmp::{max, min}; use core::fmt::{self, Debug}; use core::iter::{Peekable, FromIterator, FusedIterator}; use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds}; @@ -187,8 +187,8 @@ pub struct Intersection<'a, T: 'a> { } enum IntersectionInner<'a, T: 'a> { Stitch { - small_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets - other_iter: Iter<'a, T>, + a: Iter<'a, T>, + b: Iter<'a, T>, }, Search { small_iter: Iter<'a, T>, @@ -201,12 +201,12 @@ impl fmt::Debug for Intersection<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.inner { IntersectionInner::Stitch { - small_iter, - other_iter, + a, + b, } => f .debug_tuple("Intersection") - .field(&small_iter) - .field(&other_iter) + .field(&a) + .field(&b) .finish(), IntersectionInner::Search { small_iter, @@ -397,8 +397,8 @@ impl BTreeSet { // Iterate both sets jointly, spotting matches along the way. Intersection { inner: IntersectionInner::Stitch { - small_iter: small.iter(), - other_iter: other.iter(), + a: small.iter(), + b: other.iter(), }, } } else { @@ -1221,11 +1221,11 @@ impl Clone for Intersection<'_, T> { Intersection { inner: match &self.inner { IntersectionInner::Stitch { - small_iter, - other_iter, + a, + b, } => IntersectionInner::Stitch { - small_iter: small_iter.clone(), - other_iter: other_iter.clone(), + a: a.clone(), + b: b.clone(), }, IntersectionInner::Search { small_iter, @@ -1245,16 +1245,16 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> { fn next(&mut self) -> Option<&'a T> { match &mut self.inner { IntersectionInner::Stitch { - small_iter, - other_iter, + a, + b, } => { - let mut small_next = small_iter.next()?; - let mut other_next = other_iter.next()?; + let mut a_next = a.next()?; + let mut b_next = b.next()?; loop { - match Ord::cmp(small_next, other_next) { - Less => small_next = small_iter.next()?, - Greater => other_next = other_iter.next()?, - Equal => return Some(small_next), + match Ord::cmp(a_next, b_next) { + Less => a_next = a.next()?, + Greater => b_next = b.next()?, + Equal => return Some(a_next), } } } @@ -1272,7 +1272,7 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> { fn size_hint(&self) -> (usize, Option) { let min_len = match &self.inner { - IntersectionInner::Stitch { small_iter, .. } => small_iter.len(), + IntersectionInner::Stitch { a, b } => min(a.len(), b.len()), IntersectionInner::Search { small_iter, .. } => small_iter.len(), }; (0, Some(min_len)) diff --git a/src/liballoc/tests/btree/set.rs b/src/liballoc/tests/btree/set.rs index 62ccb53fcea18..35db18c39c83a 100644 --- a/src/liballoc/tests/btree/set.rs +++ b/src/liballoc/tests/btree/set.rs @@ -90,6 +90,17 @@ fn test_intersection() { &[1, 3, 11, 77, 103]); } +#[test] +fn test_intersection_size_hint() { + let x: BTreeSet = [3, 4].iter().copied().collect(); + let y: BTreeSet = [1, 2, 3].iter().copied().collect(); + let mut iter = x.intersection(&y); + assert_eq!(iter.size_hint(), (0, Some(2))); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.size_hint(), (0, Some(0))); + assert_eq!(iter.next(), None); +} + #[test] fn test_difference() { fn check_difference(a: &[i32], b: &[i32], expected: &[i32]) {