From 7246462294d8279055278f03b3d844d1d1e16b30 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 18 Nov 2024 12:53:16 -0300 Subject: [PATCH 01/12] WIP pass to inline const brillig calls --- compiler/noirc_evaluator/src/ssa.rs | 2 + .../src/ssa/opt/inline_const_brillig_calls.rs | 113 ++++++++++++++++++ compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 + 3 files changed, 116 insertions(+) create mode 100644 compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 6d6cb11511f..1964cd37d3e 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -141,6 +141,8 @@ pub(crate) fn optimize_into_acir( ssa.to_brillig(options.enable_brillig_logging) }); + let ssa = ssa.inline_const_brillig_calls(&brillig); + let artifacts = time("SSA to ACIR", options.print_codegen_timings, || { ssa.into_acir(&brillig, options.expression_width) })?; diff --git a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs new file mode 100644 index 00000000000..eec4a41afd6 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs @@ -0,0 +1,113 @@ +use std::collections::HashSet; + +use acvm::{blackbox_solver::StubbedBlackBoxSolver, brillig_vm::VMStatus}; + +use crate::{ + brillig::{brillig_ir::artifact::Label, Brillig}, + ssa::{ + ir::{ + function::{Function, FunctionId}, + instruction::{Instruction, InstructionId}, + value::{Value, ValueId}, + }, + Ssa, + }, +}; + +impl Ssa { + pub(crate) fn inline_const_brillig_calls(mut self, brillig: &Brillig) -> Self { + // Keep track of which brillig functions we couldn't completely inline: we'll remove the ones we could. + let mut brillig_functions_we_could_not_inline = HashSet::new(); + + for func in self.functions.values_mut() { + func.inline_const_brillig_calls(&brillig, &mut brillig_functions_we_could_not_inline); + } + + self + } +} + +/// Result of trying to optimize an instruction (any instruction) in this pass. +enum OptimizeResult { + /// Nothing was done because the instruction wasn't a call to a brillig function, + /// or some arguments to it were not constants. + NotABrilligCall, + /// The instruction was a call to a brillig function, but we couldn't optimize it. + CannotOptimize(FunctionId), + /// The instruction was a call to a brillig function and we were able to optimize it, + /// returning the optimized function and the constant values it returned. + Optimized(Function, Vec), +} + +impl Function { + pub(crate) fn inline_const_brillig_calls( + &mut self, + brillig: &Brillig, + brillig_functions_we_could_not_inline: &mut HashSet, + ) { + for block_id in self.reachable_blocks() { + for instruction_id in self.dfg[block_id].take_instructions() { + let optimize_result = self.optimize_const_brillig_call( + instruction_id, + brillig, + brillig_functions_we_could_not_inline, + ); + match optimize_result { + OptimizeResult::NotABrilligCall => { + self.dfg[block_id].instructions_mut().push(instruction_id); + } + OptimizeResult::CannotOptimize(func_id) => { + self.dfg[block_id].instructions_mut().push(instruction_id); + brillig_functions_we_could_not_inline.insert(func_id); + } + OptimizeResult::Optimized(function, return_values) => { + // Replace the instruction results with the constant values we got + // let current_results = self.dfg.instruction_results(instruction_id).to_vec(); + // assert_eq!(return_values.len(), current_results.len()); + + // for (current_result_id, return_value_id) in + // current_results.iter().zip(return_values) + // { + // let new_return_value_id = + // function.copy_constant_to_function(return_value_id, self); + // self.dfg.set_value_from_id(*current_result_id, new_return_value_id); + // } + } + } + } + } + } + + /// Tries to optimize an instruction if it's a call that points to a brillig function, + /// and all its arguments are constant. + fn optimize_const_brillig_call( + &mut self, + instruction_id: InstructionId, + brillig: &Brillig, + brillig_functions_we_could_not_inline: &mut HashSet, + ) -> OptimizeResult { + let instruction = &self.dfg[instruction_id]; + let Instruction::Call { func: func_id, arguments } = instruction else { + return OptimizeResult::NotABrilligCall; + }; + + let func_value = &self.dfg[*func_id]; + let Value::Function(func_id) = func_value else { + return OptimizeResult::NotABrilligCall; + }; + let func_id = *func_id; + dbg!(func_id); + + let Some(brillig_artifact) = brillig.find_by_label(Label::function(func_id)) else { + return OptimizeResult::NotABrilligCall; + }; + + if !arguments.iter().all(|argument| self.dfg.is_constant(*argument)) { + return OptimizeResult::CannotOptimize(func_id); + } + + // TODO... + + OptimizeResult::CannotOptimize(func_id) + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 098f62bceba..c3160565b23 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -11,6 +11,7 @@ mod constant_folding; mod defunctionalize; mod die; pub(crate) mod flatten_cfg; +mod inline_const_brillig_calls; mod inlining; mod mem2reg; mod normalize_value_ids; From b2e6e107d7ef9457c006717838f214ffce18ca11 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 18 Nov 2024 14:05:51 -0300 Subject: [PATCH 02/12] Turn the `gen_brillig_for` method into a function --- compiler/noirc_evaluator/src/acir/mod.rs | 52 ++---------------- .../src/brillig/brillig_gen.rs | 54 +++++++++++++++++-- 2 files changed, 54 insertions(+), 52 deletions(-) diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 16b07b40863..e030302e60f 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -24,12 +24,10 @@ mod big_int; mod brillig_directive; mod generated_acir; +use crate::brillig::brillig_gen::gen_brillig_for; use crate::brillig::{ brillig_gen::brillig_fn::FunctionContext as BrilligFunctionContext, - brillig_ir::{ - artifact::{BrilligParameter, GeneratedBrillig}, - BrilligContext, - }, + brillig_ir::artifact::{BrilligParameter, GeneratedBrillig}, Brillig, }; use crate::errors::{InternalError, InternalWarning, RuntimeError, SsaReport}; @@ -516,7 +514,7 @@ impl<'a> Context<'a> { let outputs: Vec = vecmap(main_func.returns(), |result_id| dfg.type_of_value(*result_id).into()); - let code = self.gen_brillig_for(main_func, arguments.clone(), brillig)?; + let code = gen_brillig_for(main_func, arguments.clone(), brillig)?; // We specifically do not attempt execution of the brillig code being generated as this can result in it being // replaced with constraints on witnesses to the program outputs. @@ -873,8 +871,7 @@ impl<'a> Context<'a> { None, )? } else { - let code = - self.gen_brillig_for(func, arguments.clone(), brillig)?; + let code = gen_brillig_for(func, arguments.clone(), brillig)?; let generated_pointer = self.shared_context.new_generated_pointer(); let output_values = self.acir_context.brillig_call( @@ -994,47 +991,6 @@ impl<'a> Context<'a> { .collect() } - fn gen_brillig_for( - &self, - func: &Function, - arguments: Vec, - brillig: &Brillig, - ) -> Result, InternalError> { - // Create the entry point artifact - let mut entry_point = BrilligContext::new_entry_point_artifact( - arguments, - BrilligFunctionContext::return_values(func), - func.id(), - ); - entry_point.name = func.name().to_string(); - - // Link the entry point with all dependencies - while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() { - let artifact = &brillig.find_by_label(unresolved_fn_label); - let artifact = match artifact { - Some(artifact) => artifact, - None => { - return Err(InternalError::General { - message: format!("Cannot find linked fn {unresolved_fn_label}"), - call_stack: CallStack::new(), - }) - } - }; - entry_point.link_with(artifact); - // Insert the range of opcode locations occupied by a procedure - if let Some(procedure_id) = artifact.procedure { - let num_opcodes = entry_point.byte_code.len(); - let previous_num_opcodes = entry_point.byte_code.len() - artifact.byte_code.len(); - // We subtract one as to keep the range inclusive on both ends - entry_point - .procedure_locations - .insert(procedure_id, (previous_num_opcodes, num_opcodes - 1)); - } - } - // Generate the final bytecode - Ok(entry_point.finish()) - } - /// Handles an ArrayGet or ArraySet instruction. /// To set an index of the array (and create a new array in doing so), pass Some(value) for /// store_value. To just retrieve an index of the array, pass None for store_value. diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index 786a03031d6..f158b5912e8 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -9,11 +9,17 @@ mod variable_liveness; use acvm::FieldElement; use self::{brillig_block::BrilligBlock, brillig_fn::FunctionContext}; -use super::brillig_ir::{ - artifact::{BrilligArtifact, Label}, - BrilligContext, +use super::{ + brillig_ir::{ + artifact::{BrilligArtifact, BrilligParameter, GeneratedBrillig, Label}, + BrilligContext, + }, + Brillig, +}; +use crate::{ + errors::InternalError, + ssa::ir::{dfg::CallStack, function::Function}, }; -use crate::ssa::ir::function::Function; /// Converting an SSA function into Brillig bytecode. pub(crate) fn convert_ssa_function( @@ -36,3 +42,43 @@ pub(crate) fn convert_ssa_function( artifact.name = func.name().to_string(); artifact } + +pub(crate) fn gen_brillig_for( + func: &Function, + arguments: Vec, + brillig: &Brillig, +) -> Result, InternalError> { + // Create the entry point artifact + let mut entry_point = BrilligContext::new_entry_point_artifact( + arguments, + FunctionContext::return_values(func), + func.id(), + ); + entry_point.name = func.name().to_string(); + + // Link the entry point with all dependencies + while let Some(unresolved_fn_label) = entry_point.first_unresolved_function_call() { + let artifact = &brillig.find_by_label(unresolved_fn_label); + let artifact = match artifact { + Some(artifact) => artifact, + None => { + return Err(InternalError::General { + message: format!("Cannot find linked fn {unresolved_fn_label}"), + call_stack: CallStack::new(), + }) + } + }; + entry_point.link_with(artifact); + // Insert the range of opcode locations occupied by a procedure + if let Some(procedure_id) = artifact.procedure { + let num_opcodes = entry_point.byte_code.len(); + let previous_num_opcodes = entry_point.byte_code.len() - artifact.byte_code.len(); + // We subtract one as to keep the range inclusive on both ends + entry_point + .procedure_locations + .insert(procedure_id, (previous_num_opcodes, num_opcodes - 1)); + } + } + // Generate the final bytecode + Ok(entry_point.finish()) +} From c2786eac5041afbac55b6ba39c51fc376e27f297 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Mon, 18 Nov 2024 19:48:52 -0300 Subject: [PATCH 03/12] Inline brillig calls with all constants by executing them with the VM --- .../src/brillig/brillig_gen/brillig_fn.rs | 19 + .../src/ssa/opt/inline_const_brillig_calls.rs | 440 ++++++++++++++++-- 2 files changed, 414 insertions(+), 45 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 2779be103cd..2ea4accc2c5 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -68,6 +68,25 @@ impl FunctionContext { } } + pub(crate) fn try_ssa_type_to_parameter(typ: &Type) -> Option { + match typ { + Type::Numeric(_) | Type::Reference(_) => { + Some(BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))) + } + Type::Array(item_type, size) => { + let mut parameters = Vec::new(); + for item_typ in item_type.iter() { + let Some(param) = FunctionContext::try_ssa_type_to_parameter(item_typ) else { + return None; + }; + parameters.push(param); + } + Some(BrilligParameter::Array(parameters, *size)) + } + _ => None, + } + } + /// Collects the return values of a given function pub(crate) fn return_values(func: &Function) -> Vec { func.returns() diff --git a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs index eec4a41afd6..9812f2cd054 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs @@ -1,13 +1,25 @@ -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; -use acvm::{blackbox_solver::StubbedBlackBoxSolver, brillig_vm::VMStatus}; +use acvm::{ + blackbox_solver::StubbedBlackBoxSolver, + brillig_vm::{MemoryValue, VMStatus, VM}, + FieldElement, +}; + +use im::Vector; use crate::{ - brillig::{brillig_ir::artifact::Label, Brillig}, + brillig::{ + brillig_gen::{brillig_fn::FunctionContext, gen_brillig_for}, + Brillig, + }, ssa::{ ir::{ - function::{Function, FunctionId}, + basic_block::BasicBlockId, + dfg::DataFlowGraph, + function::{Function, FunctionId, RuntimeType}, instruction::{Instruction, InstructionId}, + types::Type, value::{Value, ValueId}, }, Ssa, @@ -16,98 +28,436 @@ use crate::{ impl Ssa { pub(crate) fn inline_const_brillig_calls(mut self, brillig: &Brillig) -> Self { + // Collect all brillig functions so that later we can find them when processing a call instruction + let mut brillig_functions: BTreeMap = BTreeMap::new(); + for (func_id, func) in &self.functions { + if let RuntimeType::Brillig(..) = func.runtime() { + let cloned_function = Function::clone_with_id(*func_id, func); + brillig_functions.insert(*func_id, cloned_function); + }; + } + // Keep track of which brillig functions we couldn't completely inline: we'll remove the ones we could. let mut brillig_functions_we_could_not_inline = HashSet::new(); for func in self.functions.values_mut() { - func.inline_const_brillig_calls(&brillig, &mut brillig_functions_we_could_not_inline); + func.inline_const_brillig_calls( + brillig, + &brillig_functions, + &mut brillig_functions_we_could_not_inline, + ); + } + + // Remove the brillig functions that are no longer called + for func_id in brillig_functions.keys() { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given) + if self.main_id == *func_id { + continue; + } + + if brillig_functions_we_could_not_inline.contains(func_id) { + continue; + } + + // We also don't want to remove entry points + if self.entry_point_to_generated_index.contains_key(func_id) { + continue; + } + + self.functions.remove(func_id); } self } } -/// Result of trying to optimize an instruction (any instruction) in this pass. -enum OptimizeResult { +/// Result of trying to evaluate an instruction (any instruction) in this pass. +enum EvaluationResult { /// Nothing was done because the instruction wasn't a call to a brillig function, /// or some arguments to it were not constants. NotABrilligCall, - /// The instruction was a call to a brillig function, but we couldn't optimize it. - CannotOptimize(FunctionId), - /// The instruction was a call to a brillig function and we were able to optimize it, - /// returning the optimized function and the constant values it returned. - Optimized(Function, Vec), + /// The instruction was a call to a brillig function, but we couldn't evaluate it. + CannotEvaluate(FunctionId), + /// The instruction was a call to a brillig function and we were able to evaluate it, + /// returning evaluation memory values. + Evaluated(Vec>), } impl Function { pub(crate) fn inline_const_brillig_calls( &mut self, brillig: &Brillig, + brillig_functions: &BTreeMap, brillig_functions_we_could_not_inline: &mut HashSet, ) { for block_id in self.reachable_blocks() { for instruction_id in self.dfg[block_id].take_instructions() { - let optimize_result = self.optimize_const_brillig_call( - instruction_id, - brillig, - brillig_functions_we_could_not_inline, - ); - match optimize_result { - OptimizeResult::NotABrilligCall => { + let evaluation_result = + self.evaluate_const_brillig_call(instruction_id, brillig, brillig_functions); + match evaluation_result { + EvaluationResult::NotABrilligCall => { self.dfg[block_id].instructions_mut().push(instruction_id); } - OptimizeResult::CannotOptimize(func_id) => { + EvaluationResult::CannotEvaluate(func_id) => { self.dfg[block_id].instructions_mut().push(instruction_id); brillig_functions_we_could_not_inline.insert(func_id); } - OptimizeResult::Optimized(function, return_values) => { + EvaluationResult::Evaluated(memory_values) => { // Replace the instruction results with the constant values we got - // let current_results = self.dfg.instruction_results(instruction_id).to_vec(); - // assert_eq!(return_values.len(), current_results.len()); - - // for (current_result_id, return_value_id) in - // current_results.iter().zip(return_values) - // { - // let new_return_value_id = - // function.copy_constant_to_function(return_value_id, self); - // self.dfg.set_value_from_id(*current_result_id, new_return_value_id); - // } + let result_ids = self.dfg.instruction_results(instruction_id).to_vec(); + + let mut memory_index = 0; + for result_id in result_ids { + self.replace_result_id_with_memory_value( + result_id, + block_id, + &memory_values, + &mut memory_index, + ); + } } } } } } - /// Tries to optimize an instruction if it's a call that points to a brillig function, - /// and all its arguments are constant. - fn optimize_const_brillig_call( + /// Replaces `result_id` by taking memory values from `memory_values` starting at `memory_index` + /// depending on the type of the ValueId (it will read multiple memory values if it's an array). + fn replace_result_id_with_memory_value( + &mut self, + result_id: ValueId, + block_id: BasicBlockId, + memory_values: &[MemoryValue], + memory_index: &mut usize, + ) { + let typ = self.dfg.type_of_value(result_id); + let new_value = + self.new_value_for_type_and_memory_values(typ, block_id, memory_values, memory_index); + self.dfg.set_value_from_id(result_id, new_value); + } + + /// Creates a new value inside this function by reading it from `memory_values` starting at + /// `memory_index` depending on the given Type: if it's an array multiple values will be read + /// and a new `make_array` instruction will be created. + fn new_value_for_type_and_memory_values( &mut self, + typ: Type, + block_id: BasicBlockId, + memory_values: &[MemoryValue], + memory_index: &mut usize, + ) -> ValueId { + match typ { + Type::Numeric(_) => { + let memory = memory_values[*memory_index]; + *memory_index += 1; + + let field_value = match memory { + MemoryValue::Field(field_value) => field_value, + MemoryValue::Integer(u128_value, _) => u128_value.into(), + }; + self.dfg.make_constant(field_value, typ) + } + Type::Array(types, length) => { + let mut new_array_values = Vector::new(); + for _ in 0..length { + for typ in types.iter() { + let new_value = self.new_value_for_type_and_memory_values( + typ.clone(), + block_id, + memory_values, + memory_index, + ); + new_array_values.push_back(new_value); + } + } + + let instruction = Instruction::MakeArray { + elements: new_array_values, + typ: Type::Array(types, length), + }; + let instruction_id = self.dfg.make_instruction(instruction, None); + self.dfg[block_id].instructions_mut().push(instruction_id); + *self.dfg.instruction_results(instruction_id).first().unwrap() + } + Type::Reference(_) => { + panic!("Unexpected reference type in brillig function result") + } + Type::Slice(_) => { + panic!("Unexpected slice type in brillig function result") + } + Type::Function => { + panic!("Unexpected function type in brillig function result") + } + } + } + + /// Tries to evaluate an instruction if it's a call that points to a brillig function, + /// and all its arguments are constant. + /// We do this by directly executing the function with a brillig VM. + fn evaluate_const_brillig_call( + &self, instruction_id: InstructionId, brillig: &Brillig, - brillig_functions_we_could_not_inline: &mut HashSet, - ) -> OptimizeResult { + brillig_functions: &BTreeMap, + ) -> EvaluationResult { let instruction = &self.dfg[instruction_id]; let Instruction::Call { func: func_id, arguments } = instruction else { - return OptimizeResult::NotABrilligCall; + return EvaluationResult::NotABrilligCall; }; let func_value = &self.dfg[*func_id]; let Value::Function(func_id) = func_value else { - return OptimizeResult::NotABrilligCall; + return EvaluationResult::NotABrilligCall; }; - let func_id = *func_id; - dbg!(func_id); - let Some(brillig_artifact) = brillig.find_by_label(Label::function(func_id)) else { - return OptimizeResult::NotABrilligCall; + let Some(func) = brillig_functions.get(func_id) else { + return EvaluationResult::NotABrilligCall; }; if !arguments.iter().all(|argument| self.dfg.is_constant(*argument)) { - return OptimizeResult::CannotOptimize(func_id); + return EvaluationResult::CannotEvaluate(*func_id); + } + + let mut brillig_arguments = Vec::new(); + for argument in arguments { + let typ = self.dfg.type_of_value(*argument); + let Some(parameter) = FunctionContext::try_ssa_type_to_parameter(&typ) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + brillig_arguments.push(parameter); + } + + // Check that return value types are supported by brillig + for return_id in func.returns().iter() { + let typ = func.dfg.type_of_value(*return_id); + if FunctionContext::try_ssa_type_to_parameter(&typ).is_none() { + return EvaluationResult::CannotEvaluate(*func_id); + } + } + + let Ok(generated_brillig) = gen_brillig_for(func, brillig_arguments, brillig) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let mut calldata = Vec::new(); + for argument in arguments { + value_id_to_calldata(*argument, &self.dfg, &mut calldata); + } + + let bytecode = &generated_brillig.byte_code; + let foreign_call_results = Vec::new(); + let black_box_solver = StubbedBlackBoxSolver; + let profiling_active = false; + let mut vm = + VM::new(calldata, bytecode, foreign_call_results, &black_box_solver, profiling_active); + let vm_status: VMStatus<_> = vm.process_opcodes(); + let VMStatus::Finished { return_data_offset, return_data_size } = vm_status else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let memory = + vm.get_memory()[return_data_offset..(return_data_offset + return_data_size)].to_vec(); + + EvaluationResult::Evaluated(memory) + } +} + +fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut Vec) { + if let Some(value) = dfg.get_numeric_constant(value_id) { + calldata.push(value); + return; + } + + if let Some((values, _type)) = dfg.get_array_constant(value_id) { + for value in values { + value_id_to_calldata(value, dfg, calldata); } + return; + } + + panic!("Expected ValueId to be numeric constant or array constant"); +} + +#[cfg(test)] +mod test { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn inlines_brillig_call_without_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1() -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(): + v0 = add Field 2, Field 3 + return v0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_field_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3) -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_i32_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(i32 2, i32 3) -> i32 + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: i32, v1: i32): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return i32 5 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3, Field 4) -> [Field; 3] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field, v2: Field): + v3 = make_array [v0, v1, v2] : [Field; 3] + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v3 = make_array [Field 2, Field 3, Field 4] : [Field; 3] + return v3 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_composite_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, i32 3, Field 4, i32 5) -> [(Field, i32); 2] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: i32, v2: i32, v3: Field): + v4 = make_array [v0, v1, v2, v3] : [(Field, i32); 2] + return v4 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v4 = make_array [Field 2, i32 3, Field 4, i32 5] : [(Field, i32); 2] + return v4 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = make_array [Field 2, Field 3] : [Field; 2] + v1 = call f1(v0) -> Field + return v1 + } - // TODO... + brillig(inline) fn one f1 { + b0(v0: [Field; 2]): + inc_rc v0 + v2 = array_get v0, index u32 0 -> Field + v4 = array_get v0, index u32 1 -> Field + v5 = add v2, v4 + dec_rc v0 + return v5 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); - OptimizeResult::CannotOptimize(func_id) + let expected = " + acir(inline) fn main f0 { + b0(): + v2 = make_array [Field 2, Field 3] : [Field; 2] + return Field 5 + } + "; + let ssa = ssa.inline_const_brillig_calls(&brillig); + assert_normalized_ssa_equals(ssa, expected); } } From d2489bff4e0fe4780f0e94235dfe4eb215f5b84f Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 15:29:12 -0300 Subject: [PATCH 04/12] Move logic to constant folding --- compiler/noirc_evaluator/src/ssa.rs | 17 +- .../src/ssa/opt/constant_folding.rs | 493 +++++++++++++++++- .../src/ssa/opt/inline_const_brillig_calls.rs | 463 ---------------- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 - 4 files changed, 487 insertions(+), 487 deletions(-) delete mode 100644 compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 1964cd37d3e..344ac114a03 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -141,7 +141,22 @@ pub(crate) fn optimize_into_acir( ssa.to_brillig(options.enable_brillig_logging) }); - let ssa = ssa.inline_const_brillig_calls(&brillig); + let ssa_gen_span = span!(Level::TRACE, "ssa_generation"); + let ssa_gen_span_guard = ssa_gen_span.enter(); + + let ssa = SsaBuilder { + ssa, + print_ssa_passes: options.enable_ssa_logging, + print_codegen_timings: options.print_codegen_timings, + } + .run_pass( + |ssa| ssa.fold_constants_with_brillig(&brillig), + "After Constant Folding with Brillig:", + ) + .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") + .finish(); + + drop(ssa_gen_span_guard); let artifacts = time("SSA to ACIR", options.print_codegen_timings, || { ssa.into_acir(&brillig, options.expression_width) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 32f66e5a0f0..b7ab3c65b49 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -19,22 +19,34 @@ //! //! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when //! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass. -use std::collections::{HashSet, VecDeque}; +use std::collections::{BTreeMap, HashSet, VecDeque}; -use acvm::{acir::AcirField, FieldElement}; +use acvm::{ + acir::AcirField, + brillig_vm::{MemoryValue, VMStatus, VM}, + FieldElement, +}; +use bn254_blackbox_solver::Bn254BlackBoxSolver; +use im::Vector; use iter_extended::vecmap; -use crate::ssa::{ - ir::{ - basic_block::BasicBlockId, - dfg::{DataFlowGraph, InsertInstructionResult}, - dom::DominatorTree, - function::Function, - instruction::{Instruction, InstructionId}, - types::Type, - value::{Value, ValueId}, +use crate::{ + brillig::{ + brillig_gen::{brillig_fn::FunctionContext, gen_brillig_for}, + Brillig, + }, + ssa::{ + ir::{ + basic_block::BasicBlockId, + dfg::{DataFlowGraph, InsertInstructionResult}, + dom::DominatorTree, + function::{Function, FunctionId, RuntimeType}, + instruction::{Instruction, InstructionId}, + types::Type, + value::{Value, ValueId}, + }, + ssa_gen::Ssa, }, - ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -44,8 +56,10 @@ impl Ssa { /// See [`constant_folding`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants(mut self) -> Ssa { + let mut brillig_functions_we_could_not_inline = HashSet::new(); + for function in self.functions.values_mut() { - function.constant_fold(false); + function.constant_fold(false, None, &mut brillig_functions_we_could_not_inline); } self } @@ -57,9 +71,49 @@ impl Ssa { /// See [`constant_folding`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants_using_constraints(mut self) -> Ssa { + let mut brillig_functions_we_could_not_inline = HashSet::new(); + + for function in self.functions.values_mut() { + function.constant_fold(true, None, &mut brillig_functions_we_could_not_inline); + } + self + } + + #[tracing::instrument(level = "trace", skip(self, brillig))] + pub(crate) fn fold_constants_with_brillig(mut self, brillig: &Brillig) -> Ssa { + // Collect all brillig functions so that later we can find them when processing a call instruction + let mut brillig_functions: BTreeMap = BTreeMap::new(); + for (func_id, func) in &self.functions { + if let RuntimeType::Brillig(..) = func.runtime() { + let cloned_function = Function::clone_with_id(*func_id, func); + brillig_functions.insert(*func_id, cloned_function); + }; + } + + // Keep track of which brillig functions we couldn't completely inline: we'll remove the ones we could. + let mut brillig_functions_we_could_not_inline = HashSet::new(); + + let brillig_info = Some(BrilligInfo { brillig, brillig_functions: &brillig_functions }); + for function in self.functions.values_mut() { - function.constant_fold(true); + function.constant_fold(false, brillig_info, &mut brillig_functions_we_could_not_inline); + } + + // Remove the brillig functions that are no longer called + for func_id in brillig_functions.keys() { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given). + // We also don't want to remove entry points. + if self.main_id == *func_id + || brillig_functions_we_could_not_inline.contains(func_id) + || self.entry_point_to_generated_index.contains_key(func_id) + { + continue; + } + + self.functions.remove(func_id); } + self } } @@ -67,8 +121,13 @@ impl Ssa { impl Function { /// The structure of this pass is simple: /// Go through each block and re-insert all instructions. - pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) { - let mut context = Context::new(self, use_constraint_info); + pub(crate) fn constant_fold( + &mut self, + use_constraint_info: bool, + brillig_info: Option, + brillig_functions_we_could_not_inline: &mut HashSet, + ) { + let mut context = Context::new(self, use_constraint_info, brillig_info); context.block_queue.push_back(self.entry_block()); while let Some(block) = context.block_queue.pop_front() { @@ -77,13 +136,14 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(self, block); + context.fold_constants_in_block(self, block, brillig_functions_we_could_not_inline); } } } -struct Context { +struct Context<'a> { use_constraint_info: bool, + brillig_info: Option>, /// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction. visited_blocks: HashSet, block_queue: VecDeque, @@ -103,6 +163,12 @@ struct Context { dom: DominatorTree, } +#[derive(Copy, Clone)] +pub(crate) struct BrilligInfo<'a> { + brillig: &'a Brillig, + brillig_functions: &'a BTreeMap, +} + /// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. /// @@ -118,10 +184,15 @@ struct ResultCache { result: Option<(BasicBlockId, Vec)>, } -impl Context { - fn new(function: &Function, use_constraint_info: bool) -> Self { +impl<'brillig> Context<'brillig> { + fn new( + function: &Function, + use_constraint_info: bool, + brillig_info: Option>, + ) -> Self { Self { use_constraint_info, + brillig_info, visited_blocks: Default::default(), block_queue: Default::default(), constraint_simplification_mappings: Default::default(), @@ -130,7 +201,12 @@ impl Context { } } - fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { + fn fold_constants_in_block( + &mut self, + function: &mut Function, + block: BasicBlockId, + brillig_functions_we_could_not_inline: &mut HashSet, + ) { let instructions = function.dfg[block].take_instructions(); let mut side_effects_enabled_var = @@ -142,6 +218,7 @@ impl Context { block, instruction_id, &mut side_effects_enabled_var, + brillig_functions_we_could_not_inline, ); } self.block_queue.extend(function.dfg[block].successors()); @@ -153,6 +230,7 @@ impl Context { mut block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, + brillig_functions_we_could_not_inline: &mut HashSet, ) { let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); @@ -178,7 +256,15 @@ impl Context { } // Otherwise, try inserting the instruction again to apply any optimizations using the newly resolved inputs. - let new_results = Self::push_instruction(id, instruction.clone(), &old_results, block, dfg); + let new_results = Self::push_instruction( + id, + instruction.clone(), + &old_results, + block, + dfg, + self.brillig_info, + brillig_functions_we_could_not_inline, + ); Self::replace_result_ids(dfg, &old_results, &new_results); @@ -237,7 +323,41 @@ impl Context { old_results: &[ValueId], block: BasicBlockId, dfg: &mut DataFlowGraph, + brillig_info: Option, + brillig_functions_we_could_not_inline: &mut HashSet, ) -> Vec { + // Check if this is a call to a brillig function with all constant arguments. + // If so, we can try to evaluate that function and replace the results with the evaluation results. + if let Some(brillig_info) = brillig_info { + let evaluation_result = Self::evaluate_const_brillig_call( + &instruction, + brillig_info.brillig, + brillig_info.brillig_functions, + dfg, + ); + + match evaluation_result { + EvaluationResult::NotABrilligCall => (), + EvaluationResult::CannotEvaluate(id) => { + brillig_functions_we_could_not_inline.insert(id); + } + EvaluationResult::Evaluated(memory_values) => { + let mut memory_index = 0; + let new_results = vecmap(old_results, |old_result| { + let typ = dfg.type_of_value(*old_result); + Self::new_value_for_type_and_memory_values( + typ, + block, + &memory_values, + &mut memory_index, + dfg, + ) + }); + return new_results; + } + } + } + let ctrl_typevars = instruction .requires_ctrl_typevars() .then(|| vecmap(old_results, |result| dfg.type_of_value(*result))); @@ -342,6 +462,131 @@ impl Context { results_for_instruction.get(&predicate)?.get(block, &mut self.dom) } + + /// Tries to evaluate an instruction if it's a call that points to a brillig function, + /// and all its arguments are constant. + /// We do this by directly executing the function with a brillig VM. + fn evaluate_const_brillig_call( + instruction: &Instruction, + brillig: &Brillig, + brillig_functions: &BTreeMap, + dfg: &mut DataFlowGraph, + ) -> EvaluationResult { + let Instruction::Call { func: func_id, arguments } = instruction else { + return EvaluationResult::NotABrilligCall; + }; + + let func_value = &dfg[*func_id]; + let Value::Function(func_id) = func_value else { + return EvaluationResult::NotABrilligCall; + }; + + let Some(func) = brillig_functions.get(func_id) else { + return EvaluationResult::NotABrilligCall; + }; + + if !arguments.iter().all(|argument| dfg.is_constant(*argument)) { + return EvaluationResult::CannotEvaluate(*func_id); + } + + let mut brillig_arguments = Vec::new(); + for argument in arguments { + let typ = dfg.type_of_value(*argument); + let Some(parameter) = FunctionContext::try_ssa_type_to_parameter(&typ) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + brillig_arguments.push(parameter); + } + + // Check that return value types are supported by brillig + for return_id in func.returns().iter() { + let typ = func.dfg.type_of_value(*return_id); + if FunctionContext::try_ssa_type_to_parameter(&typ).is_none() { + return EvaluationResult::CannotEvaluate(*func_id); + } + } + + let Ok(generated_brillig) = gen_brillig_for(func, brillig_arguments, brillig) else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let mut calldata = Vec::new(); + for argument in arguments { + value_id_to_calldata(*argument, dfg, &mut calldata); + } + + let bytecode = &generated_brillig.byte_code; + let foreign_call_results = Vec::new(); + let black_box_solver = Bn254BlackBoxSolver; + let profiling_active = false; + let mut vm = + VM::new(calldata, bytecode, foreign_call_results, &black_box_solver, profiling_active); + let vm_status: VMStatus<_> = vm.process_opcodes(); + let VMStatus::Finished { return_data_offset, return_data_size } = vm_status else { + return EvaluationResult::CannotEvaluate(*func_id); + }; + + let memory = + vm.get_memory()[return_data_offset..(return_data_offset + return_data_size)].to_vec(); + + EvaluationResult::Evaluated(memory) + } + + /// Creates a new value inside this function by reading it from `memory_values` starting at + /// `memory_index` depending on the given Type: if it's an array multiple values will be read + /// and a new `make_array` instruction will be created. + fn new_value_for_type_and_memory_values( + typ: Type, + block_id: BasicBlockId, + memory_values: &[MemoryValue], + memory_index: &mut usize, + dfg: &mut DataFlowGraph, + ) -> ValueId { + match typ { + Type::Numeric(_) => { + let memory = memory_values[*memory_index]; + *memory_index += 1; + + let field_value = match memory { + MemoryValue::Field(field_value) => field_value, + MemoryValue::Integer(u128_value, _) => u128_value.into(), + }; + dfg.make_constant(field_value, typ) + } + Type::Array(types, length) => { + let mut new_array_values = Vector::new(); + for _ in 0..length { + for typ in types.iter() { + let new_value = Self::new_value_for_type_and_memory_values( + typ.clone(), + block_id, + memory_values, + memory_index, + dfg, + ); + new_array_values.push_back(new_value); + } + } + + let instruction = Instruction::MakeArray { + elements: new_array_values, + typ: Type::Array(types, length), + }; + let instruction_id = dfg.make_instruction(instruction, None); + dfg[block_id].instructions_mut().push(instruction_id); + *dfg.instruction_results(instruction_id).first().unwrap() + } + Type::Reference(_) => { + panic!("Unexpected reference type in brillig function result") + } + Type::Slice(_) => { + panic!("Unexpected slice type in brillig function result") + } + Type::Function => { + panic!("Unexpected function type in brillig function result") + } + } + } } impl ResultCache { @@ -376,6 +621,34 @@ enum CacheResult<'a> { NeedToHoistToCommonBlock(BasicBlockId, &'a [ValueId]), } +/// Result of trying to evaluate an instruction (any instruction) in this pass. +enum EvaluationResult { + /// Nothing was done because the instruction wasn't a call to a brillig function, + /// or some arguments to it were not constants. + NotABrilligCall, + /// The instruction was a call to a brillig function, but we couldn't evaluate it. + CannotEvaluate(FunctionId), + /// The instruction was a call to a brillig function and we were able to evaluate it, + /// returning evaluation memory values. + Evaluated(Vec>), +} + +fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut Vec) { + if let Some(value) = dfg.get_numeric_constant(value_id) { + calldata.push(value); + return; + } + + if let Some((values, _type)) = dfg.get_array_constant(value_id) { + for value in values { + value_id_to_calldata(value, dfg, calldata); + } + return; + } + + panic!("Expected ValueId to be numeric constant or array constant"); +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -854,4 +1127,180 @@ mod test { let ssa = ssa.fold_constants_using_constraints(); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn inlines_brillig_call_without_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1() -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(): + v0 = add Field 2, Field 3 + return v0 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_field_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3) -> Field + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_two_i32_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(i32 2, i32 3) -> i32 + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: i32, v1: i32): + v2 = add v0, v1 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + return i32 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, Field 3, Field 4) -> [Field; 3] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: Field, v2: Field): + v3 = make_array [v0, v1, v2] : [Field; 3] + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v3 = make_array [Field 2, Field 3, Field 4] : [Field; 3] + return v3 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_composite_array_return() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = call f1(Field 2, i32 3, Field 4, i32 5) -> [(Field, i32); 2] + return v0 + } + + brillig(inline) fn one f1 { + b0(v0: Field, v1: i32, v2: i32, v3: Field): + v4 = make_array [v0, v1, v2, v3] : [(Field, i32); 2] + return v4 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v4 = make_array [Field 2, i32 3, Field 4, i32 5] : [(Field, i32); 2] + return v4 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn inlines_brillig_call_with_array_arguments() { + let src = " + acir(inline) fn main f0 { + b0(): + v0 = make_array [Field 2, Field 3] : [Field; 2] + v1 = call f1(v0) -> Field + return v1 + } + + brillig(inline) fn one f1 { + b0(v0: [Field; 2]): + inc_rc v0 + v2 = array_get v0, index u32 0 -> Field + v4 = array_get v0, index u32 1 -> Field + v5 = add v2, v4 + dec_rc v0 + return v5 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let brillig = ssa.to_brillig(false); + + let expected = " + acir(inline) fn main f0 { + b0(): + v2 = make_array [Field 2, Field 3] : [Field; 2] + return Field 5 + } + "; + let ssa = ssa.fold_constants_with_brillig(&brillig); + assert_normalized_ssa_equals(ssa, expected); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs b/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs deleted file mode 100644 index 9812f2cd054..00000000000 --- a/compiler/noirc_evaluator/src/ssa/opt/inline_const_brillig_calls.rs +++ /dev/null @@ -1,463 +0,0 @@ -use std::collections::{BTreeMap, HashSet}; - -use acvm::{ - blackbox_solver::StubbedBlackBoxSolver, - brillig_vm::{MemoryValue, VMStatus, VM}, - FieldElement, -}; - -use im::Vector; - -use crate::{ - brillig::{ - brillig_gen::{brillig_fn::FunctionContext, gen_brillig_for}, - Brillig, - }, - ssa::{ - ir::{ - basic_block::BasicBlockId, - dfg::DataFlowGraph, - function::{Function, FunctionId, RuntimeType}, - instruction::{Instruction, InstructionId}, - types::Type, - value::{Value, ValueId}, - }, - Ssa, - }, -}; - -impl Ssa { - pub(crate) fn inline_const_brillig_calls(mut self, brillig: &Brillig) -> Self { - // Collect all brillig functions so that later we can find them when processing a call instruction - let mut brillig_functions: BTreeMap = BTreeMap::new(); - for (func_id, func) in &self.functions { - if let RuntimeType::Brillig(..) = func.runtime() { - let cloned_function = Function::clone_with_id(*func_id, func); - brillig_functions.insert(*func_id, cloned_function); - }; - } - - // Keep track of which brillig functions we couldn't completely inline: we'll remove the ones we could. - let mut brillig_functions_we_could_not_inline = HashSet::new(); - - for func in self.functions.values_mut() { - func.inline_const_brillig_calls( - brillig, - &brillig_functions, - &mut brillig_functions_we_could_not_inline, - ); - } - - // Remove the brillig functions that are no longer called - for func_id in brillig_functions.keys() { - // We never want to remove the main function (it could be `unconstrained` or it - // could have been turned into brillig if `--force-brillig` was given) - if self.main_id == *func_id { - continue; - } - - if brillig_functions_we_could_not_inline.contains(func_id) { - continue; - } - - // We also don't want to remove entry points - if self.entry_point_to_generated_index.contains_key(func_id) { - continue; - } - - self.functions.remove(func_id); - } - - self - } -} - -/// Result of trying to evaluate an instruction (any instruction) in this pass. -enum EvaluationResult { - /// Nothing was done because the instruction wasn't a call to a brillig function, - /// or some arguments to it were not constants. - NotABrilligCall, - /// The instruction was a call to a brillig function, but we couldn't evaluate it. - CannotEvaluate(FunctionId), - /// The instruction was a call to a brillig function and we were able to evaluate it, - /// returning evaluation memory values. - Evaluated(Vec>), -} - -impl Function { - pub(crate) fn inline_const_brillig_calls( - &mut self, - brillig: &Brillig, - brillig_functions: &BTreeMap, - brillig_functions_we_could_not_inline: &mut HashSet, - ) { - for block_id in self.reachable_blocks() { - for instruction_id in self.dfg[block_id].take_instructions() { - let evaluation_result = - self.evaluate_const_brillig_call(instruction_id, brillig, brillig_functions); - match evaluation_result { - EvaluationResult::NotABrilligCall => { - self.dfg[block_id].instructions_mut().push(instruction_id); - } - EvaluationResult::CannotEvaluate(func_id) => { - self.dfg[block_id].instructions_mut().push(instruction_id); - brillig_functions_we_could_not_inline.insert(func_id); - } - EvaluationResult::Evaluated(memory_values) => { - // Replace the instruction results with the constant values we got - let result_ids = self.dfg.instruction_results(instruction_id).to_vec(); - - let mut memory_index = 0; - for result_id in result_ids { - self.replace_result_id_with_memory_value( - result_id, - block_id, - &memory_values, - &mut memory_index, - ); - } - } - } - } - } - } - - /// Replaces `result_id` by taking memory values from `memory_values` starting at `memory_index` - /// depending on the type of the ValueId (it will read multiple memory values if it's an array). - fn replace_result_id_with_memory_value( - &mut self, - result_id: ValueId, - block_id: BasicBlockId, - memory_values: &[MemoryValue], - memory_index: &mut usize, - ) { - let typ = self.dfg.type_of_value(result_id); - let new_value = - self.new_value_for_type_and_memory_values(typ, block_id, memory_values, memory_index); - self.dfg.set_value_from_id(result_id, new_value); - } - - /// Creates a new value inside this function by reading it from `memory_values` starting at - /// `memory_index` depending on the given Type: if it's an array multiple values will be read - /// and a new `make_array` instruction will be created. - fn new_value_for_type_and_memory_values( - &mut self, - typ: Type, - block_id: BasicBlockId, - memory_values: &[MemoryValue], - memory_index: &mut usize, - ) -> ValueId { - match typ { - Type::Numeric(_) => { - let memory = memory_values[*memory_index]; - *memory_index += 1; - - let field_value = match memory { - MemoryValue::Field(field_value) => field_value, - MemoryValue::Integer(u128_value, _) => u128_value.into(), - }; - self.dfg.make_constant(field_value, typ) - } - Type::Array(types, length) => { - let mut new_array_values = Vector::new(); - for _ in 0..length { - for typ in types.iter() { - let new_value = self.new_value_for_type_and_memory_values( - typ.clone(), - block_id, - memory_values, - memory_index, - ); - new_array_values.push_back(new_value); - } - } - - let instruction = Instruction::MakeArray { - elements: new_array_values, - typ: Type::Array(types, length), - }; - let instruction_id = self.dfg.make_instruction(instruction, None); - self.dfg[block_id].instructions_mut().push(instruction_id); - *self.dfg.instruction_results(instruction_id).first().unwrap() - } - Type::Reference(_) => { - panic!("Unexpected reference type in brillig function result") - } - Type::Slice(_) => { - panic!("Unexpected slice type in brillig function result") - } - Type::Function => { - panic!("Unexpected function type in brillig function result") - } - } - } - - /// Tries to evaluate an instruction if it's a call that points to a brillig function, - /// and all its arguments are constant. - /// We do this by directly executing the function with a brillig VM. - fn evaluate_const_brillig_call( - &self, - instruction_id: InstructionId, - brillig: &Brillig, - brillig_functions: &BTreeMap, - ) -> EvaluationResult { - let instruction = &self.dfg[instruction_id]; - let Instruction::Call { func: func_id, arguments } = instruction else { - return EvaluationResult::NotABrilligCall; - }; - - let func_value = &self.dfg[*func_id]; - let Value::Function(func_id) = func_value else { - return EvaluationResult::NotABrilligCall; - }; - - let Some(func) = brillig_functions.get(func_id) else { - return EvaluationResult::NotABrilligCall; - }; - - if !arguments.iter().all(|argument| self.dfg.is_constant(*argument)) { - return EvaluationResult::CannotEvaluate(*func_id); - } - - let mut brillig_arguments = Vec::new(); - for argument in arguments { - let typ = self.dfg.type_of_value(*argument); - let Some(parameter) = FunctionContext::try_ssa_type_to_parameter(&typ) else { - return EvaluationResult::CannotEvaluate(*func_id); - }; - brillig_arguments.push(parameter); - } - - // Check that return value types are supported by brillig - for return_id in func.returns().iter() { - let typ = func.dfg.type_of_value(*return_id); - if FunctionContext::try_ssa_type_to_parameter(&typ).is_none() { - return EvaluationResult::CannotEvaluate(*func_id); - } - } - - let Ok(generated_brillig) = gen_brillig_for(func, brillig_arguments, brillig) else { - return EvaluationResult::CannotEvaluate(*func_id); - }; - - let mut calldata = Vec::new(); - for argument in arguments { - value_id_to_calldata(*argument, &self.dfg, &mut calldata); - } - - let bytecode = &generated_brillig.byte_code; - let foreign_call_results = Vec::new(); - let black_box_solver = StubbedBlackBoxSolver; - let profiling_active = false; - let mut vm = - VM::new(calldata, bytecode, foreign_call_results, &black_box_solver, profiling_active); - let vm_status: VMStatus<_> = vm.process_opcodes(); - let VMStatus::Finished { return_data_offset, return_data_size } = vm_status else { - return EvaluationResult::CannotEvaluate(*func_id); - }; - - let memory = - vm.get_memory()[return_data_offset..(return_data_offset + return_data_size)].to_vec(); - - EvaluationResult::Evaluated(memory) - } -} - -fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut Vec) { - if let Some(value) = dfg.get_numeric_constant(value_id) { - calldata.push(value); - return; - } - - if let Some((values, _type)) = dfg.get_array_constant(value_id) { - for value in values { - value_id_to_calldata(value, dfg, calldata); - } - return; - } - - panic!("Expected ValueId to be numeric constant or array constant"); -} - -#[cfg(test)] -mod test { - use crate::ssa::opt::assert_normalized_ssa_equals; - - use super::Ssa; - - #[test] - fn inlines_brillig_call_without_arguments() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = call f1() -> Field - return v0 - } - - brillig(inline) fn one f1 { - b0(): - v0 = add Field 2, Field 3 - return v0 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - return Field 5 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn inlines_brillig_call_with_two_field_arguments() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = call f1(Field 2, Field 3) -> Field - return v0 - } - - brillig(inline) fn one f1 { - b0(v0: Field, v1: Field): - v2 = add v0, v1 - return v2 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - return Field 5 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn inlines_brillig_call_with_two_i32_arguments() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = call f1(i32 2, i32 3) -> i32 - return v0 - } - - brillig(inline) fn one f1 { - b0(v0: i32, v1: i32): - v2 = add v0, v1 - return v2 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - return i32 5 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn inlines_brillig_call_with_array_return() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = call f1(Field 2, Field 3, Field 4) -> [Field; 3] - return v0 - } - - brillig(inline) fn one f1 { - b0(v0: Field, v1: Field, v2: Field): - v3 = make_array [v0, v1, v2] : [Field; 3] - return v3 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - v3 = make_array [Field 2, Field 3, Field 4] : [Field; 3] - return v3 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn inlines_brillig_call_with_composite_array_return() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = call f1(Field 2, i32 3, Field 4, i32 5) -> [(Field, i32); 2] - return v0 - } - - brillig(inline) fn one f1 { - b0(v0: Field, v1: i32, v2: i32, v3: Field): - v4 = make_array [v0, v1, v2, v3] : [(Field, i32); 2] - return v4 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - v4 = make_array [Field 2, i32 3, Field 4, i32 5] : [(Field, i32); 2] - return v4 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } - - #[test] - fn inlines_brillig_call_with_array_arguments() { - let src = " - acir(inline) fn main f0 { - b0(): - v0 = make_array [Field 2, Field 3] : [Field; 2] - v1 = call f1(v0) -> Field - return v1 - } - - brillig(inline) fn one f1 { - b0(v0: [Field; 2]): - inc_rc v0 - v2 = array_get v0, index u32 0 -> Field - v4 = array_get v0, index u32 1 -> Field - v5 = add v2, v4 - dec_rc v0 - return v5 - } - "; - let ssa = Ssa::from_str(src).unwrap(); - let brillig = ssa.to_brillig(false); - - let expected = " - acir(inline) fn main f0 { - b0(): - v2 = make_array [Field 2, Field 3] : [Field; 2] - return Field 5 - } - "; - let ssa = ssa.inline_const_brillig_calls(&brillig); - assert_normalized_ssa_equals(ssa, expected); - } -} diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index c3160565b23..098f62bceba 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -11,7 +11,6 @@ mod constant_folding; mod defunctionalize; mod die; pub(crate) mod flatten_cfg; -mod inline_const_brillig_calls; mod inlining; mod mem2reg; mod normalize_value_ids; From f9a670dd93c0897f0297f1fb731d0ff16c6321dd Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 15:29:40 -0300 Subject: [PATCH 05/12] Add comment --- compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index b7ab3c65b49..9b7e6a91e41 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -627,6 +627,7 @@ enum EvaluationResult { /// or some arguments to it were not constants. NotABrilligCall, /// The instruction was a call to a brillig function, but we couldn't evaluate it. + /// This can occur in the situation where the brillig function reaches a "trap" or a foreign call opcode. CannotEvaluate(FunctionId), /// The instruction was a call to a brillig function and we were able to evaluate it, /// returning evaluation memory values. From de53916f02c0bad75cb724935f719ad0b1f35e58 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 15:30:51 -0300 Subject: [PATCH 06/12] Use `?` --- .../noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 2ea4accc2c5..442673a41ae 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -74,12 +74,9 @@ impl FunctionContext { Some(BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))) } Type::Array(item_type, size) => { - let mut parameters = Vec::new(); + let mut parameters = Vec::with_capacity(item_type.len()); for item_typ in item_type.iter() { - let Some(param) = FunctionContext::try_ssa_type_to_parameter(item_typ) else { - return None; - }; - parameters.push(param); + parameters.push(FunctionContext::try_ssa_type_to_parameter(item_typ)?); } Some(BrilligParameter::Array(parameters, *size)) } From d67392d132b3559d87da0e70341f7256ee04c433 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 15:58:31 -0300 Subject: [PATCH 07/12] Refactor --- .../src/ssa/opt/constant_folding.rs | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 9b7e6a91e41..d51bf23a579 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -326,36 +326,15 @@ impl<'brillig> Context<'brillig> { brillig_info: Option, brillig_functions_we_could_not_inline: &mut HashSet, ) -> Vec { - // Check if this is a call to a brillig function with all constant arguments. - // If so, we can try to evaluate that function and replace the results with the evaluation results. - if let Some(brillig_info) = brillig_info { - let evaluation_result = Self::evaluate_const_brillig_call( - &instruction, - brillig_info.brillig, - brillig_info.brillig_functions, - dfg, - ); - - match evaluation_result { - EvaluationResult::NotABrilligCall => (), - EvaluationResult::CannotEvaluate(id) => { - brillig_functions_we_could_not_inline.insert(id); - } - EvaluationResult::Evaluated(memory_values) => { - let mut memory_index = 0; - let new_results = vecmap(old_results, |old_result| { - let typ = dfg.type_of_value(*old_result); - Self::new_value_for_type_and_memory_values( - typ, - block, - &memory_values, - &mut memory_index, - dfg, - ) - }); - return new_results; - } - } + if let Some(new_results) = Self::try_inline_brillig_call_with_all_constants( + &instruction, + old_results, + block, + dfg, + brillig_info, + brillig_functions_we_could_not_inline, + ) { + return new_results; } let ctrl_typevars = instruction @@ -463,6 +442,46 @@ impl<'brillig> Context<'brillig> { results_for_instruction.get(&predicate)?.get(block, &mut self.dom) } + /// Checks if the given instruction is a call to a brillig function with all constant arguments. + /// If so, we can try to evaluate that function and replace the results with the evaluation results. + fn try_inline_brillig_call_with_all_constants( + instruction: &Instruction, + old_results: &[ValueId], + block: BasicBlockId, + dfg: &mut DataFlowGraph, + brillig_info: Option, + brillig_functions_we_could_not_inline: &mut HashSet, + ) -> Option> { + let evaluation_result = Self::evaluate_const_brillig_call( + instruction, + brillig_info?.brillig, + brillig_info?.brillig_functions, + dfg, + ); + + match evaluation_result { + EvaluationResult::NotABrilligCall => None, + EvaluationResult::CannotEvaluate(func_id) => { + brillig_functions_we_could_not_inline.insert(func_id); + None + } + EvaluationResult::Evaluated(memory_values) => { + let mut memory_index = 0; + let new_results = vecmap(old_results, |old_result| { + let typ = dfg.type_of_value(*old_result); + Self::new_value_for_type_and_memory_values( + typ, + block, + &memory_values, + &mut memory_index, + dfg, + ) + }); + Some(new_results) + } + } + } + /// Tries to evaluate an instruction if it's a call that points to a brillig function, /// and all its arguments are constant. /// We do this by directly executing the function with a brillig VM. From 7b3bd1c218544b9859bca912499efb4575199a49 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 15:58:52 -0300 Subject: [PATCH 08/12] Disallow references in this optimization --- .../src/brillig/brillig_gen/brillig_fn.rs | 16 -------------- .../src/ssa/opt/constant_folding.rs | 22 ++++++++++++++++--- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs index 442673a41ae..2779be103cd 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_fn.rs @@ -68,22 +68,6 @@ impl FunctionContext { } } - pub(crate) fn try_ssa_type_to_parameter(typ: &Type) -> Option { - match typ { - Type::Numeric(_) | Type::Reference(_) => { - Some(BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))) - } - Type::Array(item_type, size) => { - let mut parameters = Vec::with_capacity(item_type.len()); - for item_typ in item_type.iter() { - parameters.push(FunctionContext::try_ssa_type_to_parameter(item_typ)?); - } - Some(BrilligParameter::Array(parameters, *size)) - } - _ => None, - } - } - /// Collects the return values of a given function pub(crate) fn return_values(func: &Function) -> Vec { func.returns() diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index d51bf23a579..6a84285ac65 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -32,7 +32,8 @@ use iter_extended::vecmap; use crate::{ brillig::{ - brillig_gen::{brillig_fn::FunctionContext, gen_brillig_for}, + brillig_gen::gen_brillig_for, + brillig_ir::{artifact::BrilligParameter, brillig_variable::get_bit_size_from_ssa_type}, Brillig, }, ssa::{ @@ -511,7 +512,7 @@ impl<'brillig> Context<'brillig> { let mut brillig_arguments = Vec::new(); for argument in arguments { let typ = dfg.type_of_value(*argument); - let Some(parameter) = FunctionContext::try_ssa_type_to_parameter(&typ) else { + let Some(parameter) = type_to_brillig_parameter(&typ) else { return EvaluationResult::CannotEvaluate(*func_id); }; brillig_arguments.push(parameter); @@ -520,7 +521,7 @@ impl<'brillig> Context<'brillig> { // Check that return value types are supported by brillig for return_id in func.returns().iter() { let typ = func.dfg.type_of_value(*return_id); - if FunctionContext::try_ssa_type_to_parameter(&typ).is_none() { + if type_to_brillig_parameter(&typ).is_none() { return EvaluationResult::CannotEvaluate(*func_id); } } @@ -653,6 +654,21 @@ enum EvaluationResult { Evaluated(Vec>), } +/// Similar to FunctionContext::ssa_type_to_parameter but never panics and disallows reference types. +pub(crate) fn type_to_brillig_parameter(typ: &Type) -> Option { + match typ { + Type::Numeric(_) => Some(BrilligParameter::SingleAddr(get_bit_size_from_ssa_type(typ))), + Type::Array(item_type, size) => { + let mut parameters = Vec::with_capacity(item_type.len()); + for item_typ in item_type.iter() { + parameters.push(type_to_brillig_parameter(item_typ)?); + } + Some(BrilligParameter::Array(parameters, *size)) + } + _ => None, + } +} + fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut Vec) { if let Some(value) = dfg.get_numeric_constant(value_id) { calldata.push(value); From a2d7b19fabb513f53861b6b42c2817c087937b5a Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 16:43:33 -0300 Subject: [PATCH 09/12] Remove unused brillig functions in a separate pass --- compiler/noirc_evaluator/src/ssa.rs | 1 + .../src/ssa/opt/constant_folding.rs | 81 +++++----------- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 + .../opt/remove_unused_brillig_functions.rs | 94 +++++++++++++++++++ 4 files changed, 117 insertions(+), 60 deletions(-) create mode 100644 compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 344ac114a03..1acdaf57aab 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -153,6 +153,7 @@ pub(crate) fn optimize_into_acir( |ssa| ssa.fold_constants_with_brillig(&brillig), "After Constant Folding with Brillig:", ) + .run_pass(Ssa::remove_unused_brillig_functions, "After Remove Unused Brillig Functions:") .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") .finish(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 6a84285ac65..16fa97a6520 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -57,10 +57,8 @@ impl Ssa { /// See [`constant_folding`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants(mut self) -> Ssa { - let mut brillig_functions_we_could_not_inline = HashSet::new(); - for function in self.functions.values_mut() { - function.constant_fold(false, None, &mut brillig_functions_we_could_not_inline); + function.constant_fold(false, None); } self } @@ -72,10 +70,8 @@ impl Ssa { /// See [`constant_folding`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn fold_constants_using_constraints(mut self) -> Ssa { - let mut brillig_functions_we_could_not_inline = HashSet::new(); - for function in self.functions.values_mut() { - function.constant_fold(true, None, &mut brillig_functions_we_could_not_inline); + function.constant_fold(true, None); } self } @@ -91,28 +87,10 @@ impl Ssa { }; } - // Keep track of which brillig functions we couldn't completely inline: we'll remove the ones we could. - let mut brillig_functions_we_could_not_inline = HashSet::new(); - let brillig_info = Some(BrilligInfo { brillig, brillig_functions: &brillig_functions }); for function in self.functions.values_mut() { - function.constant_fold(false, brillig_info, &mut brillig_functions_we_could_not_inline); - } - - // Remove the brillig functions that are no longer called - for func_id in brillig_functions.keys() { - // We never want to remove the main function (it could be `unconstrained` or it - // could have been turned into brillig if `--force-brillig` was given). - // We also don't want to remove entry points. - if self.main_id == *func_id - || brillig_functions_we_could_not_inline.contains(func_id) - || self.entry_point_to_generated_index.contains_key(func_id) - { - continue; - } - - self.functions.remove(func_id); + function.constant_fold(false, brillig_info); } self @@ -126,7 +104,6 @@ impl Function { &mut self, use_constraint_info: bool, brillig_info: Option, - brillig_functions_we_could_not_inline: &mut HashSet, ) { let mut context = Context::new(self, use_constraint_info, brillig_info); context.block_queue.push_back(self.entry_block()); @@ -137,7 +114,7 @@ impl Function { } context.visited_blocks.insert(block); - context.fold_constants_in_block(self, block, brillig_functions_we_could_not_inline); + context.fold_constants_in_block(self, block); } } } @@ -202,12 +179,7 @@ impl<'brillig> Context<'brillig> { } } - fn fold_constants_in_block( - &mut self, - function: &mut Function, - block: BasicBlockId, - brillig_functions_we_could_not_inline: &mut HashSet, - ) { + fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { let instructions = function.dfg[block].take_instructions(); let mut side_effects_enabled_var = @@ -219,7 +191,6 @@ impl<'brillig> Context<'brillig> { block, instruction_id, &mut side_effects_enabled_var, - brillig_functions_we_could_not_inline, ); } self.block_queue.extend(function.dfg[block].successors()); @@ -231,7 +202,6 @@ impl<'brillig> Context<'brillig> { mut block: BasicBlockId, id: InstructionId, side_effects_enabled_var: &mut ValueId, - brillig_functions_we_could_not_inline: &mut HashSet, ) { let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); @@ -256,16 +226,25 @@ impl<'brillig> Context<'brillig> { } } - // Otherwise, try inserting the instruction again to apply any optimizations using the newly resolved inputs. - let new_results = Self::push_instruction( - id, - instruction.clone(), + let new_results = + // First try to inline a call to a brillig function with all constant arguments. + Self::try_inline_brillig_call_with_all_constants( + &instruction, &old_results, block, dfg, self.brillig_info, - brillig_functions_we_could_not_inline, - ); + ) + .unwrap_or_else(|| { + // Otherwise, try inserting the instruction again to apply any optimizations using the newly resolved inputs. + Self::push_instruction( + id, + instruction.clone(), + &old_results, + block, + dfg, + ) + }); Self::replace_result_ids(dfg, &old_results, &new_results); @@ -324,20 +303,7 @@ impl<'brillig> Context<'brillig> { old_results: &[ValueId], block: BasicBlockId, dfg: &mut DataFlowGraph, - brillig_info: Option, - brillig_functions_we_could_not_inline: &mut HashSet, ) -> Vec { - if let Some(new_results) = Self::try_inline_brillig_call_with_all_constants( - &instruction, - old_results, - block, - dfg, - brillig_info, - brillig_functions_we_could_not_inline, - ) { - return new_results; - } - let ctrl_typevars = instruction .requires_ctrl_typevars() .then(|| vecmap(old_results, |result| dfg.type_of_value(*result))); @@ -451,7 +417,6 @@ impl<'brillig> Context<'brillig> { block: BasicBlockId, dfg: &mut DataFlowGraph, brillig_info: Option, - brillig_functions_we_could_not_inline: &mut HashSet, ) -> Option> { let evaluation_result = Self::evaluate_const_brillig_call( instruction, @@ -461,11 +426,7 @@ impl<'brillig> Context<'brillig> { ); match evaluation_result { - EvaluationResult::NotABrilligCall => None, - EvaluationResult::CannotEvaluate(func_id) => { - brillig_functions_we_could_not_inline.insert(func_id); - None - } + EvaluationResult::NotABrilligCall | EvaluationResult::CannotEvaluate(_) => None, EvaluationResult::Evaluated(memory_values) => { let mut memory_index = 0; let new_results = vecmap(old_results, |old_result| { diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 098f62bceba..2e865b50de7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -18,6 +18,7 @@ mod rc; mod remove_bit_shifts; mod remove_enable_side_effects; mod remove_if_else; +mod remove_unused_brillig_functions; mod resolve_is_unconstrained; mod runtime_separation; mod simplify_cfg; diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs new file mode 100644 index 00000000000..307042d1121 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs @@ -0,0 +1,94 @@ +use std::collections::HashSet; + +use crate::ssa::{ + ir::{function::RuntimeType, instruction::Instruction, value::Value}, + Ssa, +}; + +impl Ssa { + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn remove_unused_brillig_functions(mut self) -> Ssa { + // Compute the set of all brillig functions that exist in the program + let mut brillig_function_ids = HashSet::new(); + for (func_id, func) in &self.functions { + if let RuntimeType::Brillig(..) = func.runtime() { + brillig_function_ids.insert(*func_id); + }; + } + + // Remove from the above set functions that are called + for function in self.functions.values() { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + brillig_function_ids.remove(func_id); + } + } + } + + // The ones that remain are never called: let's remove them. + for func_id in brillig_function_ids { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given). + // We also don't want to remove entry points. + if self.main_id == func_id || self.entry_point_to_generated_index.contains_key(&func_id) + { + continue; + } + + self.functions.remove(&func_id); + } + self + } +} + +#[cfg(test)] +mod test { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn removes_unused_brillig_functions() { + // In the SSA below the function `two` is never called so we expected it to be removed. + let src = " + acir(inline) fn main f0 { + b0(): + call f1() + return + } + + brillig(inline) fn one f1 { + b0(): + return + } + + brillig(inline) fn two f2 { + b0(): + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + + let expected = " + acir(inline) fn main f0 { + b0(): + call f1() + return + } + brillig(inline) fn one f1 { + b0(): + return + } + "; + let ssa = ssa.remove_unused_brillig_functions(); + assert_normalized_ssa_equals(ssa, expected); + } +} From e14748ba602740649e7d233b644906d4b8c35624 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 16:44:25 -0300 Subject: [PATCH 10/12] Add a comment --- compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 16fa97a6520..35656d9b0e1 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -76,6 +76,8 @@ impl Ssa { self } + /// Performs constant folding on each instruction while also replacing calls to brillig functions + /// with all constant arguments by trying to evaluate those calls. #[tracing::instrument(level = "trace", skip(self, brillig))] pub(crate) fn fold_constants_with_brillig(mut self, brillig: &Brillig) -> Ssa { // Collect all brillig functions so that later we can find them when processing a call instruction From ab1ebf9bfdf9dca08bf77cd4e3818f9c961dc865 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 17:02:56 -0300 Subject: [PATCH 11/12] cargo fmt --- compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 35656d9b0e1..0227fda0e8c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -228,7 +228,7 @@ impl<'brillig> Context<'brillig> { } } - let new_results = + let new_results = // First try to inline a call to a brillig function with all constant arguments. Self::try_inline_brillig_call_with_all_constants( &instruction, From 10f65b41bfa67f113c3d3cf8916d94d5a89425d9 Mon Sep 17 00:00:00 2001 From: Ary Borenszweig Date: Tue, 19 Nov 2024 19:41:08 -0300 Subject: [PATCH 12/12] Remove unused brillig functions in same pass as it makes more sense --- compiler/noirc_evaluator/src/ssa.rs | 1 - .../src/ssa/opt/constant_folding.rs | 39 ++++++++ compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 - .../opt/remove_unused_brillig_functions.rs | 94 ------------------- 4 files changed, 39 insertions(+), 96 deletions(-) delete mode 100644 compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 1acdaf57aab..344ac114a03 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -153,7 +153,6 @@ pub(crate) fn optimize_into_acir( |ssa| ssa.fold_constants_with_brillig(&brillig), "After Constant Folding with Brillig:", ) - .run_pass(Ssa::remove_unused_brillig_functions, "After Remove Unused Brillig Functions:") .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") .finish(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 0227fda0e8c..019bace33a3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -95,6 +95,45 @@ impl Ssa { function.constant_fold(false, brillig_info); } + // It could happen that we inlined all calls to a given brillig function. + // In that case it's unused so we can remove it. This is what we check next. + self.remove_unused_brillig_functions(brillig_functions) + } + + fn remove_unused_brillig_functions( + mut self, + mut brillig_functions: BTreeMap, + ) -> Ssa { + // Remove from the above map functions that are called + for function in self.functions.values() { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + brillig_functions.remove(func_id); + } + } + } + + // The ones that remain are never called: let's remove them. + for func_id in brillig_functions.keys() { + // We never want to remove the main function (it could be `unconstrained` or it + // could have been turned into brillig if `--force-brillig` was given). + // We also don't want to remove entry points. + if self.main_id == *func_id || self.entry_point_to_generated_index.contains_key(func_id) + { + continue; + } + + self.functions.remove(func_id); + } + self } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 2e865b50de7..098f62bceba 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -18,7 +18,6 @@ mod rc; mod remove_bit_shifts; mod remove_enable_side_effects; mod remove_if_else; -mod remove_unused_brillig_functions; mod resolve_is_unconstrained; mod runtime_separation; mod simplify_cfg; diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs deleted file mode 100644 index 307042d1121..00000000000 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_unused_brillig_functions.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::collections::HashSet; - -use crate::ssa::{ - ir::{function::RuntimeType, instruction::Instruction, value::Value}, - Ssa, -}; - -impl Ssa { - #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn remove_unused_brillig_functions(mut self) -> Ssa { - // Compute the set of all brillig functions that exist in the program - let mut brillig_function_ids = HashSet::new(); - for (func_id, func) in &self.functions { - if let RuntimeType::Brillig(..) = func.runtime() { - brillig_function_ids.insert(*func_id); - }; - } - - // Remove from the above set functions that are called - for function in self.functions.values() { - for block_id in function.reachable_blocks() { - for instruction_id in function.dfg[block_id].instructions() { - let instruction = &function.dfg[*instruction_id]; - let Instruction::Call { func: func_id, arguments: _ } = instruction else { - continue; - }; - - let func_value = &function.dfg[*func_id]; - let Value::Function(func_id) = func_value else { continue }; - - brillig_function_ids.remove(func_id); - } - } - } - - // The ones that remain are never called: let's remove them. - for func_id in brillig_function_ids { - // We never want to remove the main function (it could be `unconstrained` or it - // could have been turned into brillig if `--force-brillig` was given). - // We also don't want to remove entry points. - if self.main_id == func_id || self.entry_point_to_generated_index.contains_key(&func_id) - { - continue; - } - - self.functions.remove(&func_id); - } - self - } -} - -#[cfg(test)] -mod test { - use crate::ssa::opt::assert_normalized_ssa_equals; - - use super::Ssa; - - #[test] - fn removes_unused_brillig_functions() { - // In the SSA below the function `two` is never called so we expected it to be removed. - let src = " - acir(inline) fn main f0 { - b0(): - call f1() - return - } - - brillig(inline) fn one f1 { - b0(): - return - } - - brillig(inline) fn two f2 { - b0(): - return - } - "; - let ssa = Ssa::from_str(src).unwrap(); - - let expected = " - acir(inline) fn main f0 { - b0(): - call f1() - return - } - brillig(inline) fn one f1 { - b0(): - return - } - "; - let ssa = ssa.remove_unused_brillig_functions(); - assert_normalized_ssa_equals(ssa, expected); - } -}