diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index 8d3fa9cc615..5f81a82accc 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -127,8 +127,9 @@ impl Context { .push(instructions_len - instruction_index - 1); } } else { - use Instruction::*; - if matches!(instruction, IncrementRc { .. } | DecrementRc { .. }) { + // We can't remove rc instructions if they're loaded from a reference + // since we'd have no way of knowing whether the reference is still used. + if Self::is_inc_dec_instruction_on_known_array(instruction, &function.dfg) { self.rc_instructions.push((*instruction_id, block_id)); } else { instruction.for_each_value(|value| { @@ -140,7 +141,7 @@ impl Context { rc_tracker.track_inc_rcs_to_remove(*instruction_id, function); } - self.instructions_to_remove.extend(rc_tracker.get_non_mutated_arrays()); + self.instructions_to_remove.extend(rc_tracker.get_non_mutated_arrays(&function.dfg)); self.instructions_to_remove.extend(rc_tracker.rc_pairs_to_remove); // If there are some instructions that might trigger an out of bounds error, @@ -337,6 +338,28 @@ impl Context { inserted_check } + + /// True if this is a `Instruction::IncrementRc` or `Instruction::DecrementRc` + /// operating on an array directly from a `Instruction::MakeArray` or an + /// intrinsic known to return a fresh array. + fn is_inc_dec_instruction_on_known_array( + instruction: &Instruction, + dfg: &DataFlowGraph, + ) -> bool { + use Instruction::*; + if let IncrementRc { value } | DecrementRc { value } = instruction { + if let Value::Instruction { instruction, .. } = &dfg[*value] { + return match &dfg[*instruction] { + MakeArray { .. } => true, + Call { func, .. } => { + matches!(&dfg[*func], Value::Intrinsic(_) | Value::ForeignFunction(_)) + } + _ => false, + }; + } + } + false + } } fn instruction_might_result_in_out_of_bounds( @@ -513,7 +536,7 @@ struct RcTracker { // We also separately track all IncrementRc instructions and all arrays which have been mutably borrowed. // If an array has not been mutably borrowed we can then safely remove all IncrementRc instructions on that array. inc_rcs: HashMap>, - mut_borrowed_arrays: HashSet, + mutated_array_types: HashSet, // The SSA often creates patterns where after simplifications we end up with repeat // IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc, // and if the current instruction is also an IncrementRc on the same value we remove the current instruction. @@ -567,25 +590,28 @@ impl RcTracker { } } - self.mut_borrowed_arrays.insert(*array); + self.mutated_array_types.insert(typ); } Instruction::Store { value, .. } => { - // We are very conservative and say that any store of an array value means it has the potential - // to be mutated. This is done due to the tracking of mutable borrows still being per block. + // We are very conservative and say that any store of an array value means that any + // array of that type has the potential to be mutated. This is done due to the + // tracking of mutable borrows still being per block and that we don't have the + // aliasing information from mem2reg. let typ = function.dfg.type_of_value(*value); if matches!(&typ, Type::Array(..) | Type::Slice(..)) { - self.mut_borrowed_arrays.insert(*value); + self.mutated_array_types.insert(typ); } } _ => {} } } - fn get_non_mutated_arrays(&self) -> HashSet { + fn get_non_mutated_arrays(&self, dfg: &DataFlowGraph) -> HashSet { self.inc_rcs .keys() .filter_map(|value| { - if !self.mut_borrowed_arrays.contains(value) { + let typ = dfg.type_of_value(*value); + if !self.mutated_array_types.contains(&typ) { Some(&self.inc_rcs[value]) } else { None @@ -858,4 +884,25 @@ mod test { let ssa = ssa.dead_instruction_elimination(); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn does_not_remove_inc_or_dec_rc_of_if_they_are_loaded_from_a_reference() { + let src = " + brillig(inline) fn borrow_mut f0 { + b0(v0: &mut [Field; 3]): + v1 = load v0 -> [Field; 3] + inc_rc v1 // this one shouldn't be removed + v2 = load v0 -> [Field; 3] + inc_rc v2 // this one shouldn't be removed + v3 = load v0 -> [Field; 3] + v6 = array_set v3, index u32 0, value Field 5 + store v6 at v0 + dec_rc v6 + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.dead_instruction_elimination(); + assert_normalized_ssa_equals(ssa, src); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 53a31ae57c1..1750ecb453a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -142,6 +142,10 @@ struct PerFunctionContext<'f> { /// instruction that aliased that reference. /// If that store has been set for removal, we can also remove this instruction. aliased_references: HashMap>, + + /// Track whether the last instruction is an inc_rc/dec_rc instruction. + /// If it is we should not remove any repeat last loads. + inside_rc_reload: bool, } impl<'f> PerFunctionContext<'f> { @@ -158,6 +162,7 @@ impl<'f> PerFunctionContext<'f> { last_loads: HashMap::default(), calls_reference_input: HashSet::default(), aliased_references: HashMap::default(), + inside_rc_reload: false, } } @@ -435,7 +440,7 @@ impl<'f> PerFunctionContext<'f> { let result = self.inserter.function.dfg.instruction_results(instruction)[0]; let previous_result = self.inserter.function.dfg.instruction_results(*last_load)[0]; - if *previous_address == address { + if *previous_address == address && !self.inside_rc_reload { self.inserter.map_value(result, previous_result); self.instructions_to_remove.insert(instruction); } @@ -553,6 +558,18 @@ impl<'f> PerFunctionContext<'f> { } _ => (), } + + self.track_rc_reload_state(instruction); + } + + fn track_rc_reload_state(&mut self, instruction: InstructionId) { + match &self.inserter.function.dfg[instruction] { + // We just had an increment or decrement to an array's reference counter + Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } => { + self.inside_rc_reload = true; + } + _ => self.inside_rc_reload = false, + } } fn contains_references(typ: &Type) -> bool {