Skip to content

Commit

Permalink
use tagged lambdas and variables
Browse files Browse the repository at this point in the history
  • Loading branch information
kavigupta committed Nov 14, 2023
1 parent 9a30d03 commit 00c763c
Show file tree
Hide file tree
Showing 12 changed files with 191 additions and 44 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ rand = "0.8.4"
parking_lot = "0.12.0"
colorful = "0.2.1"
rustc-hash = "1.1.0"
lambdas = { git = "https://github.com/mlb2251/lambdas", rev = "56bc9ee"}
lambdas = "0.2.0"
# lambdas = { git = "https://github.com/mlb2251/lambdas", rev = "56bc9ee"}


# [patch.crates-io]
# lambdas = { path = "../lambdas"}
[patch.crates-io]
lambdas = { path = "../lambdas"}


# enable for flamegraphs
Expand Down
4 changes: 4 additions & 0 deletions data/basic/simple3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
"(a (lam_1 (a a)))",
"(b (lam_1 (b b)))"
]
4 changes: 4 additions & 0 deletions data/basic/simple4.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
"(a (lam_1 (a $0_1 $0_1)))",
"(b (lam_1 (b $0_1 $0_1)))"
]
4 changes: 4 additions & 0 deletions data/basic/simple5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[
"(a (lam_1 (a $0_1 $0_1)))",
"(b (lam_1 (b $0_2 $0_2)))"
]
41 changes: 41 additions & 0 deletions data/expected_outputs/simple3-a1-i1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"cmd": null,
"args": null,
"original_cost": 606,
"final_cost": 402,
"compression_ratio": 1.507462686567164,
"num_abstractions": 1,
"original": [
"(a (lam_1 (a a)))",
"(b (lam_1 (b b)))"
],
"rewritten": [
"(fn_0 a)",
"(fn_0 b)"
],
"rewritten_dreamcoder": null,
"test_output": null,
"abstractions": [
{
"body": "(#0 (lam_1 (#0 #0)))",
"dreamcoder": "#(lambda ($0 (lambda_1 ($1 $1))))",
"arity": 1,
"name": "fn_0",
"utility": 201,
"final_cost": 402,
"compression_ratio": 1.507462686567164,
"cumulative_compression_ratio": 1.507462686567164,
"num_uses": 2,
"rewritten": null,
"rewritten_dreamcoder": null,
"uses": [
{
"fn_0 a": "(a (lam_1 (a a)))"
},
{
"fn_0 b": "(b (lam_1 (b b)))"
}
]
}
]
}
39 changes: 39 additions & 0 deletions data/expected_outputs/simple4-a1-i1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"original_cost": 808,
"final_cost": 402,
"compression_ratio": 2.009950248756219,
"num_abstractions": 1,
"original": [
"(a (lam_1 (a $0_1 $0_1)))",
"(b (lam_1 (b $0_1 $0_1)))"
],
"rewritten": [
"(fn_0 a)",
"(fn_0 b)"
],
"rewritten_dreamcoder": null,
"abstractions": [
{
"body": "(#0 (lam_1 (#0 $0_1 $0_1)))",
"dreamcoder": "#(lambda ($0 (lambda_1 ($1 $0_1 $0_1))))",
"arity": 1,
"name": "fn_0",
"utility": 202,
"final_cost": 402,
"compression_ratio": 2.009950248756219,
"cumulative_compression_ratio": 2.009950248756219,
"num_uses": 2,
"rewritten": null,
"rewritten_dreamcoder": null,
"uses": [
{
"fn_0 a": "(a (lam_1 (a $0_1 $0_1)))"
},
{
"fn_0 b": "(b (lam_1 (b $0_1 $0_1)))"
}
],
"dc_comparison_millis": null
}
]
}
16 changes: 16 additions & 0 deletions data/expected_outputs/simple5-a1-i1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"original_cost": 808,
"final_cost": 808,
"compression_ratio": 1.0,
"num_abstractions": 0,
"original": [
"(a (lam_1 (a $0_1 $0_1)))",
"(b (lam_1 (b $0_2 $0_2)))"
],
"rewritten": [
"(a (lam_1 (a $0_1 $0_1)))",
"(b (lam_1 (b $0_2 $0_2)))"
],
"rewritten_dreamcoder": null,
"abstractions": []
}
71 changes: 44 additions & 27 deletions src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ fn zids_of_ivar_of_expr(expr: &ExprOwned, zid_of_zip: &FxHashMap<Vec<ZNode>,ZId>
fn helper(expr: Expr, curr_zip: &mut Vec<ZNode>, zids_of_ivar: &mut Vec<Vec<ZId>>, zid_of_zip: &FxHashMap<Vec<ZNode>,ZId>) -> Result<(), ()> {
match expr.node() {
Node::Prim(_) => {},
Node::Var(_) => {},
Node::Var(_, _) => {},
Node::IVar(i) => {
zids_of_ivar[*i as usize].push(zid_of_zip.get(curr_zip).cloned().ok_or(())?);
},
Node::Lam(b) => {
Node::Lam(b, _) => {
curr_zip.push(ZNode::Body);
helper(expr.get(*b), curr_zip, zids_of_ivar, zid_of_zip)?;
curr_zip.pop();
Expand Down Expand Up @@ -380,7 +380,7 @@ impl Pattern {
// function type body would be effectively arity 3 and dreamcoder doesnt support this sort of thing.
match_locations.retain(|node| node != f);

if let Node::Lam(_) = &set[*f] {
if let Node::Lam(_, _) = &set[*f] {
panic!("corpus was not in beta-normal form")
}
}
Expand All @@ -396,7 +396,7 @@ impl Pattern {
}
assert!(match_locations.contains(x), "corpus was not in eta long form (?). This appeared both to the left and right of an app: {}; for example it is to the right in: {}", set.get(*x), set.get(node));
},
Node::Lam(b) => {
Node::Lam(b, _) => {
if !AnalyzedExpr::new(FreeVarAnalysis).analyze_get(set.get(*b)).is_empty() {
continue
}
Expand All @@ -409,7 +409,10 @@ impl Pattern {

// to guarantee eta long we cant allow abstractions to start with a lambda at the top
if cfg.eta_long {
match_locations.retain(|node| expands_to_of_node(&set[*node]) != ExpandsTo::Lam);
match_locations.retain(|node| match expands_to_of_node(&set[*node]) {
ExpandsTo::Lam(_) => false,
_ => true,
});
}

let utility_upper_bound = utility_upper_bound(&match_locations, body_utility, cost_of_node_all, num_paths_to_node, cost_fn, cfg);
Expand Down Expand Up @@ -440,12 +443,12 @@ impl Pattern {
// no ivar zip match, so recurse
match &shared.set[curr_node] {
Node::Prim(p) => set.add(Node::Prim(p.clone())),
Node::Var(v) => set.add(Node::Var(*v)),
Node::Lam(b) => {
Node::Var(v, tag) => set.add(Node::Var(*v, *tag)),
Node::Lam(b, tag) => {
curr_zip.push(ZNode::Body);
let b_idx = helper(set, *b, curr_zip, zips, shared);
curr_zip.pop();
set.add(Node::Lam(b_idx))
set.add(Node::Lam(b_idx, *tag))
}
Node::App(f,x) => {
curr_zip.push(ZNode::Func);
Expand Down Expand Up @@ -481,9 +484,9 @@ impl Pattern {
/// Tells us what a hole will expand into at this node.
#[derive(Debug, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub enum ExpandsTo {
Lam,
Lam(Tag),
App,
Var(i32),
Var(i32, Tag),
Prim(Symbol),
IVar(i32),
}
Expand All @@ -494,9 +497,9 @@ impl ExpandsTo {
#[allow(dead_code)]
fn has_holes(&self) -> bool {
match self {
ExpandsTo::Lam => true,
ExpandsTo::Lam(_) => true,
ExpandsTo::App => true,
ExpandsTo::Var(_) => false,
ExpandsTo::Var(_, _) => false,
ExpandsTo::Prim(_) => false,
ExpandsTo::IVar(_) => false,
}
Expand All @@ -511,9 +514,21 @@ impl ExpandsTo {
impl std::fmt::Display for ExpandsTo {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
ExpandsTo::Lam => write!(f, "(lam ??)"),
ExpandsTo::Lam(tag) => {
write!(f, "(lam")?;
if *tag != -1 {
write!(f, "_{}", tag)?;
}
write!(f, " ??)")
},
ExpandsTo::App => write!(f, "(?? ??)"),
ExpandsTo::Var(v) => write!(f, "${v}"),
ExpandsTo::Var(v, tag) => {
write!(f, "${v}")?;
if *tag != -1 {
write!(f, "_{}", tag)?;
}
Ok(())
},
ExpandsTo::Prim(p) => write!(f, "{p}"),
ExpandsTo::IVar(v) => write!(f, "#{v}"),
}
Expand All @@ -535,9 +550,9 @@ pub struct Arg {

fn expands_to_of_node(node: &Node) -> ExpandsTo {
match node {
Node::Var(i) => ExpandsTo::Var(*i),
Node::Var(i, tag) => ExpandsTo::Var(*i, *tag),
Node::Prim(p) => ExpandsTo::Prim(p.clone()),
Node::Lam(_) => ExpandsTo::Lam,
Node::Lam(_, tag) => ExpandsTo::Lam(*tag),
Node::App(_,_) => ExpandsTo::App,
Node::IVar(i) => ExpandsTo::IVar(*i),
}
Expand Down Expand Up @@ -952,7 +967,7 @@ fn stitch_search(

// Pruning (FREE VARS): if an invention has free variables in the body then it's not a real function and we can discard it
// Here we just check if our expansion just yielded a variable, and if that is bound based on how many lambdas there are above it.
if let ExpandsTo::Var(i) = expands_to {
if let ExpandsTo::Var(i, _) = expands_to {
if i >= shared.zip_of_zid[hole_zid].iter().filter(|znode|**znode == ZNode::Body).count() as i32 {
if !shared.cfg.no_stats { shared.stats.lock().deref_mut().free_vars_fired += 1; };
if tracked && !shared.cfg.quiet { println!("{} pruned by free var in body when expanding {} to {}", "[TRACK]".red().bold(), original_pattern.to_expr(&shared), original_pattern.show_track_expansion(hole_zid, &shared)) }
Expand All @@ -962,9 +977,9 @@ fn stitch_search(

// update the body utility
let body_utility = original_pattern.body_utility + match &expands_to {
ExpandsTo::Lam => shared.cost_fn.cost_lam,
ExpandsTo::Lam(_) => shared.cost_fn.cost_lam,
ExpandsTo::App => shared.cost_fn.cost_app,
ExpandsTo::Var(_) => shared.cost_fn.cost_var,
ExpandsTo::Var(_, _) => shared.cost_fn.cost_var,
ExpandsTo::Prim(p) => *shared.cost_fn.cost_prim.get(p).unwrap_or(&shared.cost_fn.cost_prim_default),
ExpandsTo::IVar(_) => 0,
};
Expand All @@ -988,7 +1003,7 @@ fn stitch_search(
// add any new holes to the list of holes
let mut holes = holes_after_pop.clone();
match expands_to {
ExpandsTo::Lam => {
ExpandsTo::Lam(_) => {
// add new holes
holes.push(shared.extensions_of_zid[hole_zid].body.unwrap());
}
Expand Down Expand Up @@ -1258,7 +1273,7 @@ fn get_zippers(

match node {
Node::IVar(_) => { unreachable!() }
Node::Var(_) | Node::Prim(_) => {},
Node::Var(_, _) | Node::Prim(_) => {},
Node::App(f,x) => {
// bubble from `f`
for f_zid in zids_of_node[&f].iter() {
Expand Down Expand Up @@ -1297,7 +1312,7 @@ fn get_zippers(

}
},
Node::Lam(b) => {
Node::Lam(b, _) => {
for b_zid in zids_of_node[&b].iter() {

// clone and extend zip to get new zid for this node
Expand Down Expand Up @@ -1598,9 +1613,9 @@ fn bottom_up_utility_correction(pattern: &Pattern, shared:&SharedData, utility_o
for node in shared.corpus_span.clone() {

let utility_without_rewrite: i32 = match &shared.set[node] {
Node::Lam(b) => cumulative_utility_of_node[*b],
Node::Lam(b, _) => cumulative_utility_of_node[*b],
Node::App(f,x) => cumulative_utility_of_node[*f] + cumulative_utility_of_node[*x],
Node::Prim(_) | Node::Var(_) => 0,
Node::Prim(_) | Node::Var(_, _) => 0,
Node::IVar(_) => unreachable!(),
};

Expand Down Expand Up @@ -1723,8 +1738,8 @@ fn use_counts(pattern: &Pattern, zip_of_zid: &[Vec<ZNode>], arg_of_zid_node: &[F
}
match &set[curr_node] {
Node::Prim(_) => {},
Node::Var(_) => {},
Node::Lam(b) => {
Node::Var(_, _) => {},
Node::Lam(b, _) => {
curr_zip.push(ZNode::Body);
let new_zid = extensions_of_zid[curr_zid].body.unwrap();
helper(*b, match_loc, curr_zip, new_zid, zips, zids, arg_of_zid_node, extensions_of_zid, set, counts, analyzed_ivars);
Expand Down Expand Up @@ -2239,7 +2254,9 @@ pub fn json_of_step_results(step_results: &[CompressionStepResult], train_progra
let final_cost = min_cost(rewritten, &tasks, cost_fn);
let rewritten = step_results.iter().last().map(|res| &res.rewritten).unwrap_or(train_programs).iter().map(|p| p.to_string()).collect::<Vec<String>>();
let rewritten_dreamcoder = if !cfg.step.rewritten_dreamcoder { None } else {
let rewritten_dreamcoder = step_results.iter().last().map(|res| res.rewritten_dreamcoder.clone().unwrap()).unwrap_or_else(||train_programs.iter().map(|p| p.to_string().replace("(lam ", "(lambda ")).collect::<Vec<String>>());
let rewritten_dreamcoder = step_results.iter().last().map(|res| res.rewritten_dreamcoder.clone().unwrap()).unwrap_or_else(||train_programs.iter().map(
|p| p.to_string().replace("(lam ", "(lambda ").replace("(lam_", "(lambda_")
).collect::<Vec<String>>());
Some(rewritten_dreamcoder)
};
json!({
Expand Down
6 changes: 3 additions & 3 deletions src/egraphs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ pub fn insert_arg_ivars(e: &mut ExprMut, set_to: i32, init_depth: i32, analyzed_

match e.node().clone() {
Node::Prim(_) => e.idx,
Node::Var(i) => if i == init_depth { e.set.add(Node::IVar(set_to)) } else { e.idx },
Node::Var(i, _) => if i == init_depth { e.set.add(Node::IVar(set_to)) } else { e.idx },
Node::IVar(_) => e.idx,
Node::App(f, x) => {
let f = insert_arg_ivars(&mut e.get(f), set_to, init_depth, analyzed_free_vars);
let x = insert_arg_ivars(&mut e.get(x), set_to, init_depth, analyzed_free_vars);
e.set.add(Node::App(f, x))
},
Node::Lam(b) => {
Node::Lam(b, tag) => {
let b = insert_arg_ivars(&mut e.get(b), set_to, init_depth + 1, analyzed_free_vars);
e.set.add(Node::Lam(b))
e.set.add(Node::Lam(b, tag))
},
}
}
Loading

0 comments on commit 00c763c

Please sign in to comment.