diff --git a/src/inclusive_set.rs b/src/inclusive_set.rs index 7d3d90d..c98c978 100644 --- a/src/inclusive_set.rs +++ b/src/inclusive_set.rs @@ -1,7 +1,7 @@ use core::borrow::Borrow; use core::fmt::{self, Debug}; use core::iter::{DoubleEndedIterator, FromIterator}; -use core::ops::RangeInclusive; +use core::ops::{BitAnd, BitOr, RangeInclusive}; #[cfg(feature = "serde1")] use core::marker::PhantomData; @@ -14,6 +14,12 @@ use serde::{ use crate::std_ext::*; use crate::RangeInclusiveMap; +/// Intersection iterator over two [`RangeInclusiveSet`]. +pub type Intersection<'a, T> = crate::operations::Intersection<'a, RangeInclusive, Iter<'a, T>>; + +/// Union iterator over two [`RangeInclusiveSet`]. +pub type Union<'a, T> = crate::operations::Union<'a, RangeInclusive, Iter<'a, T>>; + #[derive(Clone, Hash, Default, Eq, PartialEq, PartialOrd, Ord)] /// A set whose items are stored as ranges bounded /// inclusively below and above `(start..=end)`. @@ -172,6 +178,16 @@ where pub fn last(&self) -> Option<&RangeInclusive> { self.rm.last_range_value().map(|(range, _)| range) } + + /// Return an iterator over the union of two range sets. + pub fn union<'a>(&'a self, other: &'a Self) -> Union<'a, T> { + Union::new(self.iter(), other.iter()) + } + + /// Return an iterator over the intersection of two range sets. + pub fn intersection<'a>(&'a self, other: &'a Self) -> Intersection<'a, T> { + Intersection::new(self.iter(), other.iter()) + } } /// An iterator over the ranges of a `RangeInclusiveSet`. @@ -437,6 +453,22 @@ macro_rules! range_inclusive_set { }}; } +impl BitAnd for &RangeInclusiveSet { + type Output = RangeInclusiveSet; + + fn bitand(self, other: Self) -> Self::Output { + self.intersection(other).collect() + } +} + +impl BitOr for &RangeInclusiveSet { + type Output = RangeInclusiveSet; + + fn bitor(self, other: Self) -> Self::Output { + self.union(other).collect() + } +} + #[cfg(test)] mod tests { use super::*; @@ -493,12 +525,17 @@ mod tests { assert_eq!(forward, backward); } - #[proptest] - fn test_arbitrary_set_u8(ranges: Vec>) { - let ranges: Vec<_> = ranges + // neccessary due to assertion on empty ranges + fn filter_ranges(ranges: Vec>) -> Vec> { + ranges .into_iter() .filter(|range| range.start() != range.end()) - .collect(); + .collect() + } + + #[proptest] + fn test_arbitrary_set_u8(ranges: Vec>) { + let ranges: Vec<_> = filter_ranges(ranges); let set = ranges .iter() .fold(RangeInclusiveSet::new(), |mut set, range| { @@ -530,6 +567,69 @@ mod tests { ); } + #[proptest] + fn test_union_overlaps_u8(left: Vec>, right: Vec>) { + let left: RangeInclusiveSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeInclusiveSet<_> = filter_ranges(right).into_iter().collect(); + + let mut union = RangeInclusiveSet::new(); + for range in left.union(&right) { + // there should not be any overlaps in the ranges returned by the union + assert!(union.overlapping(&range).next().is_none()); + union.insert(range); + } + } + + #[proptest] + fn test_union_contains_u8(left: Vec>, right: Vec>) { + let left: RangeInclusiveSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeInclusiveSet<_> = filter_ranges(right).into_iter().collect(); + let union: RangeInclusiveSet<_> = left.union(&right).collect(); + + // value should be in the union if and only if it is in either set + for value in 0..u8::MAX { + assert_eq!( + union.contains(&value), + left.contains(&value) || right.contains(&value) + ); + } + } + + #[proptest] + fn test_intersection_contains_u8( + left: Vec>, + right: Vec>, + ) { + let left: RangeInclusiveSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeInclusiveSet<_> = filter_ranges(right).into_iter().collect(); + let union: RangeInclusiveSet<_> = left.intersection(&right).collect(); + + // value should be in the union if and only if it is in either set + for value in 0..u8::MAX { + assert_eq!( + union.contains(&value), + left.contains(&value) && right.contains(&value) + ); + } + } + + #[proptest] + fn test_intersection_overlaps_u8( + left: Vec>, + right: Vec>, + ) { + let left: RangeInclusiveSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeInclusiveSet<_> = filter_ranges(right).into_iter().collect(); + + let mut union = RangeInclusiveSet::new(); + for range in left.intersection(&right) { + // there should not be any overlaps in the ranges returned by the + // intersection + assert!(union.overlapping(&range).next().is_none()); + union.insert(range); + } + } + trait RangeInclusiveSetExt { fn to_vec(&self) -> Vec>; } diff --git a/src/lib.rs b/src/lib.rs index c2e7998..b994c88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,6 +134,7 @@ extern crate alloc; pub mod inclusive_map; pub mod inclusive_set; pub mod map; +pub(crate) mod operations; pub mod set; #[cfg(test)] diff --git a/src/operations.rs b/src/operations.rs new file mode 100644 index 0000000..c5cb5ac --- /dev/null +++ b/src/operations.rs @@ -0,0 +1,246 @@ +use core::cmp::Ordering; +use core::iter::{FusedIterator, Peekable}; +use core::ops::{Range, RangeInclusive}; + +/// Trait to determine the ordering of the start and end of a range. +trait RangeOrder { + /// Ordering of start value + fn order_start(&self, other: &Self) -> Ordering; + + /// Ordering of end value + fn order_end(&self, other: &Self) -> Ordering; +} + +impl RangeOrder for Range { + fn order_start(&self, other: &Self) -> Ordering { + self.start.cmp(&other.start) + } + + fn order_end(&self, other: &Self) -> Ordering { + self.end.cmp(&other.end) + } +} + +impl RangeOrder for RangeInclusive { + fn order_start(&self, other: &Self) -> Ordering { + self.start().cmp(other.start()) + } + + fn order_end(&self, other: &Self) -> Ordering { + self.end().cmp(&other.end()) + } +} + +/// Range which can be merged with a next range if they overlap. +trait RangeMerge: Sized { + /// Merges this range and the next range, if they overlap. + fn merge(&mut self, next: &Self) -> bool; +} + +impl RangeMerge for Range { + fn merge(&mut self, other: &Self) -> bool { + if !self.contains(&other.start) { + return false; + } + + if other.end > self.end { + self.end = other.end.clone(); + } + + true + } +} + +impl RangeMerge for RangeInclusive { + fn merge(&mut self, other: &Self) -> bool { + if !self.contains(other.start()) { + return false; + } + + if other.end() > self.end() { + *self = RangeInclusive::new(self.start().clone(), other.end().clone()); + } + + true + } +} + +/// Range which can be merged with a next range if they overlap. +trait RangeIntersect: Sized { + /// Attempt to merge the next range into the current range, if they overlap. + fn intersect(&self, next: &Self) -> Option; +} + +impl RangeIntersect for Range { + fn intersect(&self, other: &Self) -> Option { + let start = (&self.start).max(&other.start); + let end = (&self.end).min(&other.end); + + if start >= end { + return None; + } + + Some(start.clone()..end.clone()) + } +} + +impl RangeIntersect for RangeInclusive { + fn intersect(&self, other: &Self) -> Option { + let start = self.start().max(other.start()); + let end = self.end().min(other.end()); + + if start > end { + return None; + } + + Some(start.clone()..=end.clone()) + } +} + +#[test] +fn test_intersect() { + assert_eq!((0..5).intersect(&(0..3)), Some(0..3)); + assert_eq!((0..3).intersect(&(0..5)), Some(0..3)); + assert_eq!((0..3).intersect(&(3..3)), None); +} + +/// Iterator that produces the union of two iterators of sorted ranges. +pub struct Union<'a, T, L, R = L> +where + T: 'a, + L: Iterator, + R: Iterator, +{ + left: Peekable, + right: Peekable, +} + +impl<'a, T, L, R> Union<'a, T, L, R> +where + T: 'a, + L: Iterator, + R: Iterator, +{ + /// Create new Union iterator. + /// + /// Requires that the two iterators produce sorted ranges. + pub fn new(left: L, right: R) -> Self { + Self { + left: left.peekable(), + right: right.peekable(), + } + } +} + +impl<'a, R, I> Iterator for Union<'a, R, I> +where + R: RangeOrder + RangeMerge + Clone, + I: Iterator, +{ + type Item = R; + + fn next(&mut self) -> Option { + // get start range + let mut range = match (self.left.peek(), self.right.peek()) { + // if there is only one possible range, pick that + (Some(_), None) => self.left.next().unwrap(), + (None, Some(_)) => self.right.next().unwrap(), + // when there are two ranges, pick the one with the earlier start + (Some(left), Some(right)) => { + if left.order_start(right).is_lt() { + self.left.next().unwrap() + } else { + self.right.next().unwrap() + } + } + // otherwise we are done + (None, None) => return None, + } + .clone(); + + // peek into next value of iterator and merge if it is contiguous + let mut join = |iter: &mut Peekable| { + if let Some(next) = iter.peek() { + if range.merge(next) { + iter.next().unwrap(); + return true; + } + } + false + }; + + // keep merging ranges as long as we can + loop { + if !(join(&mut self.left) || join(&mut self.right)) { + break; + } + } + + Some(range) + } +} + +impl<'a, R, I> FusedIterator for Union<'a, R, I> +where + R: RangeOrder + RangeMerge + Clone, + I: Iterator, +{ +} + +/// Iterator that produces the union of two iterators of sorted ranges. +pub struct Intersection<'a, T, L, R = L> +where + T: 'a, + L: Iterator, + R: Iterator, +{ + left: Peekable, + right: Peekable, +} + +impl<'a, T, L, R> Intersection<'a, T, L, R> +where + T: 'a, + L: Iterator, + R: Iterator, +{ + /// Create new Intersection iterator. + /// + /// Requires that the two iterators produce sorted ranges. + pub fn new(left: L, right: R) -> Self { + Self { + left: left.peekable(), + right: right.peekable(), + } + } +} + +impl<'a, R, I> Iterator for Intersection<'a, R, I> +where + R: RangeOrder + RangeIntersect + Clone, + I: Iterator, +{ + type Item = R; + + fn next(&mut self) -> Option { + loop { + // if we don't have at least two ranges, there cannot be an intersection + let (Some(left), Some(right)) = (self.left.peek(), self.right.peek()) else { + return None; + }; + + let intersection = left.intersect(right); + + // pop the range that ends earlier + if left.order_end(right).is_lt() { + self.left.next(); + } else { + self.right.next(); + } + + if let Some(intersection) = intersection { + return Some(intersection); + } + } + } +} diff --git a/src/set.rs b/src/set.rs index 122d143..99fc0f0 100644 --- a/src/set.rs +++ b/src/set.rs @@ -1,7 +1,7 @@ use core::borrow::Borrow; use core::fmt::{self, Debug}; -use core::iter::{DoubleEndedIterator, FromIterator}; -use core::ops::Range; +use core::iter::FromIterator; +use core::ops::{BitAnd, BitOr, Range}; use core::prelude::v1::*; #[cfg(feature = "serde1")] @@ -14,6 +14,12 @@ use serde::{ use crate::RangeMap; +/// Intersection iterator over two [`RangeSet`]. +pub type Intersection<'a, T> = crate::operations::Intersection<'a, Range, Iter<'a, T>>; + +/// Union iterator over two [`RangeSet`]. +pub type Union<'a, T> = crate::operations::Union<'a, Range, Iter<'a, T>>; + #[derive(Clone, Hash, Default, Eq, PartialEq, PartialOrd, Ord)] /// A set whose items are stored as (half-open) ranges bounded /// inclusively below and exclusively above `(start..end)`. @@ -78,6 +84,16 @@ where self.rm.is_empty() } + /// Return an iterator over the intersection of two range sets. + pub fn intersection<'a>(&'a self, other: &'a Self) -> Intersection<'a, T> { + Intersection::new(self.iter(), other.iter()) + } + + /// Return an iterator over the union of two range sets. + pub fn union<'a>(&'a self, other: &'a Self) -> Union<'a, T> { + Union::new(self.iter(), other.iter()) + } + /// Insert a range into the set. /// /// If the inserted range either overlaps or is immediately adjacent @@ -378,6 +394,22 @@ where } } +impl BitAnd for &RangeSet { + type Output = RangeSet; + + fn bitand(self, other: Self) -> Self::Output { + self.intersection(other).collect() + } +} + +impl BitOr for &RangeSet { + type Output = RangeSet; + + fn bitor(self, other: Self) -> Self::Output { + self.union(other).collect() + } +} + impl From<[Range; N]> for RangeSet { fn from(value: [Range; N]) -> Self { let mut set = Self::new(); @@ -459,12 +491,17 @@ mod tests { assert_eq!(forward, backward); } - #[proptest] - fn test_arbitrary_set_u8(ranges: Vec>) { - let ranges: Vec<_> = ranges + // neccessary due to assertion on empty ranges + fn filter_ranges(ranges: Vec>) -> Vec> { + ranges .into_iter() .filter(|range| range.start != range.end) - .collect(); + .collect() + } + + #[proptest] + fn test_arbitrary_set_u8(ranges: Vec>) { + let ranges = filter_ranges(ranges); let set = ranges.iter().fold(RangeSet::new(), |mut set, range| { set.insert(range.clone()); set @@ -494,6 +531,63 @@ mod tests { ); } + #[proptest] + fn test_union_overlaps_u8(left: Vec>, right: Vec>) { + let left: RangeSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeSet<_> = filter_ranges(right).into_iter().collect(); + + let mut union = RangeSet::new(); + for range in left.union(&right) { + // there should not be any overlaps in the ranges returned by the union + assert!(union.overlapping(&range).next().is_none()); + union.insert(range); + } + } + + #[proptest] + fn test_union_contains_u8(left: Vec>, right: Vec>) { + let left: RangeSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeSet<_> = filter_ranges(right).into_iter().collect(); + let union: RangeSet<_> = left.union(&right).collect(); + + // value should be in the union if and only if it is in either set + for value in 0..u8::MAX { + assert_eq!( + union.contains(&value), + left.contains(&value) || right.contains(&value) + ); + } + } + + #[proptest] + fn test_intersection_contains_u8(left: Vec>, right: Vec>) { + let left: RangeSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeSet<_> = filter_ranges(right).into_iter().collect(); + let union: RangeSet<_> = left.intersection(&right).collect(); + + // value should be in the union if and only if it is in either set + for value in 0..u8::MAX { + assert_eq!( + union.contains(&value), + left.contains(&value) && right.contains(&value) + ); + } + } + + #[proptest] + fn test_intersection_overlaps_u8(left: Vec>, right: Vec>) { + let left: RangeSet<_> = filter_ranges(left).into_iter().collect(); + let right: RangeSet<_> = filter_ranges(right).into_iter().collect(); + + let mut union = RangeSet::new(); + for range in left.intersection(&right) { + // there should not be any overlaps in the ranges returned by the + // intersection + assert!(union.overlapping(&range).next().is_none()); + union.insert(range); + } + } + trait RangeSetExt { fn to_vec(&self) -> Vec>; }