Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A couple minor bug fixes #144

Merged
merged 3 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions circuit_passes/src/bucket_interpreter/env/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use compiler::intermediate_representation::BucketId;
use function_env::FunctionEnvData;
use indexmap::IndexSet;
use crate::passes::loop_unroll::body_extractor::LoopBodyExtractor;
use crate::passes::loop_unroll::{ToOriginalLocation, FuncArgIdx};
use crate::passes::loop_unroll::{FuncArgIdx, ToOriginalLocation, LOOP_BODY_FN_PREFIX};
use crate::passes::GlobalPassData;
use self::extracted_func_env::ExtractedFuncEnvData;
use self::template_env::TemplateEnvData;
use self::unrolled_block_env::UnrolledBlockEnvData;
Expand All @@ -22,6 +23,7 @@ mod function_env;
mod unrolled_block_env;
mod extracted_func_env;

const DEBUG_INTERPRETER: bool = false;
const PRINT_ENV_SORTED: bool = true;

#[inline]
Expand Down Expand Up @@ -222,6 +224,10 @@ impl<'a> Env<'a> {
Env::Template(TemplateEnvData::new(libs))
}

pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self {
Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor))
}

pub fn new_source_func_env(
base: Env<'a>,
caller: &BucketId,
Expand All @@ -231,7 +237,7 @@ impl<'a> Env<'a> {
Env::Function(FunctionEnvData::new(base, caller, call_stack, libs))
}

pub fn new_extracted_func_env(
fn _new_extracted_func_env(
base: Env<'a>,
caller: &BucketId,
remap: ToOriginalLocation,
Expand All @@ -240,10 +246,28 @@ impl<'a> Env<'a> {
Env::ExtractedFunction(ExtractedFuncEnvData::new(base, caller, remap, arenas))
}

pub fn new_unroll_block_env(base: Env<'a>, extractor: &'a LoopBodyExtractor) -> Self {
Env::UnrolledBlock(UnrolledBlockEnvData::new(base, extractor))
pub fn new_extracted_func_env(
base: Env<'a>,
caller: &BucketId,
callee_name: &str,
gdat: Ref<GlobalPassData>,
) -> Self {
if callee_name.starts_with(LOOP_BODY_FN_PREFIX) {
if DEBUG_INTERPRETER {
println!("\ncurrent env = {}", base);
println!("callee_name = {}", callee_name);
println!("base.get_vars_sort() = {:?}", base.get_vars_sort());
println!("callee function data = {:?}", gdat.get_data_for_func(callee_name));
}
let fdat = &gdat.get_data_for_func(callee_name)[&base.get_vars_sort()];
Self::_new_extracted_func_env(base, caller, fdat.0.clone(), fdat.1.clone())
} else {
Self::_new_extracted_func_env(base, caller, Default::default(), Default::default())
}
}
}

impl Env<'_> {
// READ OPERATIONS
pub fn peel_extracted_func(self) -> Self {
match self {
Expand Down
13 changes: 2 additions & 11 deletions circuit_passes/src/bucket_interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use compiler::num_bigint::BigInt;
use observer::Observer;
use program_structure::error_code::ReportCode;
use crate::passes::builders::{build_compute, build_u32_value};
use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX;
use crate::passes::GlobalPassData;
use self::env::{CallStackFrame, Env, LibraryAccess};
use self::error::BadInterp;
Expand Down Expand Up @@ -745,15 +744,8 @@ impl BucketInterpreter<'_, '_> {
// calls below (that give ownership of the 'env' object into the new Env instance)
// to avoid copying the entire 'env' instance (which is likely more expensive).
let instructions = env.get_function(name).body.clone();
let mut res = (vec![], {
if name.starts_with(LOOP_BODY_FN_PREFIX) {
let gdat = self.global_data.borrow();
let fdat = &gdat.get_data_for_func(name)[&env.get_vars_sort()];
Env::new_extracted_func_env(env, &bucket.id, fdat.0.clone(), fdat.1.clone())
} else {
Env::new_extracted_func_env(env, &bucket.id, Default::default(), Default::default())
}
});
let mut res =
(vec![], Env::new_extracted_func_env(env, &bucket.id, name, self.global_data.borrow()));
//NOTE: Do not change scope for the new interpreter because the mem lookups
// within 'write_collector.rs' need to use the original function context.
let interp = self.mem.build_interpreter_with_flags(
Expand Down Expand Up @@ -820,7 +812,6 @@ impl BucketInterpreter<'_, '_> {
// unless self.flags.allow_nondetermined_return() == false because
// that case could result in no return statements being observed.
let func_val = if body_val.is_empty() && !self.flags.allow_nondetermined_return() {
// Some(Value::Unknown) // TODO: return the correct number of Unknowns
Result::Ok(vec![Value::Unknown; callee.returns.iter().product::<usize>()])
} else {
let vals = into_result(body_val, "value returned from function");
Expand Down
48 changes: 34 additions & 14 deletions circuit_passes/src/bucket_interpreter/write_collector.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
use std::collections::HashSet;
use code_producers::llvm_elements::stdlib::GENERATED_FN_PREFIX;
use compiler::intermediate_representation::{
ir_interface::{AddressType, FinalData, LocationRule, LogBucketArg, ReturnType, StoreBucket},
Instruction, InstructionList, InstructionPointer,
};
use super::{env::Env, error::BadInterp, value::Value, BucketInterpreter, InterpRes};
use super::{
env::{Env, LibraryAccess},
error::BadInterp,
value::Value,
BucketInterpreter, InterpRes,
};

pub(crate) fn set_writes_to_unknown<'e>(
interp: &BucketInterpreter,
body: &InstructionList,
env: Env<'e>,
) -> Result<Env<'e>, BadInterp> {
let mut checker = Writes::default();
Result::from(checker.collect_writes(interp, body, env.clone())).map_or_else(
let mut collector = Writes::default();
Result::from(collector.check_body(interp, body, env.clone())).map_or_else(
|b| Result::Err(b),
// For the Ok case, ignore the Env computed within the body
// and just set Unknown to all writes that were found.
|_| checker.set_unknowns(env),
|_| collector.set_unknowns(env),
)
}

Expand Down Expand Up @@ -43,7 +49,7 @@ impl Writes {
.set_subcmps_to_unknown(self.subcmps)
}

fn collect_writes<'e>(
fn check_body<'e>(
&mut self,
interp: &BucketInterpreter,
body: &InstructionList,
Expand All @@ -65,13 +71,27 @@ impl Writes {
match inst {
Instruction::Store(b) => self.check_store_bucket(interp, b, env),
Instruction::Constraint(b) => self.check_inst(interp, b.unwrap(), env),
Instruction::Block(b) => self.collect_writes(interp, &b.body, env),
Instruction::Block(b) => self.check_body(interp, &b.body, env),
Instruction::Branch(b) => {
self.check_branch(interp, &b.cond, &b.if_branch, &b.else_branch, env)
}
Instruction::Loop(b) => {
self.check_branch(interp, &b.continue_condition, &b.body, &vec![], env)
}
Instruction::Call(b) if b.symbol.starts_with(GENERATED_FN_PREFIX) => {
let callee_name = &b.symbol;
let callee_body = env.get_function(callee_name).body.clone();
self.check_body(
interp,
&callee_body,
Env::new_extracted_func_env(
env,
&b.id,
callee_name,
interp.global_data.borrow(),
),
)
}
i => {
debug_assert!(!ContainsStore::contains_store(i));
InterpRes::Continue(env)
Expand All @@ -93,10 +113,10 @@ impl Writes {
// If the condition is unknown, collect all writes from both branches (even if
// there is a return in either, hence an InterpRes::Return result is ignored
// in both cases) and produce InterpRes::Continue with the original Env.
if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) {
return InterpRes::Err(e);
}
if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) {
return InterpRes::Err(e);
}
InterpRes::Continue(env)
Expand All @@ -105,17 +125,17 @@ impl Writes {
// If the condition is true, collect all writes from the false branch
// (ignoring an InterpRes::Return result as above) and then analyze
// and return the result from the true branch.
if let InterpRes::Err(e) = self.collect_writes(interp, false_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, false_branch, env.clone()) {
return InterpRes::Err(e);
}
self.collect_writes(interp, true_branch, env)
self.check_body(interp, true_branch, env)
}
Ok(Some(false)) => {
// Reverse of the true case.
if let InterpRes::Err(e) = self.collect_writes(interp, true_branch, env.clone()) {
if let InterpRes::Err(e) = self.check_body(interp, true_branch, env.clone()) {
return InterpRes::Err(e);
}
self.collect_writes(interp, false_branch, env)
self.check_body(interp, false_branch, env)
}
}
}
Expand Down Expand Up @@ -366,7 +386,7 @@ mod tests {
];

let mut checker = Writes::default();
let collect_res = checker.collect_writes(&interp, &body, env);
let collect_res = checker.check_body(&interp, &body, env);
assert!(!matches!(collect_res, InterpRes::Err(_)));
// EXPECT:
// - variables A, B (only index 0 in the vector), and C are written
Expand Down Expand Up @@ -434,7 +454,7 @@ mod tests {
];

let mut checker = Writes::default();
let collect_res = checker.collect_writes(&interp, &body, env);
let collect_res = checker.check_body(&interp, &body, env);
assert!(!matches!(collect_res, InterpRes::Err(_)));
// EXPECT:
// - no variables are written
Expand Down
7 changes: 6 additions & 1 deletion circuit_passes/src/passes/loop_unroll/observer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ impl LoopUnrollObserver<'_> {
}
match cond {
// If the conditional becomes unknown just give up.
None => return Ok(None),
None => {
if DEBUG_LOOP_UNROLL {
println!("[UNROLL][try_unroll_loop] OUTCOME: not safe to move or unroll, condition unknown");
}
return Ok(None);
}
// When conditional becomes `false`, iteration count is complete.
Some(false) => break,
// Otherwise, continue counting.
Expand Down
2 changes: 1 addition & 1 deletion circuit_passes/src/passes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ impl GlobalPassData {

pub fn get_data_for_func(
&self,
name: &String,
name: &str,
) -> &BTreeMap<UnrolledIterLvars, (ToOriginalLocation, HashSet<FuncArgIdx>)> {
match self.extract_func_orig_loc.get(name) {
Some(x) => x,
Expand Down
61 changes: 42 additions & 19 deletions circuit_passes/src/passes/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,45 @@ impl<'d> SimplificationPass<'d> {
fn insert<V: PartialEq>(
map: &RefCell<HashMap<BucketId, Option<V>>>,
bucket_id: BucketId,
v: V,
new_val: Option<V>,
) {
map.borrow_mut()
.entry(bucket_id)
// If the entry exists and it's not the same as the new value, set to None.
// If the entry exists and it's not the same as the new value,
// or the new value itself is None, then set to result to None
// to indicate that no replacement should be made.
.and_modify(|old| {
if let Some(old_val) = old {
if *old_val != v {
*old = None
match &new_val {
None => *old = None,
Some(x) => {
if *x != *old_val {
*old = None
}
}
}
}
})
// If no entry exists, store the new value.
.or_insert(Some(v));
.or_insert(new_val);
}

fn store_computed_value(
map: &RefCell<HashMap<BucketId, Option<Value>>>,
bucket_id: BucketId,
computed_value: Value,
) -> Result<bool, BadInterp> {
if computed_value.is_unknown() {
// When the bucket's value is Unknown from any execution, add None to the map so
// the bucket will not be replaced (even if known at an execution found later),
// return 'true' so buckets nested within this bucket will be observed.
Self::insert(map, bucket_id, None);
Ok(true)
} else {
// Add known value to the map, return 'false' so observation will not continue within.
Self::insert(map, bucket_id, Some(computed_value));
Ok(false)
}
}
}

Expand All @@ -84,12 +109,7 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
let interp = self.build_interpreter();
let v = interp.compute_compute_bucket(bucket, env, false)?;
let v = result_types::into_single_result(v, "ComputeBucket")?;
if !v.is_unknown() {
Self::insert(&self.compute_replacements, bucket.id, v);
Ok(false)
} else {
Ok(true)
}
Self::store_computed_value(&self.compute_replacements, bucket.id, v)
}

fn on_call_bucket(&self, bucket: &CallBucket, env: &Env) -> Result<bool, BadInterp> {
Expand All @@ -99,12 +119,10 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
// rather than 'into_single_result()' and return 'true' in the None case
// so buckets nested within this bucket will be observed.
if let Some(v) = result_types::into_single_option(v) {
if !v.is_unknown() {
Self::insert(&self.call_replacements, bucket.id, v);
return Ok(false);
}
Self::store_computed_value(&self.call_replacements, bucket.id, v)
} else {
Ok(true)
}
Ok(true)
}

fn on_constraint_bucket(
Expand All @@ -113,7 +131,8 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
env: &Env,
) -> Result<bool, BadInterp> {
self.within_constraint.replace(true);
// Match the expected structure of ConstraintBucket instances but don't fail if there's something different.
// Match the expected structure of ConstraintBucket instances
// but don't fail if there's something different.
match bucket {
ConstraintBucket::Equality(e) => {
if let Instruction::Assert(AssertBucket { evaluate, .. }) = e.as_ref() {
Expand All @@ -130,7 +149,11 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
}
// If at least one is a known value, then we can (likely) simplify
if values.iter().any(Value::is_known) {
Self::insert(&self.constraint_eq_replacements, e.get_id(), values);
Self::insert(
&self.constraint_eq_replacements,
e.get_id(),
Some(values),
);
}
}
}
Expand Down Expand Up @@ -172,7 +195,7 @@ impl Observer<Env<'_>> for SimplificationPass<'_> {
Self::insert(
&self.constraint_sub_replacements,
e.get_id(),
(src, dest, dest_address_type),
Some((src, dest, dest_address_type)),
);
}
}
Expand Down
Loading