From 69bb64fa34667810e96ea85c7594595522ccdce1 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:02:10 +0000 Subject: [PATCH] chore: deduplicate constants across blocks (#9972) Please read [contributing guidelines](CONTRIBUTING.md) and remove this line. --------- Co-authored-by: Ary Borenszweig --- .../src/ssa/opt/constant_folding.rs | 259 +++++++++++++----- 1 file changed, 188 insertions(+), 71 deletions(-) diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 9f55e69868c..9ee9a52b5ad 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -19,7 +19,7 @@ //! //! This is the only pass which removes duplicated pure [`Instruction`]s however and so is needed when //! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass. -use std::collections::HashSet; +use std::collections::{HashSet, VecDeque}; use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; @@ -28,6 +28,7 @@ use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::{DataFlowGraph, InsertInstructionResult}, + dom::DominatorTree, function::Function, instruction::{Instruction, InstructionId}, types::Type, @@ -67,10 +68,10 @@ impl Function { /// The structure of this pass is simple: /// Go through each block and re-insert all instructions. pub(crate) fn constant_fold(&mut self, use_constraint_info: bool) { - let mut context = Context { use_constraint_info, ..Default::default() }; - context.block_queue.push(self.entry_block()); + let mut context = Context::new(self, use_constraint_info); + context.block_queue.push_back(self.entry_block()); - while let Some(block) = context.block_queue.pop() { + while let Some(block) = context.block_queue.pop_front() { if context.visited_blocks.contains(&block) { continue; } @@ -81,34 +82,62 @@ impl Function { } } -#[derive(Default)] struct Context { use_constraint_info: bool, /// Maps pre-folded ValueIds to the new ValueIds obtained by re-inserting the instruction. visited_blocks: HashSet, - block_queue: Vec, + block_queue: VecDeque, + + /// Contains sets of values which are constrained to be equivalent to each other. + /// + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => [(block, simplified_value)])`. + /// + /// We partition the maps of constrained values according to the side-effects flag at the point + /// at which the values are constrained. This prevents constraints which are only sometimes enforced + /// being used to modify the rest of the program. + /// + /// We also keep track of how a value was simplified to other values per block. That is, + /// a same ValueId could have been simplified to one value in one block and to another value + /// in another block. + constraint_simplification_mappings: + HashMap>>, + + // Cache of instructions without any side-effects along with their outputs. + cached_instruction_results: InstructionResultCache, + + dom: DominatorTree, } /// HashMap from (Instruction, side_effects_enabled_var) to the results of the instruction. /// Stored as a two-level map to avoid cloning Instructions during the `.get` call. -type InstructionResultCache = HashMap, Vec>>; +/// +/// In addition to each result, the original BasicBlockId is stored as well. This allows us +/// to deduplicate instructions across blocks as long as the new block dominates the original. +type InstructionResultCache = HashMap, ResultCache>>; + +/// Records the results of all duplicate [`Instruction`]s along with the blocks in which they sit. +/// +/// For more information see [`InstructionResultCache`]. +#[derive(Default)] +struct ResultCache { + results: Vec<(BasicBlockId, Vec)>, +} impl Context { + fn new(function: &Function, use_constraint_info: bool) -> Self { + Self { + use_constraint_info, + visited_blocks: Default::default(), + block_queue: Default::default(), + constraint_simplification_mappings: Default::default(), + cached_instruction_results: Default::default(), + dom: DominatorTree::with_function(function), + } + } + fn fold_constants_in_block(&mut self, function: &mut Function, block: BasicBlockId) { let instructions = function.dfg[block].take_instructions(); - // Cache of instructions without any side-effects along with their outputs. - let mut cached_instruction_results = HashMap::default(); - - // Contains sets of values which are constrained to be equivalent to each other. - // - // The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. - // - // We partition the maps of constrained values according to the side-effects flag at the point - // at which the values are constrained. This prevents constraints which are only sometimes enforced - // being used to modify the rest of the program. - let mut constraint_simplification_mappings: HashMap> = - HashMap::default(); let mut side_effects_enabled_var = function.dfg.make_constant(FieldElement::one(), Type::bool()); @@ -117,8 +146,6 @@ impl Context { &mut function.dfg, block, instruction_id, - &mut cached_instruction_results, - &mut constraint_simplification_mappings, &mut side_effects_enabled_var, ); } @@ -126,22 +153,26 @@ impl Context { } fn fold_constants_into_instruction( - &self, + &mut self, dfg: &mut DataFlowGraph, block: BasicBlockId, id: InstructionId, - instruction_result_cache: &mut InstructionResultCache, - constraint_simplification_mappings: &mut HashMap>, side_effects_enabled_var: &mut ValueId, ) { let constraint_simplification_mapping = - constraint_simplification_mappings.entry(*side_effects_enabled_var).or_default(); - let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); + self.constraint_simplification_mappings.get(side_effects_enabled_var); + let instruction = Self::resolve_instruction( + id, + block, + dfg, + &mut self.dom, + constraint_simplification_mapping, + ); let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. if let Some(cached_results) = - Self::get_cached(dfg, instruction_result_cache, &instruction, *side_effects_enabled_var) + self.get_cached(dfg, &instruction, *side_effects_enabled_var, block) { Self::replace_result_ids(dfg, &old_results, cached_results); return; @@ -156,9 +187,8 @@ impl Context { instruction.clone(), new_results, dfg, - instruction_result_cache, - constraint_simplification_mapping, *side_effects_enabled_var, + block, ); // If we just inserted an `Instruction::EnableSideEffectsIf`, we need to update `side_effects_enabled_var` @@ -171,8 +201,10 @@ impl Context { /// Fetches an [`Instruction`] by its [`InstructionId`] and fully resolves its inputs. fn resolve_instruction( instruction_id: InstructionId, + block: BasicBlockId, dfg: &DataFlowGraph, - constraint_simplification_mapping: &HashMap, + dom: &mut DominatorTree, + constraint_simplification_mapping: Option<&HashMap>>, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -183,19 +215,30 @@ impl Context { // constraints to the cache. fn resolve_cache( dfg: &DataFlowGraph, - cache: &HashMap, + dom: &mut DominatorTree, + cache: Option<&HashMap>>, value_id: ValueId, + block: BasicBlockId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); - match cache.get(&resolved_id) { - Some(cached_value) => resolve_cache(dfg, cache, *cached_value), - None => resolved_id, + let Some(cached_values) = cache.and_then(|cache| cache.get(&resolved_id)) else { + return resolved_id; + }; + + for (cached_block, cached_value) in cached_values { + // We can only use the simplified value if it was simplified in a block that dominates the current one + if dom.dominates(*cached_block, block) { + return resolve_cache(dfg, dom, cache, *cached_value, block); + } } + + resolved_id } // Resolve any inputs to ensure that we're comparing like-for-like instructions. - instruction - .map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id)) + instruction.map_values(|value_id| { + resolve_cache(dfg, dom, constraint_simplification_mapping, value_id, block) + }) } /// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations @@ -229,39 +272,23 @@ impl Context { } fn cache_instruction( - &self, + &mut self, instruction: Instruction, instruction_results: Vec, dfg: &DataFlowGraph, - instruction_result_cache: &mut InstructionResultCache, - constraint_simplification_mapping: &mut HashMap, side_effects_enabled_var: ValueId, + block: BasicBlockId, ) { if self.use_constraint_info { // If the instruction was a constraint, then create a link between the two `ValueId`s // to map from the more complex to the simpler value. if let Instruction::Constrain(lhs, rhs, _) = instruction { // These `ValueId`s should be fully resolved now. - match (&dfg[lhs], &dfg[rhs]) { - // Ignore trivial constraints - (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => (), - - // Prefer replacing with constants where possible. - (Value::NumericConstant { .. }, _) => { - constraint_simplification_mapping.insert(rhs, lhs); - } - (_, Value::NumericConstant { .. }) => { - constraint_simplification_mapping.insert(lhs, rhs); - } - // Otherwise prefer block parameters over instruction results. - // This is as block parameters are more likely to be a single witness rather than a full expression. - (Value::Param { .. }, Value::Instruction { .. }) => { - constraint_simplification_mapping.insert(rhs, lhs); - } - (Value::Instruction { .. }, Value::Param { .. }) => { - constraint_simplification_mapping.insert(lhs, rhs); - } - (_, _) => (), + if let Some((complex, simple)) = simplify(dfg, lhs, rhs) { + self.get_constraint_map(side_effects_enabled_var) + .entry(complex) + .or_default() + .push((block, simple)); } } } @@ -273,13 +300,22 @@ impl Context { self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); let predicate = use_predicate.then_some(side_effects_enabled_var); - instruction_result_cache + self.cached_instruction_results .entry(instruction) .or_default() - .insert(predicate, instruction_results); + .entry(predicate) + .or_default() + .cache(block, instruction_results); } } + fn get_constraint_map( + &mut self, + side_effects_enabled_var: ValueId, + ) -> &mut HashMap> { + self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() + } + /// Replaces a set of [`ValueId`]s inside the [`DataFlowGraph`] with another. fn replace_result_ids( dfg: &mut DataFlowGraph, @@ -292,22 +328,59 @@ impl Context { } fn get_cached<'a>( + &'a mut self, dfg: &DataFlowGraph, - instruction_result_cache: &'a mut InstructionResultCache, instruction: &Instruction, side_effects_enabled_var: ValueId, - ) -> Option<&'a Vec> { - let results_for_instruction = instruction_result_cache.get(instruction); + block: BasicBlockId, + ) -> Option<&'a [ValueId]> { + let results_for_instruction = self.cached_instruction_results.get(instruction)?; - // See if there's a cached version with no predicate first - if let Some(results) = results_for_instruction.and_then(|map| map.get(&None)) { - return Some(results); - } + let predicate = self.use_constraint_info && instruction.requires_acir_gen_predicate(dfg); + let predicate = predicate.then_some(side_effects_enabled_var); + + results_for_instruction.get(&predicate)?.get(block, &mut self.dom) + } +} - let predicate = - instruction.requires_acir_gen_predicate(dfg).then_some(side_effects_enabled_var); +impl ResultCache { + /// Records that an `Instruction` in block `block` produced the result values `results`. + fn cache(&mut self, block: BasicBlockId, results: Vec) { + self.results.push((block, results)); + } - results_for_instruction.and_then(|map| map.get(&predicate)) + /// Returns a set of [`ValueId`]s produced from a copy of this [`Instruction`] which sits + /// within a block which dominates `block`. + /// + /// We require that the cached instruction's block dominates `block` in order to avoid + /// cycles causing issues (e.g. two instructions being replaced with the results of each other + /// such that neither instruction exists anymore.) + fn get(&self, block: BasicBlockId, dom: &mut DominatorTree) -> Option<&[ValueId]> { + for (origin_block, results) in &self.results { + if dom.dominates(*origin_block, block) { + return Some(results); + } + } + None + } +} + +/// Check if one expression is simpler than the other. +/// Returns `Some((complex, simple))` if a simplification was found, otherwise `None`. +/// Expects the `ValueId`s to be fully resolved. +fn simplify(dfg: &DataFlowGraph, lhs: ValueId, rhs: ValueId) -> Option<(ValueId, ValueId)> { + match (&dfg[lhs], &dfg[rhs]) { + // Ignore trivial constraints + (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => None, + + // Prefer replacing with constants where possible. + (Value::NumericConstant { .. }, _) => Some((rhs, lhs)), + (_, Value::NumericConstant { .. }) => Some((lhs, rhs)), + // Otherwise prefer block parameters over instruction results. + // This is as block parameters are more likely to be a single witness rather than a full expression. + (Value::Param { .. }, Value::Instruction { .. }) => Some((rhs, lhs)), + (Value::Instruction { .. }, Value::Param { .. }) => Some((lhs, rhs)), + (_, _) => None, } } @@ -673,4 +746,48 @@ mod test { let ending_instruction_count = instructions.len(); assert_eq!(ending_instruction_count, 1); } + + #[test] + fn deduplicate_across_blocks() { + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // v2 = not v0 + // return v2 + // } + let main_id = Id::test_new(0); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let b1 = builder.insert_block(); + + let v0 = builder.add_parameter(Type::bool()); + let _v1 = builder.insert_not(v0); + builder.terminate_with_jmp(b1, Vec::new()); + + builder.switch_to_block(b1); + let v2 = builder.insert_not(v0); + builder.terminate_with_return(vec![v2]); + + let ssa = builder.finish(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 1); + + // Expected output: + // + // fn main f0 { + // b0(v0: u1): + // v1 = not v0 + // jmp b1() + // b1(): + // return v1 + // } + let ssa = ssa.fold_constants_using_constraints(); + let main = ssa.main(); + assert_eq!(main.dfg[main.entry_block()].instructions().len(), 1); + assert_eq!(main.dfg[b1].instructions().len(), 0); + } }