Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pull out cfg simplification changes #10279

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,16 @@ impl<'f> Context<'f> {
};
self.condition_stack.push(cond_context);
self.insert_current_side_effects_enabled();

// We disallow this case as it results in the `else_destination` block
// being inlined before the `then_destination` block due to block deduplication in the work queue.
//
// The `else_destination` block then gets treated as if it were the `then_destination` block
// and has the incorrect condition applied to it.
assert_ne!(
self.branch_ends[if_entry], *then_destination,
"ICE: branches merge inside of `then` branch"
);
vec![self.branch_ends[if_entry], *else_destination, *then_destination]
}

Expand Down Expand Up @@ -1526,4 +1536,23 @@ mod test {
_ => unreachable!("Should have terminator instruction"),
}
}

#[test]
#[should_panic = "ICE: branches merge inside of `then` branch"]
fn panics_if_branches_merge_within_then_branch() {
//! This is a regression test for https://github.com/noir-lang/noir/issues/6620

let src = "
acir(inline) fn main f0 {
b0(v0: u1):
jmpif v0 then: b2, else: b1
b2():
return
b1():
jmp b2()
}
";
let merged_ssa = Ssa::from_str(src).unwrap();
let _ = merged_ssa.flatten_cfg();
}
}
112 changes: 111 additions & 1 deletion noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use crate::ssa::{
basic_block::BasicBlockId,
cfg::ControlFlowGraph,
function::{Function, RuntimeType},
instruction::TerminatorInstruction,
instruction::{Instruction, TerminatorInstruction},
value::Value,
},
ssa_gen::Ssa,
};
Expand All @@ -31,6 +32,7 @@ impl Ssa {
/// 4. Removing any blocks which have no instructions other than a single terminating jmp.
/// 5. Replacing any jmpifs with constant conditions with jmps. If this causes the block to have
/// only 1 successor then (2) also will be applied.
/// 6. Replacing any jmpifs with a negated condition with a jmpif with a un-negated condition and reversed branches.
///
/// Currently, 1 is unimplemented.
#[tracing::instrument(level = "trace", skip(self))]
Expand All @@ -55,6 +57,8 @@ impl Function {
stack.extend(self.dfg[block].successors().filter(|block| !visited.contains(block)));
}

check_for_negated_jmpif_condition(self, block, &mut cfg);

// This call is before try_inline_into_predecessor so that if it succeeds in changing a
// jmpif into a jmp, the block may then be inlined entirely into its predecessor in try_inline_into_predecessor.
check_for_constant_jmpif(self, block, &mut cfg);
Expand Down Expand Up @@ -184,6 +188,55 @@ fn check_for_double_jmp(function: &mut Function, block: BasicBlockId, cfg: &mut
cfg.recompute_block(function, block);
}

/// Optimize a jmpif on a negated condition by swapping the branches.
fn check_for_negated_jmpif_condition(
function: &mut Function,
block: BasicBlockId,
cfg: &mut ControlFlowGraph,
) {
if matches!(function.runtime(), RuntimeType::Acir(_)) {
// Swapping the `then` and `else` branches of a `JmpIf` within an ACIR function
// can result in the situation where the branches merge together again in the `then` block, e.g.
//
// acir(inline) fn main f0 {
// b0(v0: u1):
// jmpif v0 then: b2, else: b1
// b2():
// return
// b1():
// jmp b2()
// }
//
// This breaks the `flatten_cfg` pass as it assumes that merges only happen in
// the `else` block or a 3rd block.
//
// See: https://github.com/noir-lang/noir/pull/5891#issuecomment-2500219428
return;
}

if let Some(TerminatorInstruction::JmpIf {
condition,
then_destination,
else_destination,
call_stack,
}) = function.dfg[block].terminator()
{
if let Value::Instruction { instruction, .. } = function.dfg[*condition] {
if let Instruction::Not(negated_condition) = function.dfg[instruction] {
let call_stack = call_stack.clone();
let jmpif = TerminatorInstruction::JmpIf {
condition: negated_condition,
then_destination: *else_destination,
else_destination: *then_destination,
call_stack,
};
function.dfg[block].set_terminator(jmpif);
cfg.recompute_block(function, block);
}
}
}
}

/// If the given block has block parameters, replace them with the jump arguments from the predecessor.
///
/// Currently, if this function is needed, `try_inline_into_predecessor` will also always apply,
Expand Down Expand Up @@ -246,6 +299,8 @@ mod test {
map::Id,
types::Type,
},
opt::assert_normalized_ssa_equals,
Ssa,
};
use acvm::acir::AcirField;

Expand Down Expand Up @@ -359,4 +414,59 @@ mod test {
other => panic!("Unexpected terminator {other:?}"),
}
}

#[test]
fn swap_negated_jmpif_branches_in_brillig() {
let src = "
brillig(inline) fn main f0 {
b0(v0: u1):
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = not v0
jmpif v3 then: b1, else: b2
b1():
store Field 2 at v1
jmp b2()
b2():
v5 = load v1 -> Field
v6 = eq v5, Field 2
constrain v5 == Field 2
return
}";
let ssa = Ssa::from_str(src).unwrap();

let expected = "
brillig(inline) fn main f0 {
b0(v0: u1):
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = not v0
jmpif v0 then: b2, else: b1
b2():
v5 = load v1 -> Field
v6 = eq v5, Field 2
constrain v5 == Field 2
return
b1():
store Field 2 at v1
jmp b2()
}";
assert_normalized_ssa_equals(ssa.simplify_cfg(), expected);
}

#[test]
fn does_not_swap_negated_jmpif_branches_in_acir() {
let src = "
acir(inline) fn main f0 {
b0(v0: u1):
v1 = not v0
jmpif v1 then: b1, else: b2
b1():
jmp b2()
b2():
return
}";
let ssa = Ssa::from_str(src).unwrap();
assert_normalized_ssa_equals(ssa.simplify_cfg(), src);
}
}
Loading