diff --git a/crates/rustc_codegen_spirv/src/linker/dce.rs b/crates/rustc_codegen_spirv/src/linker/dce.rs index 3cc865a741..896406c4e9 100644 --- a/crates/rustc_codegen_spirv/src/linker/dce.rs +++ b/crates/rustc_codegen_spirv/src/linker/dce.rs @@ -7,9 +7,10 @@ //! *references* a rooted thing is also rooted, not the other way around - but that's the basic //! concept. -use rspirv::dr::{Function, Instruction, Module, Operand}; +use rspirv::dr::{Block, Function, Instruction, Module, Operand}; use rspirv::spirv::{Decoration, LinkageType, Op, StorageClass, Word}; -use rustc_data_structures::fx::FxIndexSet; +use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; +use std::hash::Hash; pub fn dce(module: &mut Module) { let mut rooted = collect_roots(module); @@ -137,11 +138,11 @@ fn kill_unrooted(module: &mut Module, rooted: &FxIndexSet) { } } -pub fn dce_phi(func: &mut Function) { +pub fn dce_phi(blocks: &mut FxIndexMap) { let mut used = FxIndexSet::default(); loop { let mut changed = false; - for inst in func.all_inst_iter() { + for inst in blocks.values().flat_map(|block| &block.instructions) { if inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap()) { for op in &inst.operands { if let Some(id) = op.id_ref_any() { @@ -154,7 +155,7 @@ pub fn dce_phi(func: &mut Function) { break; } } - for block in &mut func.blocks { + for block in blocks.values_mut() { block .instructions .retain(|inst| inst.class.opcode != Op::Phi || used.contains(&inst.result_id.unwrap())); diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index 628be71c8d..d3ffe52692 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -13,10 +13,14 @@ use super::simple_passes::outgoing_edges; use super::{apply_rewrite_rules, id}; use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand}; use rspirv::spirv::{Op, Word}; -use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap}; use rustc_middle::bug; use std::collections::hash_map; +// HACK(eddyb) newtype instead of type alias to avoid mistakes. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +struct LabelId(Word); + pub fn mem2reg( header: &mut ModuleHeader, types_global_values: &mut Vec, @@ -24,8 +28,16 @@ pub fn mem2reg( constants: &FxHashMap, func: &mut Function, ) { - let reachable = compute_reachable(&func.blocks); - let preds = compute_preds(&func.blocks, &reachable); + // HACK(eddyb) this ad-hoc indexing might be useful elsewhere as well, but + // it's made completely irrelevant by SPIR-T so only applies to legacy code. + let mut blocks: FxIndexMap<_, _> = func + .blocks + .iter_mut() + .map(|block| (LabelId(block.label_id().unwrap()), block)) + .collect(); + + let reachable = compute_reachable(&blocks); + let preds = compute_preds(&blocks, &reachable); let idom = compute_idom(&preds, &reachable); let dominance_frontier = compute_dominance_frontier(&preds, &idom); loop { @@ -34,31 +46,27 @@ pub fn mem2reg( types_global_values, pointer_to_pointee, constants, - &mut func.blocks, + &mut blocks, &dominance_frontier, ); if !changed { break; } // mem2reg produces minimal SSA form, not pruned, so DCE the dead ones - super::dce::dce_phi(func); + super::dce::dce_phi(&mut blocks); } } -fn label_to_index(blocks: &[Block], id: Word) -> usize { - blocks - .iter() - .position(|b| b.label_id().unwrap() == id) - .unwrap() -} - -fn compute_reachable(blocks: &[Block]) -> Vec { - fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) { +fn compute_reachable(blocks: &FxIndexMap) -> Vec { + fn recurse(blocks: &FxIndexMap, reachable: &mut [bool], block: usize) { if !reachable[block] { reachable[block] = true; - for dest_id in outgoing_edges(&blocks[block]) { - let dest_idx = label_to_index(blocks, dest_id); - recurse(blocks, reachable, dest_idx); + for dest_id in outgoing_edges(blocks[block]) { + recurse( + blocks, + reachable, + blocks.get_index_of(&LabelId(dest_id)).unwrap(), + ); } } } @@ -67,17 +75,19 @@ fn compute_reachable(blocks: &[Block]) -> Vec { reachable } -fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec> { +fn compute_preds( + blocks: &FxIndexMap, + reachable_blocks: &[bool], +) -> Vec> { let mut result = vec![vec![]; blocks.len()]; // Do not count unreachable blocks as valid preds of blocks for (source_idx, source) in blocks - .iter() + .values() .enumerate() .filter(|&(b, _)| reachable_blocks[b]) { for dest_id in outgoing_edges(source) { - let dest_idx = label_to_index(blocks, dest_id); - result[dest_idx].push(source_idx); + result[blocks.get_index_of(&LabelId(dest_id)).unwrap()].push(source_idx); } } result @@ -161,7 +171,7 @@ fn insert_phis_all( types_global_values: &mut Vec, pointer_to_pointee: &FxHashMap, constants: &FxHashMap, - blocks: &mut [Block], + blocks: &mut FxIndexMap, dominance_frontier: &[FxHashSet], ) -> bool { let var_maps_and_types = blocks[0] @@ -198,7 +208,11 @@ fn insert_phis_all( rewrite_rules: FxHashMap::default(), }; renamer.rename(0, None); - apply_rewrite_rules(&renamer.rewrite_rules, blocks); + // FIXME(eddyb) shouldn't this full rescan of the function be done once? + apply_rewrite_rules( + &renamer.rewrite_rules, + blocks.values_mut().map(|block| &mut **block), + ); remove_nops(blocks); } remove_old_variables(blocks, &var_maps_and_types); @@ -216,7 +230,7 @@ struct VarInfo { fn collect_access_chains( pointer_to_pointee: &FxHashMap, constants: &FxHashMap, - blocks: &[Block], + blocks: &FxIndexMap, base_var: Word, base_var_ty: Word, ) -> Option> { @@ -249,7 +263,7 @@ fn collect_access_chains( // Loop in case a previous block references a later AccessChain loop { let mut changed = false; - for inst in blocks.iter().flat_map(|b| &b.instructions) { + for inst in blocks.values().flat_map(|b| &b.instructions) { for (index, op) in inst.operands.iter().enumerate() { if let Operand::IdRef(id) = op { if variables.contains_key(id) { @@ -307,10 +321,10 @@ fn collect_access_chains( // same var map (e.g. `s.x = s.y;`). fn split_copy_memory( header: &mut ModuleHeader, - blocks: &mut [Block], + blocks: &mut FxIndexMap, var_map: &FxHashMap, ) { - for block in blocks { + for block in blocks.values_mut() { let mut inst_index = 0; while inst_index < block.instructions.len() { let inst = &block.instructions[inst_index]; @@ -369,7 +383,7 @@ fn has_store(block: &Block, var_map: &FxHashMap) -> bool { } fn insert_phis( - blocks: &[Block], + blocks: &FxIndexMap, dominance_frontier: &[FxHashSet], var_map: &FxHashMap, ) -> FxHashSet { @@ -378,7 +392,7 @@ fn insert_phis( let mut ever_on_work_list = FxHashSet::default(); let mut work_list = Vec::new(); let mut blocks_with_phi = FxHashSet::default(); - for (block_idx, block) in blocks.iter().enumerate() { + for (block_idx, block) in blocks.values().enumerate() { if has_store(block, var_map) { ever_on_work_list.insert(block_idx); work_list.push(block_idx); @@ -423,10 +437,10 @@ fn top_stack_or_undef( } } -struct Renamer<'a> { +struct Renamer<'a, 'b> { header: &'a mut ModuleHeader, types_global_values: &'a mut Vec, - blocks: &'a mut [Block], + blocks: &'a mut FxIndexMap, blocks_with_phi: FxHashSet, base_var_type: Word, var_map: &'a FxHashMap, @@ -436,7 +450,7 @@ struct Renamer<'a> { rewrite_rules: FxHashMap, } -impl Renamer<'_> { +impl Renamer<'_, '_> { // Returns the phi definition. fn insert_phi_value(&mut self, block: usize, from_block: usize) -> Word { let from_block_label = self.blocks[from_block].label_id().unwrap(); @@ -558,9 +572,8 @@ impl Renamer<'_> { } } - for dest_id in outgoing_edges(&self.blocks[block]).collect::>() { - // TODO: Don't do this find - let dest_idx = label_to_index(self.blocks, dest_id); + for dest_id in outgoing_edges(self.blocks[block]).collect::>() { + let dest_idx = self.blocks.get_index_of(&LabelId(dest_id)).unwrap(); self.rename(dest_idx, Some(block)); } @@ -570,8 +583,8 @@ impl Renamer<'_> { } } -fn remove_nops(blocks: &mut [Block]) { - for block in blocks { +fn remove_nops(blocks: &mut FxIndexMap) { + for block in blocks.values_mut() { block .instructions .retain(|inst| inst.class.opcode != Op::Nop); @@ -579,7 +592,7 @@ fn remove_nops(blocks: &mut [Block]) { } fn remove_old_variables( - blocks: &mut [Block], + blocks: &mut FxIndexMap, var_maps_and_types: &[(FxHashMap, u32)], ) { blocks[0].instructions.retain(|inst| { @@ -590,7 +603,7 @@ fn remove_old_variables( .all(|(var_map, _)| !var_map.contains_key(&result_id)) } }); - for block in blocks { + for block in blocks.values_mut() { block.instructions.retain(|inst| { !matches!(inst.class.opcode, Op::AccessChain | Op::InBoundsAccessChain) || inst.operands.iter().all(|op| { diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index f53823b753..5a6c062b5f 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -88,7 +88,10 @@ fn id(header: &mut ModuleHeader) -> Word { result } -fn apply_rewrite_rules(rewrite_rules: &FxHashMap, blocks: &mut [Block]) { +fn apply_rewrite_rules<'a>( + rewrite_rules: &FxHashMap, + blocks: impl IntoIterator, +) { let apply = |inst: &mut Instruction| { if let Some(ref mut id) = &mut inst.result_id { if let Some(&rewrite) = rewrite_rules.get(id) {