diff --git a/src/librustc_mir/hair/pattern/_match.rs b/src/librustc_mir/hair/pattern/_match.rs index e928d4e5f03a6..e47e2899507ee 100644 --- a/src/librustc_mir/hair/pattern/_match.rs +++ b/src/librustc_mir/hair/pattern/_match.rs @@ -194,6 +194,7 @@ use std::cmp::{self, Ordering, min, max}; use std::fmt; use std::iter::{FromIterator, IntoIterator}; use std::ops::RangeInclusive; +use std::u128; pub fn expand_pattern<'a, 'tcx>(cx: &MatchCheckCtxt<'a, 'tcx>, pat: Pattern<'tcx>) -> &'a Pattern<'tcx> @@ -799,6 +800,7 @@ fn max_slice_length<'p, 'a: 'p, 'tcx: 'a, I>( /// /// `IntRange` is never used to encode an empty range or a "range" that wraps /// around the (offset) space: i.e. `range.lo <= range.hi`. +#[derive(Clone)] struct IntRange<'tcx> { pub range: RangeInclusive, pub ty: Ty<'tcx>, @@ -1400,9 +1402,7 @@ fn should_treat_range_exhaustively(tcx: TyCtxt<'_, 'tcx, 'tcx>, ctor: &Construct /// patterns that apply to that range (specifically: the patterns that *intersect* with that range) /// change. /// Our solution, therefore, is to split the range constructor into subranges at every single point -/// the group of intersecting patterns changes, which we can compute by converting each pattern to -/// a range and recording its endpoints, then creating subranges between each consecutive pair of -/// endpoints. +/// the group of intersecting patterns changes (using the method described below). /// And voilà! We're testing precisely those ranges that we need to, without any exhaustive matching /// on actual integers. The nice thing about this is that the number of subranges is linear in the /// number of rows in the matrix (i.e. the number of cases in the `match` statement), so we don't @@ -1414,14 +1414,14 @@ fn should_treat_range_exhaustively(tcx: TyCtxt<'_, 'tcx, 'tcx>, ctor: &Construct /// |-------| |-------| |----| || /// |---------| /// -/// We truncate the ranges so that they lie inside each range constructor and then split them -/// up into equivalence classes so the ranges are no longer overlapping: +/// We split the ranges up into equivalence classes so the ranges are no longer overlapping: /// /// |--|--|||-||||--||---|||-------| |-|||| || /// -/// The logic for determining how to split the ranges is a little involved: we need to make sure -/// that we have a new range for each subrange for which a different set of rows coïncides, but -/// essentially reduces to case analysis on the endpoints of the ranges. +/// The logic for determining how to split the ranges is fairly straightforward: we calculate +/// boundaries for each interval range, sort them, then create constructors for each new interval +/// between every pair of boundary points. (This essentially sums up to performing the intuitive +/// merging operation depicted above.) fn split_grouped_constructors<'p, 'a: 'p, 'tcx: 'a>( tcx: TyCtxt<'a, 'tcx, 'tcx>, ctors: Vec>, @@ -1440,84 +1440,54 @@ fn split_grouped_constructors<'p, 'a: 'p, 'tcx: 'a>( // `NotUseful`, which is the default case anyway, and can be ignored. let ctor_range = IntRange::from_ctor(tcx, &ctor).unwrap(); - // We're going to collect all the endpoints in the new pattern so we can create - // subranges between them. - // If there's a single point, we need to identify it as belonging - // to a length-1 range, so it can be treated as an individual - // constructor, rather than as an endpoint. To do this, we keep track of which - // endpoint a point corresponds to. Whenever a point corresponds to both a start - // and an end, then we create a unit range for it. - #[derive(PartialEq, Clone, Copy, Debug)] - enum Endpoint { - Start, - End, - Both, - }; - let mut points = FxHashMap::default(); - let add_endpoint = |points: &mut FxHashMap<_, _>, x, endpoint| { - points.entry(x).and_modify(|ex_x| { - if *ex_x != endpoint { - *ex_x = Endpoint::Both - } - }).or_insert(endpoint); - }; - let add_endpoints = |points: &mut FxHashMap<_, _>, lo, hi| { - // Insert the endpoints, taking care to keep track of to - // which endpoints a point corresponds. - add_endpoint(points, lo, Endpoint::Start); - add_endpoint(points, hi, Endpoint::End); - }; - let (lo, hi) = (*ctor_range.range.start(), *ctor_range.range.end()); - add_endpoints(&mut points, lo, hi); - // We're going to iterate through every row pattern, adding endpoints in. - for row in m.iter() { - if let Some(r) = IntRange::from_pat(tcx, row[0]) { - // We're only interested in endpoints that lie (at least partially) - // within the subrange domain. - if let Some(r) = ctor_range.intersection(&r) { - let (r_lo, r_hi) = r.range.into_inner(); - add_endpoints(&mut points, r_lo, r_hi); - } - } + /// Represents a border between 2 integers. Because the intervals spanning borders + /// must be able to cover every integer, we need 2^128 + 1 such borders. + #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + enum Border { + JustBefore(u128), + AfterMax, } - // The patterns were iterated in an arbitrary order (i.e. in the order the user - // wrote them), so we need to make sure our endpoints are sorted. - let mut points: Vec<(u128, Endpoint)> = points.into_iter().collect(); - points.sort_unstable_by_key(|(x, _)| *x); - let mut points = points.into_iter(); - let mut a = points.next().unwrap(); - - // Iterate through pairs of points, adding the subranges to `split_ctors`. - // We have to be careful about the orientation of the points as endpoints, to make - // sure we're enumerating precisely the correct ranges. Too few and the matching is - // actually incorrect. Too many and our diagnostics are poorer. This involves some - // case analysis. - // In essence, we need to ensure that every time the set of row-ranges that are - // overlapping changes (as we go through the values covered by the ranges), we split - // into a new subrange. - while let Some(b) = points.next() { - // a < b (strictly) - if let Endpoint::Both = a.1 { - split_ctors.push(IntRange::range_to_ctor(tcx, ty, a.0..=a.0)); - } - // Integer overflow cannot occur here, because only the first point may be - // u128::MIN and only the last may be u128::MAX. - let c = match a.1 { - Endpoint::Start => a.0, - Endpoint::End | Endpoint::Both => a.0 + 1, - }; - let d = match b.1 { - Endpoint::Start | Endpoint::Both => b.0 - 1, - Endpoint::End => b.0, + // A function for extracting the borders of an integer interval. + fn range_borders(r: IntRange<'_>) -> impl Iterator { + let (lo, hi) = r.range.into_inner(); + let from = Border::JustBefore(lo); + let to = match hi.checked_add(1) { + Some(m) => Border::JustBefore(m), + None => Border::AfterMax, }; - // In some cases, we won't need an intermediate range between two ranges - // lie immediately adjacent to one another. - if c <= d { - split_ctors.push(IntRange::range_to_ctor(tcx, ty, c..=d)); - } + vec![from, to].into_iter() + } - a = b; + // `borders` is the set of borders between equivalence classes: each equivalence + // class lies between 2 borders. + let row_borders = m.iter() + .flat_map(|row| IntRange::from_pat(tcx, row[0])) + .flat_map(|range| ctor_range.intersection(&range)) + .flat_map(|range| range_borders(range)); + let ctor_borders = range_borders(ctor_range.clone()); + let mut borders: Vec<_> = row_borders.chain(ctor_borders).collect(); + borders.sort_unstable(); + + // We're going to iterate through every pair of borders, making sure that each + // represents an interval of nonnegative length, and convert each such interval + // into a constructor. + for IntRange { range, .. } in borders.windows(2).filter_map(|window| { + match (window[0], window[1]) { + (Border::JustBefore(n), Border::JustBefore(m)) => { + if n < m { + Some(IntRange { range: n..=(m - 1), ty }) + } else { + None + } + } + (Border::JustBefore(n), Border::AfterMax) => { + Some(IntRange { range: n..=u128::MAX, ty }) + } + (Border::AfterMax, _) => None, + } + }) { + split_ctors.push(IntRange::range_to_ctor(tcx, ty, range)); } } // Any other constructor can be used unchanged. diff --git a/src/test/ui/exhaustive_integer_patterns.rs b/src/test/ui/exhaustive_integer_patterns.rs index 50fc825e74e09..a8e9e74905c7b 100644 --- a/src/test/ui/exhaustive_integer_patterns.rs +++ b/src/test/ui/exhaustive_integer_patterns.rs @@ -158,8 +158,8 @@ fn main() { _ => {} } - const lim: u128 = u128::MAX - 1; + const LIM: u128 = u128::MAX - 1; match 0u128 { //~ ERROR non-exhaustive patterns - 0 ..= lim => {} + 0 ..= LIM => {} } }