Skip to content

Commit

Permalink
Fix Recursive not shrinking to non-recursive case.
Browse files Browse the repository at this point in the history
  • Loading branch information
AltSysrq committed Feb 24, 2018
1 parent b107e3f commit 6dee486
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 53 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## Unreleased

### New Additions

- `proptest::strategy::Union` and `proptest::strategy::TupleUnion` now work
with weighted strategies even if the sum of the weights overflows a `u32`.

### Bug Fixes

- Fixed values produced via `prop_recursive()` not shrinking from the recursive
to the non-recursive case.

## 0.5.0

### Potential Breaking Changes
Expand Down
95 changes: 45 additions & 50 deletions src/strategy/recursive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,9 @@ use std::fmt;
use std::sync::Arc;

use strategy::traits::*;
use strategy::IndFlatten;
use strategy::statics::{Map, MapFn};
use strategy::unions::float_to_weight;
use test_runner::*;

/// A branching `MapFn` that picks `yes` (true) or
/// `no` (false) and clones the `Arc`.
struct BranchFn<T> {
/// Result on true.
yes: Arc<T>,
/// Result on false.
no: Arc<T>,
}

impl<T> Clone for BranchFn<T> {
fn clone(&self) -> Self {
Self { yes: Arc::clone(&self.yes), no: Arc::clone(&self.no) }
}
}

impl<T: fmt::Debug> MapFn<bool> for BranchFn<T> {
type Output = Arc<T>;
fn apply(&self, branch: bool) -> Self::Output {
Arc::clone(if branch { &self.yes } else { &self.no })
}
}

/// Return type from `Strategy::prop_recursive()`.
pub struct Recursive<B, F> {
pub(super) base: Arc<B>,
Expand Down Expand Up @@ -144,11 +121,15 @@ where
let recursed = (self.recurse)(Arc::clone(&strat));
let recursive_choice = Arc::new(recursed.boxed());
let non_recursive_choice = strat;
let branch_cond = ::bool::weighted(branch_probability.min(0.9));
let branch = IndFlatten(Map::new(branch_cond, BranchFn {
yes: recursive_choice,
no: non_recursive_choice,
}));
// Clamp the maximum branch probability to 0.9 to ensure we can
// generate non-recursive cases reasonably often.
let branch_probability = branch_probability.min(0.9);
let (weight_branch, weight_leaf) =
float_to_weight(branch_probability);
let branch = prop_oneof![
weight_leaf => non_recursive_choice,
weight_branch => recursive_choice,
];
strat = Arc::new(branch.boxed());
}

Expand All @@ -163,33 +144,33 @@ mod test {
use strategy::just::Just;
use super::*;

#[test]
fn test_recursive() {
#[derive(Clone, Debug)]
enum Tree {
Leaf,
Branch(Vec<Tree>),
}
#[derive(Clone, Debug, PartialEq)]
enum Tree {
Leaf,
Branch(Vec<Tree>),
}

impl Tree {
fn stats(&self) -> (u32, u32) {
match *self {
Tree::Leaf => (0, 1),
Tree::Branch(ref children) => {
let mut depth = 0;
let mut count = 0;
for child in children {
let (d, c) = child.stats();
depth = max(d, depth);
count += c;
}

(depth + 1, count + 1)
impl Tree {
fn stats(&self) -> (u32, u32) {
match *self {
Tree::Leaf => (0, 1),
Tree::Branch(ref children) => {
let mut depth = 0;
let mut count = 0;
for child in children {
let (d, c) = child.stats();
depth = max(d, depth);
count += c;
}

(depth + 1, count + 1)
}
}
}
}

#[test]
fn test_recursive() {
let mut max_depth = 0;
let mut max_count = 0;

Expand All @@ -209,4 +190,18 @@ mod test {
assert!(max_depth >= 3, "Only got max depth {}", max_depth);
assert!(max_count > 48, "Only got max count {}", max_count);
}

#[test]
fn simplifies_to_non_recursive() {
let strat = Just(Tree::Leaf).prop_recursive(4, 64, 16,
|element| ::collection::vec(element, 8..16).prop_map(Tree::Branch));

let mut runner = TestRunner::default();
for _ in 0..256 {
let mut value = strat.new_value(&mut runner).unwrap();
while value.simplify() { }

assert_eq!(Tree::Leaf, value.current());
}
}
}
3 changes: 3 additions & 0 deletions src/strategy/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ pub trait Strategy : fmt::Debug {
/// `expected_branch_size` (though it is not a hard limit) since the
/// underlying code underestimates probabilities.
///
/// Shrinking shrinks both the inner values and attempts switching from
/// recursive to non-recursive cases.
///
/// ## Example
///
/// ```rust,norun
Expand Down
6 changes: 3 additions & 3 deletions src/strategy/unions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ impl<T : Strategy> Union<T> {

fn pick_weighted<I : Iterator<Item = u32>>(runner: &mut TestRunner,
weights1: I, weights2: I) -> usize {
let sum = weights1.sum();
let sum = weights1.map(u64::from).sum();
let weighted_pick = rand::distributions::Range::new(0, sum)
.ind_sample(runner.rng());
weights2.scan(0, |state, w| {
*state += w;
weights2.scan(0u64, |state, w| {
*state += u64::from(w);
Some(*state)
}).filter(|&v| v <= weighted_pick).count()
}
Expand Down

0 comments on commit 6dee486

Please sign in to comment.