Skip to content

Commit

Permalink
Introduce UnrollMode to preserve original checks in acir mode
Browse files Browse the repository at this point in the history
  • Loading branch information
asterite committed Nov 11, 2024
1 parent e83fdbd commit 63aa63d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 29 deletions.
14 changes: 8 additions & 6 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use noirc_frontend::{
hir_def::{function::FunctionSignature, types::Type as HirType},
monomorphization::ast::Program,
};
use opt::unrolling::UnrollMode;
use tracing::{span, Level};

use self::{
Expand Down Expand Up @@ -130,6 +131,7 @@ fn optimize_ssa(builder: SsaBuilder, inliner_aggressiveness: i64) -> Result<Ssa,
builder,
inliner_aggressiveness,
true, // inline functions with no predicates
UnrollMode::Acir,
)?;
Ok(ssa)
}
Expand All @@ -138,6 +140,7 @@ fn optimize_ssa_after_inline_const_brillig_calls(
builder: SsaBuilder,
inliner_aggressiveness: i64,
inline_functions_with_no_predicates: bool,
unroll_mode: UnrollMode,
) -> Result<Ssa, RuntimeError> {
let builder = builder
// Run mem2reg with the CFG separated into blocks
Expand All @@ -148,7 +151,7 @@ fn optimize_ssa_after_inline_const_brillig_calls(
Ssa::evaluate_static_assert_and_assert_constant,
"After `static_assert` and `assert_constant`:",
)?
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.try_run_pass(|ssa| Ssa::unroll_loops_iteratively(ssa, unroll_mode), "After Unrolling:")?
.run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
Expand Down Expand Up @@ -457,11 +460,10 @@ impl SsaBuilder {
}

/// The same as `run_pass` but for passes that may fail
fn try_run_pass(
mut self,
pass: fn(Ssa) -> Result<Ssa, RuntimeError>,
msg: &str,
) -> Result<Self, RuntimeError> {
fn try_run_pass<F>(mut self, pass: F, msg: &str) -> Result<Self, RuntimeError>
where
F: FnOnce(Ssa) -> Result<Ssa, RuntimeError>,
{
self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?;
Ok(self.print(msg))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
instruction::{Instruction, InstructionId, TerminatorInstruction},
value::{Value, ValueId},
},
optimize_ssa_after_inline_const_brillig_calls, Ssa, SsaBuilder,
optimize_ssa_after_inline_const_brillig_calls, Ssa, SsaBuilder, UnrollMode,
},
};

Expand Down Expand Up @@ -235,6 +235,7 @@ fn optimize(
// a single function. For inlining to work we need to know all other functions that
// exist (so we can inline them). Here we choose to skip this optimization for simplicity reasons.
false,
UnrollMode::Brillig,
)?;
Ok(ssa.functions.pop_first().unwrap().1)
}
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ mod remove_if_else;
mod resolve_is_unconstrained;
mod runtime_separation;
mod simplify_cfg;
mod unrolling;
pub(crate) mod unrolling;
80 changes: 59 additions & 21 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,29 @@ use crate::{
};
use fxhash::FxHashMap as HashMap;

/// In this mode we are unrolling.
#[derive(Debug, Clone, Copy)]
pub(crate) enum UnrollMode {
/// This is the normal unroll mode, where we are unrolling a brillig function.
Acir,
/// There's one optimization, `inline_const_brillig_calls`, where we try to optimize
/// brillig functions with all constant arguments. For that we turn the brillig function
/// into an acir one and try to optimize it. This brillig function could then have
/// `break` statements that can't be possible in acir, so in this mode we don't panic
/// if we end up in unexpected situations that might have been produced by `break`.
Brillig,
}

impl Ssa {
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
pub(crate) fn unroll_loops_iteratively(
mut ssa: Ssa,
mode: UnrollMode,
) -> Result<Ssa, RuntimeError> {
// Try to unroll loops first:
let mut unroll_errors;
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
(ssa, unroll_errors) = ssa.try_to_unroll_loops(mode);

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
Expand All @@ -58,7 +74,7 @@ impl Ssa {
ssa = ssa.mem2reg();

// Unroll again
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
(ssa, unroll_errors) = ssa.try_to_unroll_loops(mode);
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
Expand All @@ -71,22 +87,22 @@ impl Ssa {
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
/// Returns the ssa along with all unrolling errors encountered
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec<RuntimeError>) {
pub(crate) fn try_to_unroll_loops(mut self, mode: UnrollMode) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in self.functions.values_mut() {
function.try_to_unroll_loops(&mut errors);
function.try_to_unroll_loops(mode, &mut errors);
}
(self, errors)
}
}

impl Function {
pub(crate) fn try_to_unroll_loops(&mut self, errors: &mut Vec<RuntimeError>) {
pub(crate) fn try_to_unroll_loops(&mut self, mode: UnrollMode, errors: &mut Vec<RuntimeError>) {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code.
if !matches!(self.runtime(), RuntimeType::Brillig(_)) {
errors.extend(find_all_loops(self).unroll_each_loop(self));
errors.extend(find_all_loops(self).unroll_each_loop(self, mode));
}
}
}
Expand Down Expand Up @@ -151,7 +167,7 @@ fn find_all_loops(function: &Function) -> Loops {
impl Loops {
/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
fn unroll_each_loop(mut self, function: &mut Function) -> Vec<RuntimeError> {
fn unroll_each_loop(mut self, function: &mut Function, mode: UnrollMode) -> Vec<RuntimeError> {
let mut unroll_errors = vec![];
while let Some(next_loop) = self.yet_to_unroll.pop() {
// If we've previously modified a block in this loop we need to refresh the context.
Expand All @@ -161,13 +177,13 @@ impl Loops {
new_context.failed_to_unroll = self.failed_to_unroll;
return unroll_errors
.into_iter()
.chain(new_context.unroll_each_loop(function))
.chain(new_context.unroll_each_loop(function, mode))
.collect();
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match unroll_loop(function, &self.cfg, &next_loop) {
match unroll_loop(function, &self.cfg, &next_loop, mode) {
Ok(_) => self.modified_blocks.extend(next_loop.blocks),
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
Expand Down Expand Up @@ -217,12 +233,13 @@ fn unroll_loop(
function: &mut Function,
cfg: &ControlFlowGraph,
loop_: &Loop,
mode: UnrollMode,
) -> Result<(), CallStack> {
let mut unroll_into = get_pre_header(cfg, loop_)?;
let mut unroll_into = get_pre_header(cfg, loop_, mode)?;
let mut jump_value = get_induction_variable(function, unroll_into)?;

while let Some(context) = unroll_loop_header(function, loop_, unroll_into, jump_value)? {
let (last_block, last_value) = context.unroll_loop_iteration()?;
let (last_block, last_value) = context.unroll_loop_iteration(mode)?;
unroll_into = last_block;
jump_value = last_value;
}
Expand All @@ -233,16 +250,28 @@ fn unroll_loop(
/// The loop pre-header is the block that comes before the loop begins. Generally a header block
/// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps
/// back to the beginning.
fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> Result<BasicBlockId, CallStack> {
fn get_pre_header(
cfg: &ControlFlowGraph,
loop_: &Loop,
mode: UnrollMode,
) -> Result<BasicBlockId, CallStack> {
let mut pre_header = cfg
.predecessors(loop_.header)
.filter(|predecessor| *predecessor != loop_.back_edge_start)
.collect::<Vec<_>>();

if pre_header.len() == 1 {
Ok(pre_header.remove(0))
} else {
Err(CallStack::new())
match mode {
UnrollMode::Acir => {
assert_eq!(pre_header.len(), 1);
Ok(pre_header.remove(0))
}
UnrollMode::Brillig => {
if pre_header.len() == 1 {
Ok(pre_header.remove(0))
} else {
Err(CallStack::new())
}
}
}
}

Expand Down Expand Up @@ -367,7 +396,10 @@ impl<'f> LoopIteration<'f> {
/// It is expected the terminator instructions are set up to branch into an empty block
/// for further unrolling. When the loop is finished this will need to be mutated to
/// jump to the end of the loop instead.
fn unroll_loop_iteration(mut self) -> Result<(BasicBlockId, ValueId), CallStack> {
fn unroll_loop_iteration(
mut self,
mode: UnrollMode,
) -> Result<(BasicBlockId, ValueId), CallStack> {
let mut next_blocks = self.unroll_loop_block();

while let Some(block) = next_blocks.pop() {
Expand All @@ -380,7 +412,12 @@ impl<'f> LoopIteration<'f> {
}
}

self.induction_value.ok_or_else(CallStack::new)
match mode {
UnrollMode::Acir => Ok(self
.induction_value
.expect("Expected to find the induction variable by end of loop iteration")),
UnrollMode::Brillig => self.induction_value.ok_or_else(CallStack::new),
}
}

/// Unroll a single block in the current iteration of the loop
Expand Down Expand Up @@ -512,6 +549,7 @@ mod tests {
use crate::ssa::{
function_builder::FunctionBuilder,
ir::{instruction::BinaryOp, map::Id, types::Type},
opt::unrolling::UnrollMode,
};

#[test]
Expand Down Expand Up @@ -632,7 +670,7 @@ mod tests {
// }
// The final block count is not 1 because unrolling creates some unnecessary jmps.
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
let (ssa, errors) = ssa.try_to_unroll_loops();
let (ssa, errors) = ssa.try_to_unroll_loops(UnrollMode::Acir);
assert_eq!(errors.len(), 0, "All loops should be unrolled");
assert_eq!(ssa.main().reachable_blocks().len(), 5);
}
Expand Down Expand Up @@ -682,7 +720,7 @@ mod tests {
assert_eq!(ssa.main().reachable_blocks().len(), 4);

// Expected that we failed to unroll the loop
let (_, errors) = ssa.try_to_unroll_loops();
let (_, errors) = ssa.try_to_unroll_loops(UnrollMode::Acir);
assert_eq!(errors.len(), 1, "Expected to fail to unroll loop");
}
}

0 comments on commit 63aa63d

Please sign in to comment.