Skip to content

Commit

Permalink
feat: Unroll loops iteratively (#4779)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4736

## Summary\*

Instead of trying to unroll loops once, adds a meta pass that tries to
unroll as many times as necessary, simplifying and doing mem2reg between
unrolls.
- For the cases where the program unrolls succesfully at the first try,
no compile time overhead is created
- For the cases where the program does contain an unknown at compile
time loop bound, it'll try one more time before failing, since the stop
condition is that "this unroll retry generated the same amount of errors
as the previous one"
- For the cases where the program doesn't contain an unknown at compile
time loop bound, instead of failing it'll do as many unrolls as
necessary to unroll the loop.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: jfecher <[email protected]>
  • Loading branch information
sirasistant and jfecher authored Apr 11, 2024
1 parent 606ff44 commit f831b0b
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 21 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub(crate) fn optimize_into_acir(
.run_pass(Ssa::mem2reg, "After Mem2Reg:")
.run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization")
.try_run_pass(Ssa::evaluate_assert_constant, "After Assert Constant:")?
.try_run_pass(Ssa::unroll_loops, "After Unrolling:")?
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.run_pass(Ssa::simplify_cfg, "After Simplifying:")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
Expand Down
65 changes: 45 additions & 20 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,41 @@ use crate::{
use fxhash::FxHashMap as HashMap;

impl Ssa {
/// Unroll all loops in each SSA function.
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
// Try to unroll loops first:
let mut unroll_errors;
(ssa, unroll_errors) = ssa.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
ssa = ssa.mem2reg();
ssa = ssa.simplify_cfg();
// Do another mem2reg after simplify_cfg to aid the next unroll
ssa = ssa.mem2reg();

// Unroll again
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
}
}
Ok(ssa)
}

/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
/// Returns the ssa along with all unrolling errors encountered
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn unroll_loops(mut self) -> Result<Ssa, RuntimeError> {
pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in self.functions.values_mut() {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
Expand All @@ -46,12 +77,9 @@ impl Ssa {
continue;
}

// This check is always true with the addition of the above guard, but I'm
// keeping it in case the guard on brillig functions is ever removed.
let abort_on_error = matches!(function.runtime(), RuntimeType::Acir(_));
find_all_loops(function).unroll_each_loop(function, abort_on_error)?;
errors.extend(find_all_loops(function).unroll_each_loop(function));
}
Ok(self)
(self, errors)
}
}

Expand Down Expand Up @@ -115,34 +143,29 @@ fn find_all_loops(function: &Function) -> Loops {
impl Loops {
/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
fn unroll_each_loop(
mut self,
function: &mut Function,
abort_on_error: bool,
) -> Result<(), RuntimeError> {
fn unroll_each_loop(mut self, function: &mut Function) -> Vec<RuntimeError> {
let mut unroll_errors = vec![];
while let Some(next_loop) = self.yet_to_unroll.pop() {
// If we've previously modified a block in this loop we need to refresh the context.
// This happens any time we have nested loops.
if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) {
let mut new_context = find_all_loops(function);
new_context.failed_to_unroll = self.failed_to_unroll;
return new_context.unroll_each_loop(function, abort_on_error);
return new_context.unroll_each_loop(function);
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match unroll_loop(function, &self.cfg, &next_loop) {
Ok(_) => self.modified_blocks.extend(next_loop.blocks),
Err(call_stack) if abort_on_error => {
return Err(RuntimeError::UnknownLoopBound { call_stack });
}
Err(_) => {
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
}
Ok(())
unroll_errors
}
}

Expand Down Expand Up @@ -585,7 +608,8 @@ mod tests {
// }
// The final block count is not 1 because unrolling creates some unnecessary jmps.
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
let ssa = ssa.unroll_loops().expect("All loops should be unrolled");
let (ssa, errors) = ssa.try_to_unroll_loops();
assert_eq!(errors.len(), 0, "All loops should be unrolled");
assert_eq!(ssa.main().reachable_blocks().len(), 5);
}

Expand Down Expand Up @@ -634,6 +658,7 @@ mod tests {
assert_eq!(ssa.main().reachable_blocks().len(), 4);

// Expected that we failed to unroll the loop
assert!(ssa.unroll_loops().is_err());
let (_, errors) = ssa.try_to_unroll_loops();
assert_eq!(errors.len(), 1, "Expected to fail to unroll loop");
}
}
6 changes: 6 additions & 0 deletions test_programs/execution_success/slice_loop/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "slice_loop"
type = "bin"
authors = [""]

[dependencies]
11 changes: 11 additions & 0 deletions test_programs/execution_success/slice_loop/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[[points]]
x = "1"
y = "2"

[[points]]
x = "3"
y = "4"

[[points]]
x = "5"
y = "6"
32 changes: 32 additions & 0 deletions test_programs/execution_success/slice_loop/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
struct Point {
x: Field,
y: Field,
}

impl Point {
fn serialize(self) -> [Field; 2] {
[self.x, self.y]
}
}

fn sum(values: [Field]) -> Field {
let mut sum = 0;
for value in values {
sum = sum + value;
}
sum
}

fn main(points: [Point; 3]) {
let mut serialized_points = &[];
for point in points {
serialized_points = serialized_points.append(point.serialize().as_slice());
}
// Do a compile-time check that needs the previous loop to be unrolled
if serialized_points.len() > 5 {
let empty_point = Point { x: 0, y: 0 };
serialized_points = serialized_points.append(empty_point.serialize().as_slice());
}
// Do a sum that needs both the previous loop and the previous if to have been simplified
assert_eq(sum(serialized_points), 21);
}

0 comments on commit f831b0b

Please sign in to comment.