Skip to content

Commit

Permalink
fix up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Aug 20, 2024
1 parent 335744a commit 43cf77e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 26 deletions.
32 changes: 25 additions & 7 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
subst: &Subst,
) -> Explanation<L> {
let left = self.add_expr_uncanonical(left_expr);
let right = self.add_instantiation_noncanonical(
right_pattern,
subst,
Some(ExistanceReason::Direct),
);
let right = self.add_instantiation_noncanonical(right_pattern, subst, None);

if self.find(left) != self.find(right) {
panic!(
Expand Down Expand Up @@ -1192,8 +1188,30 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
// add the lhs directly
let id1 =
self.add_instantiation_noncanonical(from_pat, subst, Some(ExistanceReason::Direct));
// add the rhs directly
let id2 = self.add_instantiation_noncanonical(to_pat, subst, Some(ExistanceReason::Direct));
// add the rhs, with reason equal to lhs
let id2 =
self.add_instantiation_noncanonical(to_pat, subst, Some(ExistanceReason::EqualTo(id1)));

let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
}

/// Like `union_instantiations`, but assumes that the `from_pat` and substitution
/// is guaranteed to match the egraph already.
/// Using this method makes existance explanations more precise.
pub fn union_instantiations_guaranteed_match(
&mut self,
from_pat: &PatternAst<L>,
to_pat: &PatternAst<L>,
subst: &Subst,
rule_name: impl Into<Symbol>,
) -> (Id, bool) {
// add the lhs without an existance reason,
// assuming it matches
let id1 = self.add_instantiation_noncanonical(from_pat, subst, None);
// add the rhs, making it equal to the lhs
let id2 =
self.add_instantiation_noncanonical(to_pat, subst, Some(ExistanceReason::EqualTo(id1)));

let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
Expand Down
2 changes: 1 addition & 1 deletion src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,7 +1325,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
let adj = self.explain_adjacent(connection, cache, enode_cache, false);
let mut exp = self.explain_term_existance(adjacent_id, adj, cache, enode_cache);
exp.push(rest_of_proof);
return exp;
exp
}
ExistanceReason::Direct => {
vec![self.node_to_explanation(term, enode_cache), rest_of_proof]
Expand Down
29 changes: 17 additions & 12 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub fn test_runner<L, A>(
goals: &[Pattern<L>],
check_fn: Option<fn(Runner<L, A, ()>)>,
should_check: bool,
check_existance_explanations: bool,
) where
L: Language + Display + FromOp + 'static,
A: Analysis<L> + Default,
Expand Down Expand Up @@ -109,17 +110,19 @@ pub fn test_runner<L, A>(

// now check for existance of the goal
// it should exist due to the start expression
let mut existance_proof = runner.explain_existance_pattern(&goal.ast, &subst);
existance_proof.check_proof(rules);
let first_term_in_existance_proof: RecExpr<L> =
existance_proof.make_flat_explanation()[0].get_recexpr();
if !has_initial_expression {
assert_eq!(
first_term_in_existance_proof,
start,
"Existance proof failed to find original term. Existance proof: {}",
existance_proof.get_flat_string()
);
if check_existance_explanations {
let mut existance_proof = runner.explain_existance_pattern(&goal.ast, &subst);
existance_proof.check_proof(rules);
let first_term_in_existance_proof: RecExpr<L> =
existance_proof.make_flat_explanation()[0].get_recexpr();
if !has_initial_expression {
assert_eq!(
first_term_in_existance_proof,
start,
"Existance proof failed to find original term. Existance proof: {}",
existance_proof.get_flat_string()
);
}
}

runner = runner.with_explanation_length_optimization();
Expand Down Expand Up @@ -270,7 +273,8 @@ macro_rules! test_fn {
$start:literal
=>
$($goal:literal),+ $(,)?
$(@check $check_fn:expr)?
$(@check $check_fn:expr,)?
$(@existance $check_existance_explanations:expr)?
) => {

$(#[$meta])*
Expand All @@ -286,6 +290,7 @@ macro_rules! test_fn {
&[$( $goal.parse().unwrap() ),+],
None $(.or(Some($check_fn)))?,
check,
true $(&& $check_existance_explanations)?,
)
}};
}
20 changes: 16 additions & 4 deletions tests/lambda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl Applier<Lambda, LambdaAnalysis> for CaptureAvoid {
fn apply_one(
&self,
egraph: &mut EGraph,
eclass: Id,
term: Id,
subst: &Subst,
searcher_ast: Option<&PatternAst<Lambda>>,
rule_name: Symbol,
Expand All @@ -185,13 +185,13 @@ impl Applier<Lambda, LambdaAnalysis> for CaptureAvoid {
let v2_free_in_e = egraph[e].data.free.contains(&v2);
if v2_free_in_e {
let mut subst = subst.clone();
let sym = Lambda::Symbol(format!("_{}", eclass).into());
let sym = Lambda::Symbol(format!("_{}", term).into());
subst.insert(self.fresh, egraph.add(sym));
self.if_free
.apply_one(egraph, eclass, &subst, searcher_ast, rule_name)
.apply_one(egraph, term, &subst, searcher_ast, rule_name)
} else {
self.if_not_free
.apply_one(egraph, eclass, subst, searcher_ast, rule_name)
.apply_one(egraph, term, subst, searcher_ast, rule_name)
}
}
}
Expand All @@ -205,6 +205,7 @@ egg::test_fn! {
// "(lam x (+ 4 (let y 4 (var y))))",
// "(lam x (+ 4 4))",
"(lam x 8))",
@existance false
}

egg::test_fn! {
Expand All @@ -214,6 +215,7 @@ egg::test_fn! {
(+ (var a) (var b)))"
=>
"(+ (var a) (var b))"
@existance false
}

egg::test_fn! {
Expand All @@ -226,18 +228,21 @@ egg::test_fn! {
// (+ (var ?a) 1))",
// "(+ 0 1)",
"1",
@existance false
}

egg::test_fn! {
#[should_panic(expected = "Could not prove goal 0")]
lambda_capture, rules(),
"(let x 1 (lam x (var x)))" => "(lam x 1)"
@existance false
}

egg::test_fn! {
#[should_panic(expected = "Could not prove goal 0")]
lambda_capture_free, rules(),
"(let y (+ (var x) (var x)) (lam x (var y)))" => "(lam x (+ (var x) (var x)))"
@existance false
}

egg::test_fn! {
Expand All @@ -249,6 +254,7 @@ egg::test_fn! {
(app (var add-five) 1))))"
=>
"7"
@existance false
}

egg::test_fn! {
Expand All @@ -262,11 +268,13 @@ egg::test_fn! {
(app (lam ?y (+ 1 (var ?y)))
(var ?x))))",
"(lam ?x (+ (var ?x) 2))"
@existance false
}

egg::test_fn! {
lambda_if_simple, rules(),
"(if (= 1 1) 7 9)" => "7"
@existance false
}

egg::test_fn! {
Expand All @@ -283,6 +291,7 @@ egg::test_fn! {
(var add1)))))))))"
=>
"(lam ?x (+ (var ?x) 7))"
@existance false
}

egg::test_fn! {
Expand All @@ -308,6 +317,7 @@ egg::test_fn! {
2))))"
=>
"(lam ?x (+ (var ?x) 2))"
@existance false
}

egg::test_fn! {
Expand All @@ -322,6 +332,7 @@ egg::test_fn! {
// "(+ (if false 0 1) (if true 0 1))",
// "(+ 1 0)",
"1",
@existance false
}

egg::test_fn! {
Expand All @@ -342,6 +353,7 @@ egg::test_fn! {
(+ (var n) -2)))))))
(app (var fib) 4))"
=> "3"
@existance false
}

#[test]
Expand Down
4 changes: 2 additions & 2 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Analysis<Math> for ConstantFold {
let data = egraph[id].data.clone();
if let Some((c, pat)) = data {
if egraph.are_explanations_enabled() {
egraph.union_instantiations(
egraph.union_instantiations_guaranteed_match(
&pat,
&format!("{}", c).parse().unwrap(),
&Default::default(),
Expand Down Expand Up @@ -221,7 +221,7 @@ egg::test_fn! {
"(+ 1 (+ 2 (+ 3 (+ 4 (+ 5 (+ 6 7))))))"
=>
"(+ 7 (+ 6 (+ 5 (+ 4 (+ 3 (+ 2 1))))))"
@check |r: Runner<Math, ()>| assert_eq!(r.egraph.number_of_classes(), 127)
@check |r: Runner<Math, ()>| assert_eq!(r.egraph.number_of_classes(), 127),
}

egg::test_fn! {
Expand Down

0 comments on commit 43cf77e

Please sign in to comment.