From 63aa63d847274def0bb9328383ec79f1a3698ee2 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 11 Nov 2024 11:45:02 -0300 Subject: [PATCH] Introduce UnrollMode to preserve original checks in acir mode --- compiler/noirc_evaluator/src/ssa.rs | 14 ++-- .../src/ssa/opt/inline_const_brillig_calls.rs | 3 +- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 2 +- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 80 ++++++++++++++----- 4 files changed, 70 insertions(+), 29 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index b6a18baf6ea..33cfb918ece 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -33,6 +33,7 @@ use noirc_frontend::{ hir_def::{function::FunctionSignature, types::Type as HirType}, monomorphization::ast::Program, }; +use opt::unrolling::UnrollMode; use tracing::{span, Level}; use self::{ @@ -130,6 +131,7 @@ fn optimize_ssa(builder: SsaBuilder, inliner_aggressiveness: i64) -> Result Result { let builder = builder // Run mem2reg with the CFG separated into blocks @@ -148,7 +151,7 @@ fn optimize_ssa_after_inline_const_brillig_calls( Ssa::evaluate_static_assert_and_assert_constant, "After `static_assert` and `assert_constant`:", )? - .try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")? + .try_run_pass(|ssa| Ssa::unroll_loops_iteratively(ssa, unroll_mode), "After Unrolling:")? .run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):") .run_pass(Ssa::flatten_cfg, "After Flattening:") .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") @@ -457,11 +460,10 @@ impl SsaBuilder { } /// The same as `run_pass` but for passes that may fail - fn try_run_pass( - mut self, - pass: fn(Ssa) -> Result, - msg: &str, - ) -> Result { + fn try_run_pass(mut self, pass: F, msg: &str) -> Result + where + F: FnOnce(Ssa) -> Result, + { self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?; Ok(self.print(msg)) } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs index e29842768ae..3b28c78e7fe 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs @@ -12,7 +12,7 @@ use crate::{ instruction::{Instruction, InstructionId, TerminatorInstruction}, value::{Value, ValueId}, }, - optimize_ssa_after_inline_const_brillig_calls, Ssa, SsaBuilder, + optimize_ssa_after_inline_const_brillig_calls, Ssa, SsaBuilder, UnrollMode, }, }; @@ -235,6 +235,7 @@ fn optimize( // a single function. For inlining to work we need to know all other functions that // exist (so we can inline them). Here we choose to skip this optimization for simplicity reasons. false, + UnrollMode::Brillig, )?; Ok(ssa.functions.pop_first().unwrap().1) } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index c9aa5e17efe..9cd71e08973 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -21,4 +21,4 @@ mod remove_if_else; mod resolve_is_unconstrained; mod runtime_separation; mod simplify_cfg; -mod unrolling; +pub(crate) mod unrolling; diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 003877e85f7..55681618853 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -37,13 +37,29 @@ use crate::{ }; use fxhash::FxHashMap as HashMap; +/// In this mode we are unrolling. +#[derive(Debug, Clone, Copy)] +pub(crate) enum UnrollMode { + /// This is the normal unroll mode, where we are unrolling a brillig function. + Acir, + /// There's one optimization, `inline_const_brillig_calls`, where we try to optimize + /// brillig functions with all constant arguments. For that we turn the brillig function + /// into an acir one and try to optimize it. This brillig function could then have + /// `break` statements that can't be possible in acir, so in this mode we don't panic + /// if we end up in unexpected situations that might have been produced by `break`. + Brillig, +} + impl Ssa { /// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. - pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { + pub(crate) fn unroll_loops_iteratively( + mut ssa: Ssa, + mode: UnrollMode, + ) -> Result { // Try to unroll loops first: let mut unroll_errors; - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + (ssa, unroll_errors) = ssa.try_to_unroll_loops(mode); // Keep unrolling until no more errors are found while !unroll_errors.is_empty() { @@ -58,7 +74,7 @@ impl Ssa { ssa = ssa.mem2reg(); // Unroll again - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + (ssa, unroll_errors) = ssa.try_to_unroll_loops(mode); // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); @@ -71,22 +87,22 @@ impl Ssa { /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. /// Returns the ssa along with all unrolling errors encountered #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec) { + pub(crate) fn try_to_unroll_loops(mut self, mode: UnrollMode) -> (Ssa, Vec) { let mut errors = vec![]; for function in self.functions.values_mut() { - function.try_to_unroll_loops(&mut errors); + function.try_to_unroll_loops(mode, &mut errors); } (self, errors) } } impl Function { - pub(crate) fn try_to_unroll_loops(&mut self, errors: &mut Vec) { + pub(crate) fn try_to_unroll_loops(&mut self, mode: UnrollMode, errors: &mut Vec) { // Loop unrolling in brillig can lead to a code explosion currently. This can // also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code. if !matches!(self.runtime(), RuntimeType::Brillig(_)) { - errors.extend(find_all_loops(self).unroll_each_loop(self)); + errors.extend(find_all_loops(self).unroll_each_loop(self, mode)); } } } @@ -151,7 +167,7 @@ fn find_all_loops(function: &Function) -> Loops { impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each_loop(mut self, function: &mut Function) -> Vec { + fn unroll_each_loop(mut self, function: &mut Function, mode: UnrollMode) -> Vec { let mut unroll_errors = vec![]; while let Some(next_loop) = self.yet_to_unroll.pop() { // If we've previously modified a block in this loop we need to refresh the context. @@ -161,13 +177,13 @@ impl Loops { new_context.failed_to_unroll = self.failed_to_unroll; return unroll_errors .into_iter() - .chain(new_context.unroll_each_loop(function)) + .chain(new_context.unroll_each_loop(function, mode)) .collect(); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { - match unroll_loop(function, &self.cfg, &next_loop) { + match unroll_loop(function, &self.cfg, &next_loop, mode) { Ok(_) => self.modified_blocks.extend(next_loop.blocks), Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); @@ -217,12 +233,13 @@ fn unroll_loop( function: &mut Function, cfg: &ControlFlowGraph, loop_: &Loop, + mode: UnrollMode, ) -> Result<(), CallStack> { - let mut unroll_into = get_pre_header(cfg, loop_)?; + let mut unroll_into = get_pre_header(cfg, loop_, mode)?; let mut jump_value = get_induction_variable(function, unroll_into)?; while let Some(context) = unroll_loop_header(function, loop_, unroll_into, jump_value)? { - let (last_block, last_value) = context.unroll_loop_iteration()?; + let (last_block, last_value) = context.unroll_loop_iteration(mode)?; unroll_into = last_block; jump_value = last_value; } @@ -233,16 +250,28 @@ fn unroll_loop( /// The loop pre-header is the block that comes before the loop begins. Generally a header block /// is expected to have 2 predecessors: the pre-header and the final block of the loop which jumps /// back to the beginning. -fn get_pre_header(cfg: &ControlFlowGraph, loop_: &Loop) -> Result { +fn get_pre_header( + cfg: &ControlFlowGraph, + loop_: &Loop, + mode: UnrollMode, +) -> Result { let mut pre_header = cfg .predecessors(loop_.header) .filter(|predecessor| *predecessor != loop_.back_edge_start) .collect::>(); - if pre_header.len() == 1 { - Ok(pre_header.remove(0)) - } else { - Err(CallStack::new()) + match mode { + UnrollMode::Acir => { + assert_eq!(pre_header.len(), 1); + Ok(pre_header.remove(0)) + } + UnrollMode::Brillig => { + if pre_header.len() == 1 { + Ok(pre_header.remove(0)) + } else { + Err(CallStack::new()) + } + } } } @@ -367,7 +396,10 @@ impl<'f> LoopIteration<'f> { /// It is expected the terminator instructions are set up to branch into an empty block /// for further unrolling. When the loop is finished this will need to be mutated to /// jump to the end of the loop instead. - fn unroll_loop_iteration(mut self) -> Result<(BasicBlockId, ValueId), CallStack> { + fn unroll_loop_iteration( + mut self, + mode: UnrollMode, + ) -> Result<(BasicBlockId, ValueId), CallStack> { let mut next_blocks = self.unroll_loop_block(); while let Some(block) = next_blocks.pop() { @@ -380,7 +412,12 @@ impl<'f> LoopIteration<'f> { } } - self.induction_value.ok_or_else(CallStack::new) + match mode { + UnrollMode::Acir => Ok(self + .induction_value + .expect("Expected to find the induction variable by end of loop iteration")), + UnrollMode::Brillig => self.induction_value.ok_or_else(CallStack::new), + } } /// Unroll a single block in the current iteration of the loop @@ -512,6 +549,7 @@ mod tests { use crate::ssa::{ function_builder::FunctionBuilder, ir::{instruction::BinaryOp, map::Id, types::Type}, + opt::unrolling::UnrollMode, }; #[test] @@ -632,7 +670,7 @@ mod tests { // } // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. - let (ssa, errors) = ssa.try_to_unroll_loops(); + let (ssa, errors) = ssa.try_to_unroll_loops(UnrollMode::Acir); assert_eq!(errors.len(), 0, "All loops should be unrolled"); assert_eq!(ssa.main().reachable_blocks().len(), 5); } @@ -682,7 +720,7 @@ mod tests { assert_eq!(ssa.main().reachable_blocks().len(), 4); // Expected that we failed to unroll the loop - let (_, errors) = ssa.try_to_unroll_loops(); + let (_, errors) = ssa.try_to_unroll_loops(UnrollMode::Acir); assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); } }