From dfa5126f2c65843c34701cacddf2cbcfb0d7ff11 Mon Sep 17 00:00:00 2001 From: jfecher Date: Mon, 18 Mar 2024 13:44:03 -0500 Subject: [PATCH] feat: RC optimization pass (#4560) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description ## Problem\* `inc_rc` and `dec_rc` instructions can bloat unconstrained code with unneeded rc changes on otherwise immutable arrays. ## Summary\* Adds an optimization pass to remove `inc_rc vN .. dec_rc vN` pairs as long as there are not `array_set` instructions in the same function which may mutate an array of the same type. ## Additional Context I thought of tracking all inc and dec instructions in the function originally but eventually limited it to finding just those in the function's entry block and exit block respectively. The later is the only place we currently issue dec_rc instructions anyway. This restriction greatly simplifies the code since we do not have to merge intermediate results across several blocks, nor do we have to handle inc/dec in loops. This pass applies to both acir and brillig functions since acir functions can still be called in an unconstrained context. ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Álvaro Rodríguez --- compiler/noirc_evaluator/src/ssa.rs | 6 +- .../src/ssa/function_builder/mod.rs | 73 ++-- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 1 + compiler/noirc_evaluator/src/ssa/opt/rc.rs | 327 ++++++++++++++++++ .../src/monomorphization/debug.rs | 4 +- 5 files changed, 379 insertions(+), 32 deletions(-) create mode 100644 compiler/noirc_evaluator/src/ssa/opt/rc.rs diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 56cb76adbe4..808cf7533c9 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -48,6 +48,7 @@ pub(crate) fn optimize_into_acir( let ssa_gen_span_guard = ssa_gen_span.enter(); let ssa = SsaBuilder::new(program, print_ssa_passes, force_brillig_output)? .run_pass(Ssa::defunctionalize, "After Defunctionalization:") + .run_pass(Ssa::remove_paired_rc, "After Removing Paired rc_inc & rc_decs:") .run_pass(Ssa::inline_functions, "After Inlining:") // Run mem2reg with the CFG separated into blocks .run_pass(Ssa::mem2reg, "After Mem2Reg:") @@ -59,10 +60,7 @@ pub(crate) fn optimize_into_acir( // Run mem2reg once more with the flattened CFG to catch any remaining loads/stores .run_pass(Ssa::mem2reg, "After Mem2Reg:") .run_pass(Ssa::fold_constants, "After Constant Folding:") - .run_pass( - Ssa::fold_constants_using_constraints, - "After Constant Folding With Constraint Info:", - ) + .run_pass(Ssa::fold_constants_using_constraints, "After Constraint Folding:") .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") .finish(); diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 2c39c83b342..aa5a7fedd92 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -195,12 +195,9 @@ impl FunctionBuilder { self.call_stack.clone() } - /// Insert a Load instruction at the end of the current block, loading from the given offset - /// of the given address which should point to a previous Allocate instruction. Note that - /// this is limited to loading a single value. Loading multiple values (such as a tuple) - /// will require multiple loads. - /// 'offset' is in units of FieldElements here. So loading the fourth FieldElement stored in - /// an array will have an offset of 3. + /// Insert a Load instruction at the end of the current block, loading from the given address + /// which should point to a previous Allocate instruction. Note that this is limited to loading + /// a single value. Loading multiple values (such as a tuple) will require multiple loads. /// Returns the element that was loaded. pub(crate) fn insert_load(&mut self, address: ValueId, type_to_load: Type) -> ValueId { self.insert_instruction(Instruction::Load { address }, Some(vec![type_to_load])).first() @@ -221,11 +218,9 @@ impl FunctionBuilder { operator: BinaryOp, rhs: ValueId, ) -> ValueId { - assert_eq!( - self.type_of_value(lhs), - self.type_of_value(rhs), - "ICE - Binary instruction operands must have the same type" - ); + let lhs_type = self.type_of_value(lhs); + let rhs_type = self.type_of_value(rhs); + assert_eq!(lhs_type, rhs_type, "ICE - Binary instruction operands must have the same type"); let instruction = Instruction::Binary(Binary { lhs, rhs, operator }); self.insert_instruction(instruction, None).first() } @@ -309,6 +304,18 @@ impl FunctionBuilder { self.insert_instruction(Instruction::ArraySet { array, index, value }, None).first() } + /// Insert an instruction to increment an array's reference count. This only has an effect + /// in unconstrained code where arrays are reference counted and copy on write. + pub(crate) fn insert_inc_rc(&mut self, value: ValueId) { + self.insert_instruction(Instruction::IncrementRc { value }, None); + } + + /// Insert an instruction to decrement an array's reference count. This only has an effect + /// in unconstrained code where arrays are reference counted and copy on write. + pub(crate) fn insert_dec_rc(&mut self, value: ValueId) { + self.insert_instruction(Instruction::DecrementRc { value }, None); + } + /// Terminates the current block with the given terminator instruction fn terminate_block_with(&mut self, terminator: TerminatorInstruction) { self.current_function.dfg.set_block_terminator(self.current_block, terminator); @@ -384,51 +391,65 @@ impl FunctionBuilder { /// within the given value. If the given value is not an array and does not contain /// any arrays, this does nothing. pub(crate) fn increment_array_reference_count(&mut self, value: ValueId) { - self.update_array_reference_count(value, true); + self.update_array_reference_count(value, true, None); } /// Insert instructions to decrement the reference count of any array(s) stored /// within the given value. If the given value is not an array and does not contain /// any arrays, this does nothing. pub(crate) fn decrement_array_reference_count(&mut self, value: ValueId) { - self.update_array_reference_count(value, false); + self.update_array_reference_count(value, false, None); } /// Increment or decrement the given value's reference count if it is an array. /// If it is not an array, this does nothing. Note that inc_rc and dec_rc instructions /// are ignored outside of unconstrained code. - pub(crate) fn update_array_reference_count(&mut self, value: ValueId, increment: bool) { + fn update_array_reference_count( + &mut self, + value: ValueId, + increment: bool, + load_address: Option, + ) { match self.type_of_value(value) { Type::Numeric(_) => (), Type::Function => (), Type::Reference(element) => { if element.contains_an_array() { - let value = self.insert_load(value, element.as_ref().clone()); - self.increment_array_reference_count(value); + let reference = value; + let value = self.insert_load(reference, element.as_ref().clone()); + self.update_array_reference_count(value, increment, Some(reference)); } } typ @ Type::Array(..) | typ @ Type::Slice(..) => { // If there are nested arrays or slices, we wait until ArrayGet // is issued to increment the count of that array. - let instruction = if increment { - Instruction::IncrementRc { value } - } else { - Instruction::DecrementRc { value } + let update_rc = |this: &mut Self, value| { + if increment { + this.insert_inc_rc(value); + } else { + this.insert_dec_rc(value); + } }; - self.insert_instruction(instruction, None); + + update_rc(self, value); + let dfg = &self.current_function.dfg; // This is a bit odd, but in brillig the inc_rc instruction operates on // a copy of the array's metadata, so we need to re-store a loaded array // even if there have been no other changes to it. - if let Value::Instruction { instruction, .. } = &self.current_function.dfg[value] { - let instruction = &self.current_function.dfg[*instruction]; + if let Some(address) = load_address { + // If we already have a load from the Type::Reference case, avoid inserting + // another load and rc update. + self.insert_store(address, value); + } else if let Value::Instruction { instruction, .. } = &dfg[value] { + let instruction = &dfg[*instruction]; if let Instruction::Load { address } = instruction { // We can't re-use `value` in case the original address was stored // to again in the meantime. So introduce another load. let address = *address; - let value = self.insert_load(address, typ); - self.insert_instruction(Instruction::IncrementRc { value }, None); - self.insert_store(address, value); + let new_load = self.insert_load(address, typ); + update_rc(self, new_load); + self.insert_store(address, new_load); } } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index a315695f7db..8f98b3fb17f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -12,6 +12,7 @@ mod die; pub(crate) mod flatten_cfg; mod inlining; mod mem2reg; +mod rc; mod remove_bit_shifts; mod simplify_cfg; mod unrolling; diff --git a/compiler/noirc_evaluator/src/ssa/opt/rc.rs b/compiler/noirc_evaluator/src/ssa/opt/rc.rs new file mode 100644 index 00000000000..4766bc3e8d2 --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/rc.rs @@ -0,0 +1,327 @@ +use std::collections::{HashMap, HashSet}; + +use crate::ssa::{ + ir::{ + basic_block::BasicBlockId, + function::Function, + instruction::{Instruction, InstructionId, TerminatorInstruction}, + types::Type, + value::ValueId, + }, + ssa_gen::Ssa, +}; + +impl Ssa { + /// This pass removes `inc_rc` and `dec_rc` instructions + /// as long as there are no `array_set` instructions to an array + /// of the same type in between. + /// + /// Note that this pass is very conservative since the array_set + /// instruction does not need to be to the same array. This is because + /// the given array may alias another array (e.g. function parameters or + /// a `load`ed array from a reference). + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn remove_paired_rc(mut self) -> Ssa { + for function in self.functions.values_mut() { + remove_paired_rc(function); + } + self + } +} + +#[derive(Default)] +struct Context { + // All inc_rc instructions encountered without a corresponding dec_rc. + // These are only searched for in the first block of a function. + // + // The type of the array being operated on is recorded. + // If an array_set to that array type is encountered, that is also recorded. + inc_rcs: HashMap>, +} + +struct IncRc { + id: InstructionId, + array: ValueId, + possibly_mutated: bool, +} + +/// This function is very simplistic for now. It takes advantage of the fact that dec_rc +/// instructions are currently issued only at the end of a function for parameters and will +/// only check the first and last block for inc & dec rc instructions to be removed. The rest +/// of the function is still checked for array_set instructions. +/// +/// This restriction lets this function largely ignore merging intermediate results from other +/// blocks and handling loops. +fn remove_paired_rc(function: &mut Function) { + // `dec_rc` is only issued for parameters currently so we can speed things + // up a bit by skipping any functions without them. + if !contains_array_parameter(function) { + return; + } + + let mut context = Context::default(); + + context.find_rcs_in_entry_block(function); + context.scan_for_array_sets(function); + let to_remove = context.find_rcs_to_remove(function); + remove_instructions(to_remove, function); +} + +fn contains_array_parameter(function: &mut Function) -> bool { + let mut parameters = function.parameters().iter(); + parameters.any(|parameter| function.dfg.type_of_value(*parameter).contains_an_array()) +} + +impl Context { + fn find_rcs_in_entry_block(&mut self, function: &Function) { + let entry = function.entry_block(); + + for instruction in function.dfg[entry].instructions() { + if let Instruction::IncrementRc { value } = &function.dfg[*instruction] { + let typ = function.dfg.type_of_value(*value); + + // We assume arrays aren't mutated until we find an array_set + let inc_rc = IncRc { id: *instruction, array: *value, possibly_mutated: false }; + self.inc_rcs.entry(typ).or_default().push(inc_rc); + } + } + } + + /// Find each array_set instruction in the function and mark any arrays used + /// by the inc_rc instructions as possibly mutated if they're the same type. + fn scan_for_array_sets(&mut self, function: &Function) { + for block in function.reachable_blocks() { + for instruction in function.dfg[block].instructions() { + if let Instruction::ArraySet { array, .. } = function.dfg[*instruction] { + let typ = function.dfg.type_of_value(array); + if let Some(inc_rcs) = self.inc_rcs.get_mut(&typ) { + for inc_rc in inc_rcs { + inc_rc.possibly_mutated = true; + } + } + } + } + } + } + + /// Find each dec_rc instruction and if the most recent inc_rc instruction for the same value + /// is not possibly mutated, then we can remove them both. Returns each such pair. + fn find_rcs_to_remove(&mut self, function: &Function) -> HashSet { + let last_block = Self::find_last_block(function); + let mut to_remove = HashSet::new(); + + for instruction in function.dfg[last_block].instructions() { + if let Instruction::DecrementRc { value } = &function.dfg[*instruction] { + if let Some(inc_rc) = self.pop_rc_for(*value, function) { + if !inc_rc.possibly_mutated { + to_remove.insert(inc_rc.id); + to_remove.insert(*instruction); + } + } + } + } + + to_remove + } + + /// Finds the block of the function with the Return instruction + fn find_last_block(function: &Function) -> BasicBlockId { + for block in function.reachable_blocks() { + if matches!( + function.dfg[block].terminator(), + Some(TerminatorInstruction::Return { .. }) + ) { + return block; + } + } + + unreachable!("SSA Function {} has no reachable return instruction!", function.id()) + } + + /// Finds and pops the IncRc for the given array value if possible. + fn pop_rc_for(&mut self, value: ValueId, function: &Function) -> Option { + let typ = function.dfg.type_of_value(value); + + let rcs = self.inc_rcs.get_mut(&typ)?; + let position = rcs.iter().position(|inc_rc| inc_rc.array == value)?; + + Some(rcs.remove(position)) + } +} + +fn remove_instructions(to_remove: HashSet, function: &mut Function) { + if !to_remove.is_empty() { + for block in function.reachable_blocks() { + function.dfg[block] + .instructions_mut() + .retain(|instruction| !to_remove.contains(instruction)); + } + } +} + +#[cfg(test)] +mod test { + use std::rc::Rc; + + use crate::ssa::{ + function_builder::FunctionBuilder, + ir::{ + basic_block::BasicBlockId, dfg::DataFlowGraph, function::RuntimeType, + instruction::Instruction, map::Id, types::Type, + }, + }; + + fn count_inc_rcs(block: BasicBlockId, dfg: &DataFlowGraph) -> usize { + dfg[block] + .instructions() + .iter() + .filter(|instruction_id| { + matches!(dfg[**instruction_id], Instruction::IncrementRc { .. }) + }) + .count() + } + + fn count_dec_rcs(block: BasicBlockId, dfg: &DataFlowGraph) -> usize { + dfg[block] + .instructions() + .iter() + .filter(|instruction_id| { + matches!(dfg[**instruction_id], Instruction::DecrementRc { .. }) + }) + .count() + } + + #[test] + fn single_block_fn_return_array() { + // This is the output for the program with a function: + // unconstrained fn foo(x: [Field; 2]) -> [[Field; 2]; 1] { + // [array] + // } + // + // fn foo { + // b0(v0: [Field; 2]): + // inc_rc v0 + // inc_rc v0 + // dec_rc v0 + // return [v0] + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("foo".into(), main_id, RuntimeType::Brillig); + + let inner_array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let v0 = builder.add_parameter(inner_array_type.clone()); + + builder.insert_inc_rc(v0); + builder.insert_inc_rc(v0); + builder.insert_dec_rc(v0); + + let outer_array_type = Type::Array(Rc::new(vec![inner_array_type]), 1); + let array = builder.array_constant(vec![v0].into(), outer_array_type); + builder.terminate_with_return(vec![array]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 0); + } + + #[test] + fn single_block_mutation() { + // fn mutator(mut array: [Field; 2]) { + // array[0] = 5; + // } + // + // fn mutator { + // b0(v0: [Field; 2]): + // v1 = allocate + // store v0 at v1 + // inc_rc v0 + // v2 = load v1 + // v7 = array_set v2, index u64 0, value Field 5 + // store v7 at v1 + // dec_rc v0 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("mutator".into(), main_id, RuntimeType::Acir); + + let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let v0 = builder.add_parameter(array_type.clone()); + + let v1 = builder.insert_allocate(array_type.clone()); + builder.insert_store(v1, v0); + builder.insert_inc_rc(v0); + let v2 = builder.insert_load(v1, array_type); + + let zero = builder.numeric_constant(0u128, Type::unsigned(64)); + let five = builder.field_constant(5u128); + let v7 = builder.insert_array_set(v2, zero, five); + + builder.insert_store(v1, v7); + builder.insert_dec_rc(v0); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + // No changes, the array is possibly mutated + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 1); + } + + // Similar to single_block_mutation but for a function which + // uses a mutable reference parameter. + #[test] + fn single_block_mutation_through_reference() { + // fn mutator2(array: &mut [Field; 2]) { + // array[0] = 5; + // } + // + // fn mutator2 { + // b0(v0: &mut [Field; 2]): + // v1 = load v0 + // inc_rc v1 + // store v1 at v0 + // v2 = load v0 + // v7 = array_set v2, index u64 0, value Field 5 + // store v7 at v0 + // v8 = load v0 + // dec_rc v8 + // store v8 at v0 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("mutator2".into(), main_id, RuntimeType::Acir); + + let array_type = Type::Array(Rc::new(vec![Type::field()]), 2); + let reference_type = Type::Reference(Rc::new(array_type.clone())); + + let v0 = builder.add_parameter(reference_type); + + let v1 = builder.insert_load(v0, array_type.clone()); + builder.insert_inc_rc(v1); + builder.insert_store(v0, v1); + + let v2 = builder.insert_load(v1, array_type.clone()); + let zero = builder.numeric_constant(0u128, Type::unsigned(64)); + let five = builder.field_constant(5u128); + let v7 = builder.insert_array_set(v2, zero, five); + + builder.insert_store(v0, v7); + let v8 = builder.insert_load(v0, array_type); + builder.insert_dec_rc(v8); + builder.insert_store(v0, v8); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().remove_paired_rc(); + let main = ssa.main(); + let entry = main.entry_block(); + + // No changes, the array is possibly mutated + assert_eq!(count_inc_rcs(entry, &main.dfg), 1); + assert_eq!(count_dec_rcs(entry, &main.dfg), 1); + } +} diff --git a/compiler/noirc_frontend/src/monomorphization/debug.rs b/compiler/noirc_frontend/src/monomorphization/debug.rs index cf4e0ab792e..3a03177f8ec 100644 --- a/compiler/noirc_frontend/src/monomorphization/debug.rs +++ b/compiler/noirc_frontend/src/monomorphization/debug.rs @@ -195,8 +195,8 @@ fn element_type_at_index(ptype: &PrintableType, i: usize) -> &PrintableType { PrintableType::Tuple { types } => &types[i], PrintableType::Struct { name: _name, fields } => &fields[i].1, PrintableType::String { length: _length } => &PrintableType::UnsignedInteger { width: 8 }, - _ => { - panic!["expected type with sub-fields, found terminal type"] + other => { + panic!["expected type with sub-fields, found terminal type: {other:?}"] } } }