Skip to content

Commit

Permalink
Merge 0a91844 into 0db5610
Browse files Browse the repository at this point in the history
  • Loading branch information
vezenovm authored Sep 4, 2024
2 parents 0db5610 + 0a91844 commit 85faf64
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 13 deletions.
200 changes: 190 additions & 10 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ struct PerFunctionContext<'f> {

/// Track a value's last load across all blocks.
/// If a value is not used in anymore loads we can remove the last store to that value.
last_loads: HashMap<ValueId, (InstructionId, BasicBlockId)>,
last_loads: HashMap<ValueId, (InstructionId, BasicBlockId, u32)>,

/// Track whether a load result was used across all blocks.
load_results: HashMap<ValueId, PerFuncLoadResultContext>,

/// Track whether a reference was passed into another entry point
stores_used_in_calls: HashMap<ValueId, Vec<(InstructionId, BasicBlockId)>>,

/// Flag for tracking whether we had to perform a re-load as part of the Brillig CoW optimization.
/// Stores made as part of this optimization should not be removed.
Expand All @@ -138,6 +144,19 @@ struct PerFunctionContext<'f> {
inside_rc_reload: Option<bool>,
}

#[derive(Debug, Clone)]
struct PerFuncLoadResultContext {
load_counter: u32,
load_instruction: InstructionId,
instructions_using_result: Vec<(InstructionId, BasicBlockId)>,
}

impl PerFuncLoadResultContext {
fn new(load_instruction: InstructionId) -> Self {
Self { load_counter: 0, load_instruction, instructions_using_result: vec![] }
}
}

impl<'f> PerFunctionContext<'f> {
fn new(function: &'f mut Function) -> Self {
let cfg = ControlFlowGraph::with_function(function);
Expand All @@ -150,7 +169,9 @@ impl<'f> PerFunctionContext<'f> {
blocks: BTreeMap::new(),
instructions_to_remove: BTreeSet::new(),
last_loads: HashMap::default(),
load_results: HashMap::default(),
inside_rc_reload: None,
stores_used_in_calls: HashMap::default(),
}
}

Expand All @@ -168,6 +189,27 @@ impl<'f> PerFunctionContext<'f> {
self.analyze_block(block, references);
}

let mut loads_removed = HashMap::default();
for (_, PerFuncLoadResultContext { load_counter, load_instruction, .. }) in
self.load_results.iter()
{
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};

if *load_counter == 0 {
if let Some(counter) = loads_removed.get_mut(&address) {
*counter += 1;
} else {
loads_removed.insert(address, 1);
}

self.instructions_to_remove.insert(*load_instruction);
}
}

let mut not_removed_stores: HashMap<ValueId, (InstructionId, u32)> = HashMap::default();
// If we never load from an address within a function we can remove all stores to that address.
// This rule does not apply to reference parameters, which we must also check for before removing these stores.
for (block_id, block) in self.blocks.iter() {
Expand All @@ -189,15 +231,103 @@ impl<'f> PerFunctionContext<'f> {
let last_load_not_in_return = self
.last_loads
.get(store_address)
.map(|(_, last_load_block)| *last_load_block != *block_id)
.map(|(_, last_load_block, _)| *last_load_block != *block_id)
.unwrap_or(true);
!is_return_value && last_load_not_in_return
} else if let (Some((_, _, last_loads_counter)), Some(loads_removed_counter)) =
(self.last_loads.get(store_address), loads_removed.get(store_address))
{
*last_loads_counter == *loads_removed_counter
} else {
self.last_loads.get(store_address).is_none()
};

if remove_load && !is_reference_param {
let is_not_used_in_reference_param =
self.stores_used_in_calls.get(store_address).is_none();
if remove_load && !is_reference_param && is_not_used_in_reference_param {
self.instructions_to_remove.insert(*store_instruction);
if let Some((_, counter)) = not_removed_stores.get_mut(store_address) {
*counter -= 1;
}
} else if let Some((_, counter)) = not_removed_stores.get_mut(store_address) {
*counter += 1;
} else {
not_removed_stores.insert(*store_address, (*store_instruction, 1));
}
}
}

self.load_results.retain(|_, PerFuncLoadResultContext { load_instruction, .. }| {
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};
not_removed_stores.contains_key(&address)
});

let mut new_instructions = HashMap::default();
for (store_address, (store_instruction, store_counter)) in not_removed_stores {
let Instruction::Store { value, .. } = self.inserter.function.dfg[store_instruction]
else {
panic!("Should only have a store instruction");
};

if store_counter != 0 {
continue;
}
self.instructions_to_remove.insert(store_instruction);

if let (Some((_, _, last_loads_counter)), Some(loads_removed_counter)) =
(self.last_loads.get(&store_address), loads_removed.get(&store_address))
{
if *last_loads_counter < *loads_removed_counter {
panic!("The number of loads removed should not be more than all loads");
}
}

for (
result,
PerFuncLoadResultContext {
load_counter,
load_instruction,
instructions_using_result,
},
) in self.load_results.iter()
{
let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction]
else {
panic!("Should only have a load instruction here");
};
if address != store_address {
continue;
}

if *load_counter > 0 {
self.inserter.map_value(*result, value);
for (instruction, block_id) in instructions_using_result {
let new_instruction =
self.inserter.push_instruction(*instruction, *block_id);
if let Some(new_instruction) = new_instruction {
new_instructions
.insert((*instruction, block_id), Some(new_instruction));
} else {
new_instructions.insert((*instruction, block_id), None);
}
}

self.instructions_to_remove.insert(*load_instruction);
}
}
}

// Re-assign or delete any mapped instructions after the final loads were removed.
for ((old_instruction, block_id), new_instruction) in new_instructions {
let instructions = self.inserter.function.dfg[*block_id].instructions_mut();
if let Some(index) = instructions.iter().position(|v| *v == old_instruction) {
if let Some(new_instruction) = new_instruction {
instructions[index] = new_instruction;
} else {
instructions.remove(index);
}
}
}
Expand Down Expand Up @@ -286,6 +416,16 @@ impl<'f> PerFunctionContext<'f> {
return;
}

self.inserter.function.dfg[instruction].for_each_value(|value| {
if let Some(PerFuncLoadResultContext {
load_counter, instructions_using_result, ..
}) = self.load_results.get_mut(&value)
{
*load_counter += 1;
instructions_using_result.push((instruction, block_id));
}
});

match &self.inserter.function.dfg[instruction] {
Instruction::Load { address } => {
let address = self.inserter.function.dfg.resolve(*address);
Expand All @@ -304,7 +444,15 @@ impl<'f> PerFunctionContext<'f> {
references.aliases.insert(Expression::Other(result), AliasSet::known(result));
references.set_known_value(result, address);

self.last_loads.insert(address, (instruction, block_id));
self.load_results.insert(result, PerFuncLoadResultContext::new(instruction));

let load_counter =
if let Some((_, _, load_counter)) = self.last_loads.get(&address) {
*load_counter + 1
} else {
1
};
self.last_loads.insert(address, (instruction, block_id, load_counter));
}
}
Instruction::Store { address, value } => {
Expand All @@ -317,17 +465,29 @@ impl<'f> PerFunctionContext<'f> {
// function calls in-between, we can remove the previous store.
if let Some(last_store) = references.last_stores.get(&address) {
self.instructions_to_remove.insert(*last_store);
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter -= 1;
}
}

let known_value = references.get_known_value(value);
if let Some(known_value) = known_value {
let known_value_is_address = known_value == address;
if let Some(from_rc) = self.inside_rc_reload {
if known_value_is_address && !from_rc {
if known_value_is_address {
if let Some(from_rc) = self.inside_rc_reload {
if !from_rc {
self.instructions_to_remove.insert(instruction);
}
} else {
self.instructions_to_remove.insert(instruction);
}
} else if known_value_is_address {
self.instructions_to_remove.insert(instruction);
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter -= 1;
}
}
}

Expand Down Expand Up @@ -382,7 +542,17 @@ impl<'f> PerFunctionContext<'f> {
references.aliases.insert(expression, aliases);
}
}
Instruction::Call { arguments, .. } => self.mark_all_unknown(arguments, references),
Instruction::Call { arguments, .. } => {
for arg in arguments {
if self.inserter.function.dfg.value_is_reference(*arg) {
self.stores_used_in_calls
.entry(*arg)
.or_default()
.push((instruction, block_id));
}
}
self.mark_all_unknown(arguments, references);
}
_ => (),
}

Expand Down Expand Up @@ -488,7 +658,17 @@ impl<'f> PerFunctionContext<'f> {
fn handle_terminator(&mut self, block: BasicBlockId, references: &mut Block) {
self.inserter.map_terminator_in_place(block);

match self.inserter.function.dfg[block].unwrap_terminator() {
let terminator = self.inserter.function.dfg[block].unwrap_terminator();

terminator.for_each_value(|value| {
if let Some(PerFuncLoadResultContext { load_counter, .. }) =
self.load_results.get_mut(&value)
{
*load_counter += 1;
}
});

match terminator {
TerminatorInstruction::JmpIf { .. } => (), // Nothing to do
TerminatorInstruction::Jmp { destination, arguments, .. } => {
let destination_parameters = self.inserter.function.dfg[*destination].parameters();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ struct EnumEmulation {
unconstrained fn main() -> pub Field {
let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

// Do a copy to optimize out loads in the loop
let copy_enum = emulated_enum;
for _ in 0..1 {
assert_eq(copy_enum.a.unwrap(), 1);
assert_eq(emulated_enum.a.unwrap(), 1);
}

emulated_enum.a = Option::some(2);
Expand Down

0 comments on commit 85faf64

Please sign in to comment.