Skip to content

Commit

Permalink
chore(ssa refactor): Simplify inlining pass and fix inlining failure (#…
Browse files Browse the repository at this point in the history
…1337)

* Fix bug in inlining pass

* Work on functions with multiple returns too

* Forgot to translate_block
  • Loading branch information
jfecher authored May 12, 2023
1 parent da47368 commit 7df3bb1
Showing 1 changed file with 127 additions and 27 deletions.
154 changes: 127 additions & 27 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ struct PerFunctionContext<'function> {
/// Maps InstructionIds from the function being inlined to the function being inlined into.
instructions: HashMap<InstructionId, InstructionId>,

/// The TerminatorInstruction::Return in the source_function will be mapped to a jmp to
/// this block in the destination function instead.
return_destination: BasicBlockId,

/// True if we're currently working on the main function.
inlining_main: bool,
}
Expand Down Expand Up @@ -124,7 +120,12 @@ impl InlineContext {

/// Inlines a function into the current function and returns the translated return values
/// of the inlined function.
fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> &[ValueId] {
fn inline_function(
&mut self,
ssa: &Ssa,
id: FunctionId,
arguments: &[ValueId],
) -> Vec<ValueId> {
self.recursion_level += 1;

if self.recursion_level > RECURSION_LIMIT {
Expand All @@ -143,9 +144,7 @@ impl InlineContext {
let current_block = context.context.builder.current_block();
context.blocks.insert(source_function.entry_block(), current_block);

context.inline_blocks(ssa);
let return_destination = context.return_destination;
self.builder.block_parameters(return_destination)
context.inline_blocks(ssa)
}

/// Finish inlining and return the new Ssa struct with the inlined version of main.
Expand Down Expand Up @@ -175,10 +174,7 @@ impl<'function> PerFunctionContext<'function> {
/// for containing the mapping between parameters in the source_function and
/// the arguments of the destination function.
fn new(context: &'function mut InlineContext, source_function: &'function Function) -> Self {
// Create the block to return to but don't insert its parameters until we
// have the types of the actual return values later.
Self {
return_destination: context.builder.insert_block(),
context,
source_function,
blocks: HashMap::new(),
Expand Down Expand Up @@ -265,20 +261,60 @@ impl<'function> PerFunctionContext<'function> {
}

/// Inline all reachable blocks within the source_function into the destination function.
fn inline_blocks(&mut self, ssa: &Ssa) {
fn inline_blocks(&mut self, ssa: &Ssa) -> Vec<ValueId> {
let mut seen_blocks = HashSet::new();
let mut block_queue = vec![self.source_function.entry_block()];

// This Vec will contain each block with a Return instruction along with the
// returned values of that block.
let mut function_returns = vec![];

while let Some(source_block_id) = block_queue.pop() {
let translated_block_id = self.translate_block(source_block_id, &mut block_queue);
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block(ssa, source_block_id);
self.handle_terminator_instruction(source_block_id, &mut block_queue);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
{
function_returns.push((block, values));
}
}

self.context.builder.switch_to_block(self.return_destination);
self.handle_function_returns(function_returns)
}

/// Handle inlining a function's possibly multiple return instructions.
/// If there is only 1 return we can just continue inserting into that block.
/// If there are multiple, we'll need to create a join block to jump to with each value.
fn handle_function_returns(
&mut self,
mut returns: Vec<(BasicBlockId, Vec<ValueId>)>,
) -> Vec<ValueId> {
// Clippy complains if this were written as an if statement
match returns.len() {
1 => {
let (return_block, return_values) = returns.remove(0);
self.context.builder.switch_to_block(return_block);
return_values
}
n if n > 1 => {
// If there is more than 1 return instruction we'll need to create a single block we
// can return to and continue inserting in afterwards.
let return_block = self.context.builder.insert_block();

for (block, return_values) in returns {
self.context.builder.switch_to_block(block);
self.context.builder.terminate_with_jmp(return_block, return_values);
}

self.context.builder.switch_to_block(return_block);
self.context.builder.block_parameters(return_block).to_vec()
}
_ => unreachable!("Inlined function had no return values"),
}
}

/// Inline each instruction in the given block into the function being inlined into.
Expand Down Expand Up @@ -307,7 +343,7 @@ impl<'function> PerFunctionContext<'function> {
let old_results = self.source_function.dfg.instruction_results(call_id);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
let new_results = self.context.inline_function(ssa, function, &arguments);
Self::insert_new_instruction_results(&mut self.values, old_results, new_results);
Self::insert_new_instruction_results(&mut self.values, old_results, &new_results);
}

/// Push the given instruction from the source_function into the current block of the
Expand Down Expand Up @@ -340,16 +376,20 @@ impl<'function> PerFunctionContext<'function> {
/// Handle the given terminator instruction from the given source function block.
/// This will push any new blocks to the destination function as needed, add them
/// to the block queue, and set the terminator instruction for the current block.
///
/// If the terminator instruction was a Return, this will return the block this instruction
/// was in as well as the values that were returned.
fn handle_terminator_instruction(
&mut self,
block_id: BasicBlockId,
block_queue: &mut Vec<BasicBlockId>,
) {
) -> Option<(BasicBlockId, Vec<ValueId>)> {
match self.source_function.dfg[block_id].terminator() {
Some(TerminatorInstruction::Jmp { destination, arguments }) => {
let destination = self.translate_block(*destination, block_queue);
let arguments = vecmap(arguments, |arg| self.translate_value(*arg));
self.context.builder.terminate_with_jmp(destination, arguments);
None
}
Some(TerminatorInstruction::JmpIf {
condition,
Expand All @@ -360,21 +400,15 @@ impl<'function> PerFunctionContext<'function> {
let then_block = self.translate_block(*then_destination, block_queue);
let else_block = self.translate_block(*else_destination, block_queue);
self.context.builder.terminate_with_jmpif(condition, then_block, else_block);
None
}
Some(TerminatorInstruction::Return { return_values }) => {
let return_values = vecmap(return_values, |value| self.translate_value(*value));

if self.inlining_main {
self.context.builder.terminate_with_return(return_values);
} else {
for value in &return_values {
// Add the block parameters for the return block here since we don't do
// it when inserting the block in PerFunctionContext::new
let typ = self.context.builder.current_function.dfg.type_of_value(*value);
self.context.builder.add_block_parameter(self.return_destination, typ);
}
self.context.builder.terminate_with_jmp(self.return_destination, return_values);
self.context.builder.terminate_with_return(return_values.clone());
}
let block_id = self.translate_block(block_id, block_queue);
Some((block_id, return_values))
}
None => unreachable!("Block has no terminator instruction"),
}
Expand All @@ -384,7 +418,7 @@ impl<'function> PerFunctionContext<'function> {
#[cfg(test)]
mod test {
use crate::ssa_refactor::{
ir::{map::Id, types::Type},
ir::{instruction::BinaryOp, map::Id, types::Type},
ssa_builder::FunctionBuilder,
};

Expand Down Expand Up @@ -418,4 +452,70 @@ mod test {
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}

#[test]
fn complex_inlining() {
// This SSA is from issue #1327 which previously failed to inline properly
//
// fn main f0 {
// b0(v0: Field):
// v7 = call f2(f1)
// v13 = call f3(v7)
// v16 = call v13(v0)
// return v16
// }
// fn square f1 {
// b0(v0: Field):
// v2 = mul v0, v0
// return v2
// }
// fn id1 f2 {
// b0(v0: function):
// return v0
// }
// fn id2 f3 {
// b0(v0: function):
// return v0
// }
let main_id = Id::test_new(0);
let square_id = Id::test_new(1);
let id1_id = Id::test_new(2);
let id2_id = Id::test_new(3);

// Compiling main
let mut builder = FunctionBuilder::new("main".into(), main_id);
let main_v0 = builder.add_parameter(Type::field());

let main_f1 = builder.import_function(square_id);
let main_f2 = builder.import_function(id1_id);
let main_f3 = builder.import_function(id2_id);

let main_v7 = builder.insert_call(main_f2, vec![main_f1], vec![Type::Function])[0];
let main_v13 = builder.insert_call(main_f3, vec![main_v7], vec![Type::Function])[0];
let main_v16 = builder.insert_call(main_v13, vec![main_v0], vec![Type::field()])[0];
builder.terminate_with_return(vec![main_v16]);

// Compiling square f1
builder.new_function("square".into(), square_id);
let square_v0 = builder.add_parameter(Type::field());
let square_v2 = builder.insert_binary(square_v0, BinaryOp::Mul, square_v0);
builder.terminate_with_return(vec![square_v2]);

// Compiling id1 f2
builder.new_function("id1".into(), id1_id);
let id1_v0 = builder.add_parameter(Type::Function);
builder.terminate_with_return(vec![id1_v0]);

// Compiling id2 f3
builder.new_function("id2".into(), id2_id);
let id2_v0 = builder.add_parameter(Type::Function);
builder.terminate_with_return(vec![id2_v0]);

// Done, now we test that we can successfully inline all functions.
let ssa = builder.finish();
assert_eq!(ssa.functions.len(), 4);

let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);
}
}

0 comments on commit 7df3bb1

Please sign in to comment.