Skip to content

Commit

Permalink
linker/inline: use OpPhi instead of OpVariable for return values.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb authored and Firestar99 committed Oct 19, 2024
1 parent 7e5c74d commit ce5558c
Showing 1 changed file with 67 additions and 69 deletions.
136 changes: 67 additions & 69 deletions crates/rustc_codegen_spirv/src/linker/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
header,
debug_string_source: &mut module.debug_string_source,
annotations: &mut module.annotations,
types_global_values: &mut module.types_global_values,

legal_globals,

Expand Down Expand Up @@ -493,7 +492,6 @@ struct Inliner<'a, 'b> {
header: &'b mut ModuleHeader,
debug_string_source: &'b mut Vec<Instruction>,
annotations: &'b mut Vec<Instruction>,
types_global_values: &'b mut Vec<Instruction>,

legal_globals: FxHashMap<Word, LegalGlobal>,
functions_that_may_abort: FxHashSet<Word>,
Expand Down Expand Up @@ -523,29 +521,6 @@ impl Inliner<'_, '_> {
}
}

fn ptr_ty(&mut self, pointee: Word) -> Word {
// TODO: This is horribly slow, fix this
let existing = self.types_global_values.iter().find(|inst| {
inst.class.opcode == Op::TypePointer
&& inst.operands[0].unwrap_storage_class() == StorageClass::Function
&& inst.operands[1].unwrap_id_ref() == pointee
});
if let Some(existing) = existing {
return existing.result_id.unwrap();
}
let inst_id = self.id();
self.types_global_values.push(Instruction::new(
Op::TypePointer,
None,
Some(inst_id),
vec![
Operand::StorageClass(StorageClass::Function),
Operand::IdRef(pointee),
],
));
inst_id
}

fn inline_fn(
&mut self,
function: &mut Function,
Expand Down Expand Up @@ -622,15 +597,19 @@ impl Inliner<'_, '_> {
.insert(caller.def_id().unwrap());
}

let call_result_type = {
let mut maybe_call_result_phi = {
let ty = call_inst.result_type.unwrap();
if ty == self.op_type_void_id {
None
} else {
Some(ty)
Some(Instruction::new(
Op::Phi,
Some(ty),
Some(call_inst.result_id.unwrap()),
vec![],
))
}
};
let call_result_id = call_inst.result_id.unwrap();

// Get the debug "source location" instruction that applies to the call.
let custom_ext_inst_set_import = self.custom_ext_inst_set_import;
Expand Down Expand Up @@ -667,17 +646,12 @@ impl Inliner<'_, '_> {
});
let mut rewrite_rules = callee_parameters.zip(call_arguments).collect();

let return_variable = if call_result_type.is_some() {
Some(self.id())
} else {
None
};
let return_jump = self.id();
// Rewrite OpReturns of the callee.
let mut inlined_callee_blocks = self.get_inlined_blocks(
callee,
call_debug_src_loc_inst,
return_variable,
maybe_call_result_phi.as_mut(),
return_jump,
);
// Clone the IDs of the callee, because otherwise they'd be defined multiple times if the
Expand All @@ -686,6 +660,55 @@ impl Inliner<'_, '_> {
apply_rewrite_rules(&rewrite_rules, &mut inlined_callee_blocks);
self.apply_rewrite_for_decorations(&rewrite_rules);

if let Some(call_result_phi) = &mut maybe_call_result_phi {
// HACK(eddyb) new IDs should be generated earlier, to avoid pushing
// callee IDs to `call_result_phi.operands` only to rewrite them here.
for op in &mut call_result_phi.operands {
if let Some(id) = op.id_ref_any_mut() {
if let Some(&rewrite) = rewrite_rules.get(id) {
*id = rewrite;
}
}
}

// HACK(eddyb) this special-casing of the single-return case is
// really necessary for passes like `mem2reg` which are not capable
// of skipping through the extraneous `OpPhi`s on their own.
if let [returned_value, _return_block] = &call_result_phi.operands[..] {
let call_result_id = call_result_phi.result_id.unwrap();
let returned_value_id = returned_value.unwrap_id_ref();

maybe_call_result_phi = None;

// HACK(eddyb) this is a conservative approximation of all the
// instructions that could potentially reference the call result.
let reaching_insts = {
let (pre_call_blocks, call_and_post_call_blocks) =
caller.blocks.split_at_mut(block_idx);
(pre_call_blocks.iter_mut().flat_map(|block| {
block
.instructions
.iter_mut()
.take_while(|inst| inst.class.opcode == Op::Phi)
}))
.chain(
call_and_post_call_blocks
.iter_mut()
.flat_map(|block| &mut block.instructions),
)
};
for reaching_inst in reaching_insts {
for op in &mut reaching_inst.operands {
if let Some(id) = op.id_ref_any_mut() {
if *id == call_result_id {
*id = returned_value_id;
}
}
}
}
}
}

// Split the block containing the `OpFunctionCall` into pre-call vs post-call.
let pre_call_block_idx = block_idx;
#[expect(unused)]
Expand All @@ -701,18 +724,6 @@ impl Inliner<'_, '_> {
.unwrap();
assert!(call.class.opcode == Op::FunctionCall);

if let Some(call_result_type) = call_result_type {
// Generate the storage space for the return value: Do this *after* the split above,
// because if block_idx=0, inserting a variable here shifts call_index.
let ret_var_inst = Instruction::new(
Op::Variable,
Some(self.ptr_ty(call_result_type)),
Some(return_variable.unwrap()),
vec![Operand::StorageClass(StorageClass::Function)],
);
self.insert_opvariables(&mut caller.blocks[0], [ret_var_inst]);
}

// Insert non-entry inlined callee blocks just after the pre-call block.
let non_entry_inlined_callee_blocks = inlined_callee_blocks.drain(1..);
let num_non_entry_inlined_callee_blocks = non_entry_inlined_callee_blocks.len();
Expand All @@ -721,18 +732,9 @@ impl Inliner<'_, '_> {
non_entry_inlined_callee_blocks,
);

if let Some(call_result_type) = call_result_type {
// Add the load of the result value after the inlined function. Note there's guaranteed no
// OpPhi instructions since we just split this block.
post_call_block_insts.insert(
0,
Instruction::new(
Op::Load,
Some(call_result_type),
Some(call_result_id),
vec![Operand::IdRef(return_variable.unwrap())],
),
);
if let Some(call_result_phi) = maybe_call_result_phi {
// Add the `OpPhi` for the call result value, after the inlined function.
post_call_block_insts.insert(0, call_result_phi);
}

// Insert the post-call block, after all the inlined callee blocks.
Expand Down Expand Up @@ -899,7 +901,7 @@ impl Inliner<'_, '_> {
&mut self,
callee: &Function,
call_debug_src_loc_inst: Option<&Instruction>,
return_variable: Option<Word>,
mut maybe_call_result_phi: Option<&mut Instruction>,
return_jump: Word,
) -> Vec<Block> {
let Self {
Expand Down Expand Up @@ -997,17 +999,13 @@ impl Inliner<'_, '_> {
if let Op::Return | Op::ReturnValue = terminator.class.opcode {
if Op::ReturnValue == terminator.class.opcode {
let return_value = terminator.operands[0].id_ref_any().unwrap();
block.instructions.push(Instruction::new(
Op::Store,
None,
None,
vec![
Operand::IdRef(return_variable.unwrap()),
Operand::IdRef(return_value),
],
));
let call_result_phi = maybe_call_result_phi.as_deref_mut().unwrap();
call_result_phi.operands.extend([
Operand::IdRef(return_value),
Operand::IdRef(block.label_id().unwrap()),
]);
} else {
assert!(return_variable.is_none());
assert!(maybe_call_result_phi.is_none());
}
terminator =
Instruction::new(Op::Branch, None, None, vec![Operand::IdRef(return_jump)]);
Expand Down

0 comments on commit ce5558c

Please sign in to comment.