Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(perf): Remove unused last loads in mem2reg #5905

Closed
200 changes: 190 additions & 10 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
//!
//! Repeating this algorithm for each block in the function in program order should result in
//! optimizing out most known loads. However, identifying all aliases correctly has been proven
//! undecidable in general (Landi, 1992). So this pass will not always optimize out all loads

Check warning on line 59 in compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Landi)
//! that could theoretically be optimized out. This pass can be performed at any time in the
//! SSA optimization pipeline, although it will be more successful the simpler the program's CFG is.
//! This pass is currently performed several times to enable other passes - most notably being
Expand Down Expand Up @@ -111,12 +111,18 @@
/// Load and Store instructions that should be removed at the end of the pass.
///
/// We avoid removing individual instructions as we go since removing elements
/// from the middle of Vecs many times will be slower than a single call to `retain`.

Check warning on line 114 in compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Vecs)
instructions_to_remove: BTreeSet<InstructionId>,

/// Track a value's last load across all blocks.
/// If a value is not used in anymore loads we can remove the last store to that value.
last_loads: HashMap<ValueId, (InstructionId, BasicBlockId)>,
last_loads: HashMap<ValueId, (InstructionId, BasicBlockId, u32)>,

/// Track whether a load result was used across all blocks.
load_results: HashMap<ValueId, PerFuncLoadResultContext>,

/// Track whether a reference was passed into another entry point
stores_used_in_calls: HashMap<ValueId, Vec<(InstructionId, BasicBlockId)>>,

/// Flag for tracking whether we had to perform a re-load as part of the Brillig CoW optimization.
/// Stores made as part of this optimization should not be removed.
Expand All @@ -138,6 +144,19 @@
inside_rc_reload: Option<bool>,
}

#[derive(Debug, Clone)]
struct PerFuncLoadResultContext {
load_counter: u32,
load_instruction: InstructionId,
instructions_using_result: Vec<(InstructionId, BasicBlockId)>,
}

impl PerFuncLoadResultContext {
fn new(load_instruction: InstructionId) -> Self {
Self { load_counter: 0, load_instruction, instructions_using_result: vec![] }
}
}

impl<'f> PerFunctionContext<'f> {
fn new(function: &'f mut Function) -> Self {
let cfg = ControlFlowGraph::with_function(function);
Expand All @@ -150,7 +169,9 @@
blocks: BTreeMap::new(),
instructions_to_remove: BTreeSet::new(),
last_loads: HashMap::default(),
load_results: HashMap::default(),
inside_rc_reload: None,
stores_used_in_calls: HashMap::default(),
}
}

Expand All @@ -168,6 +189,27 @@
self.analyze_block(block, references);
}

let mut loads_removed = HashMap::default();
for (_, PerFuncLoadResultContext { load_counter, load_instruction, .. }) in
self.load_results.iter()
{
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};

if *load_counter == 0 {
if let Some(counter) = loads_removed.get_mut(&address) {
*counter += 1;
} else {
loads_removed.insert(address, 1);
}

self.instructions_to_remove.insert(*load_instruction);
}
}

let mut not_removed_stores: HashMap<ValueId, (InstructionId, u32)> = HashMap::default();
// If we never load from an address within a function we can remove all stores to that address.
// This rule does not apply to reference parameters, which we must also check for before removing these stores.
for (block_id, block) in self.blocks.iter() {
Expand All @@ -189,15 +231,103 @@
let last_load_not_in_return = self
.last_loads
.get(store_address)
.map(|(_, last_load_block)| *last_load_block != *block_id)
.map(|(_, last_load_block, _)| *last_load_block != *block_id)
.unwrap_or(true);
!is_return_value && last_load_not_in_return
} else if let (Some((_, _, last_loads_counter)), Some(loads_removed_counter)) =
(self.last_loads.get(store_address), loads_removed.get(store_address))
{
*last_loads_counter == *loads_removed_counter
} else {
self.last_loads.get(store_address).is_none()
};

if remove_load && !is_reference_param {
let is_not_used_in_reference_param =
self.stores_used_in_calls.get(store_address).is_none();
if remove_load && !is_reference_param && is_not_used_in_reference_param {
self.instructions_to_remove.insert(*store_instruction);
if let Some((_, counter)) = not_removed_stores.get_mut(store_address) {
*counter -= 1;
}
} else if let Some((_, counter)) = not_removed_stores.get_mut(store_address) {
*counter += 1;
} else {
not_removed_stores.insert(*store_address, (*store_instruction, 1));
}
}
}

self.load_results.retain(|_, PerFuncLoadResultContext { load_instruction, .. }| {
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};
not_removed_stores.contains_key(&address)
});

let mut new_instructions = HashMap::default();
for (store_address, (store_instruction, store_counter)) in not_removed_stores {
let Instruction::Store { value, .. } = self.inserter.function.dfg[store_instruction]
else {
panic!("Should only have a store instruction");
};

if store_counter != 0 {
continue;
}
self.instructions_to_remove.insert(store_instruction);

if let (Some((_, _, last_loads_counter)), Some(loads_removed_counter)) =
(self.last_loads.get(&store_address), loads_removed.get(&store_address))
{
if *last_loads_counter < *loads_removed_counter {
panic!("The number of loads removed should not be more than all loads");
}
}

for (
result,
PerFuncLoadResultContext {
load_counter,
load_instruction,
instructions_using_result,
},
) in self.load_results.iter()
{
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};
if address != store_address {
continue;
}

if *load_counter > 0 {
self.inserter.map_value(*result, value);
for (instruction, block_id) in instructions_using_result {
let new_instruction =
self.inserter.push_instruction(*instruction, *block_id);
if let Some(new_instruction) = new_instruction {
new_instructions
.insert((*instruction, block_id), Some(new_instruction));
} else {
new_instructions.insert((*instruction, block_id), None);
}
}

self.instructions_to_remove.insert(*load_instruction);
}
}
}

// Re-assign or delete any mapped instructions after the final loads were removed.
for ((old_instruction, block_id), new_instruction) in new_instructions {
let instructions = self.inserter.function.dfg[*block_id].instructions_mut();
if let Some(index) = instructions.iter().position(|v| *v == old_instruction) {
if let Some(new_instruction) = new_instruction {
instructions[index] = new_instruction;
} else {
instructions.remove(index);
}
}
}
Expand Down Expand Up @@ -286,6 +416,16 @@
return;
}

self.inserter.function.dfg[instruction].for_each_value(|value| {
if let Some(PerFuncLoadResultContext {
load_counter, instructions_using_result, ..
}) = self.load_results.get_mut(&value)
{
*load_counter += 1;
instructions_using_result.push((instruction, block_id));
}
});

match &self.inserter.function.dfg[instruction] {
Instruction::Load { address } => {
let address = self.inserter.function.dfg.resolve(*address);
Expand All @@ -304,7 +444,15 @@
references.aliases.insert(Expression::Other(result), AliasSet::known(result));
references.set_known_value(result, address);

self.last_loads.insert(address, (instruction, block_id));
self.load_results.insert(result, PerFuncLoadResultContext::new(instruction));

let load_counter =
if let Some((_, _, load_counter)) = self.last_loads.get(&address) {
*load_counter + 1
} else {
1
};
self.last_loads.insert(address, (instruction, block_id, load_counter));
}
}
Instruction::Store { address, value } => {
Expand All @@ -317,17 +465,29 @@
// function calls in-between, we can remove the previous store.
if let Some(last_store) = references.last_stores.get(&address) {
self.instructions_to_remove.insert(*last_store);
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter -= 1;
}
}

let known_value = references.get_known_value(value);
if let Some(known_value) = known_value {
let known_value_is_address = known_value == address;
if let Some(from_rc) = self.inside_rc_reload {
if known_value_is_address && !from_rc {
if known_value_is_address {
if let Some(from_rc) = self.inside_rc_reload {
if !from_rc {
self.instructions_to_remove.insert(instruction);
}
} else {
self.instructions_to_remove.insert(instruction);
}
} else if known_value_is_address {
self.instructions_to_remove.insert(instruction);
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter -= 1;
}
}
}

Expand Down Expand Up @@ -382,7 +542,17 @@
references.aliases.insert(expression, aliases);
}
}
Instruction::Call { arguments, .. } => self.mark_all_unknown(arguments, references),
Instruction::Call { arguments, .. } => {
for arg in arguments {
if self.inserter.function.dfg.value_is_reference(*arg) {
self.stores_used_in_calls
.entry(*arg)
.or_default()
.push((instruction, block_id));
}
}
self.mark_all_unknown(arguments, references);
}
_ => (),
}

Expand Down Expand Up @@ -488,7 +658,17 @@
fn handle_terminator(&mut self, block: BasicBlockId, references: &mut Block) {
self.inserter.map_terminator_in_place(block);

match self.inserter.function.dfg[block].unwrap_terminator() {
let terminator = self.inserter.function.dfg[block].unwrap_terminator();

terminator.for_each_value(|value| {
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter += 1;
}
});

match terminator {
TerminatorInstruction::JmpIf { .. } => (), // Nothing to do
TerminatorInstruction::Jmp { destination, arguments, .. } => {
let destination_parameters = self.inserter.function.dfg[*destination].parameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ struct EnumEmulation {
unconstrained fn main() -> pub Field {
let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

// Do a copy to optimize out loads in the loop
let copy_enum = emulated_enum;
Copy link
Contributor Author

@vezenovm vezenovm Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this PR:
Even after removing this copy optimization we get a Brillig bytecode size improvement from 55 -> 44.

Keeping the copy optimization we get the following final SSA which produces 24 brillig opcodes:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    jmp b1(u32 0)
  b1(v6: u32):
    v27 = eq v6, u32 0
    jmpif v27 then: b2, else: b3
  b2():
    v28 = add v6, u32 1
    jmp b1(v28)
  b3():
    return Field 2
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this PR now produces the optimal SSA specified above without the copy optimization. Just getting a failure on uhashmap that is the same as #5897.

for _ in 0..1 {
assert_eq(copy_enum.a.unwrap(), 1);
assert_eq(emulated_enum.a.unwrap(), 1);
}

emulated_enum.a = Option::some(2);
Expand Down
Loading