Skip to content

Commit

Permalink
0.1.24 - bump stitch to add weighting, and add weighting to the pytho…
Browse files Browse the repository at this point in the history
…n bindings and tests too
  • Loading branch information
mlb2251 committed Nov 29, 2023
1 parent b432e61 commit 09aaa45
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "stitch_bindings"
version = "0.1.23"
version = "0.1.24"
edition = "2021"

[lib]
Expand All @@ -12,7 +12,7 @@ name = "stitch_core"

[dependencies]
# stitch_core = { path = "../stitch"}
stitch_core = { git = "https://github.com/mlb2251/stitch", rev = "f96ba0e"}
stitch_core = { git = "https://github.com/mlb2251/stitch", rev = "323da8c"}
pyo3 = {version = "0.17.3", features = ["extension-module"] }
clap = { version = "3.1.0" }

Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ use clap::Parser;
#[pyfunction(
programs,
tasks,
weights,
name_mapping,
panic_loud,
args
)]
fn compress_backend(
py: Python,
programs: Vec<String>,
tasks: Option<Vec<String>>,
weights: Option<Vec<f32>>,
name_mapping: Option<Vec<(String,String)>>,
panic_loud: bool,
args: String,
Expand All @@ -31,7 +34,7 @@ fn compress_backend(

// release the GIL and call compression
let (_step_results, json_res) = py.allow_threads(||
multistep_compression(&programs, tasks, name_mapping, None, &cfg)
multistep_compression(&programs, tasks, weights, name_mapping, None, &cfg)
);

// return as something you could json.loads(out) from in python
Expand All @@ -42,6 +45,7 @@ fn compress_backend(
#[pyfunction(
programs,
abstractions,
panic_loud,
args
)]
fn rewrite_backend(
Expand Down
2 changes: 2 additions & 0 deletions stitch_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def compress(
"""

tasks = kwargs.pop("tasks", None)
weights = kwargs.pop("weights", None)
name_mapping = kwargs.pop("name_mapping", None)
panic_loud = kwargs.pop('panic_loud',False)

Expand All @@ -305,6 +306,7 @@ def compress(
res = compress_backend(
programs,
tasks,
weights,
name_mapping,
panic_loud,
args)
Expand Down
19 changes: 19 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from stitch_core import compress, rewrite, StitchException, from_dreamcoder, Abstraction, name_mapping_stitch, stitch_to_dreamcoder
import json
import math

# simple test
programs = ["(a a a)", "(b b b)"]
Expand Down Expand Up @@ -68,4 +69,22 @@
# print(e)
pass

# 1x (default) weighting vs 2x weighting vs weighting the "g" programs more
programs = ["(f a a)", "(f b b)", "(f c c)", "(g d d)", "(g e e)"]
res = compress(programs, iterations=1)
res2x = compress(programs, iterations=1, weights=[2. for _ in programs])
res_uneven = compress(programs, iterations=1, weights=[1., 1., 1., 2., 2.])

assert res.json["original_cost"] *2 == res2x.json["original_cost"]
assert res.json["final_cost"] *2 == res2x.json["final_cost"]
assert res.abstractions[0].body == res2x.abstractions[0].body == "(f #0 #0)"
assert res_uneven.abstractions[0].body == "(g #0 #0)"

# make sure compression ratio is as expected
assert math.fabs(res_uneven.json["original_cost"]/res_uneven.json["final_cost"] - res_uneven.json["compression_ratio"]) < 0.00001

# assert res.rewritten == ['(fn_0 a)', '(fn_0 b)']
# assert res.abstractions[0].body == '(#0 #0 #0)'


print("Passed all tests")

0 comments on commit 09aaa45

Please sign in to comment.