Skip to content

Commit

Permalink
fix segfault when invoking compiled function with optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
tgeng committed Jun 27, 2024
1 parent 91772b8 commit d2c1463
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ itertools = "0.12.0"
enum-ordinalize = "4.3.0"

[profile.dev]
opt-level = 0
opt-level = 3
1 change: 0 additions & 1 deletion runtime/src/runtime_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ pub unsafe extern "C" fn runtime_process_simple_handler_result(
// Actually the result is not needed since the continuation is a disposer loader continuation, which always
// ignores the last result.
let result_ptr = next_base_address.sub(1);
// plus 1 to tag the pointer a an SPtr.
result_ptr.write(UniformType::to_uniform_sptr(empty_struct_ptr()));
// Here we use the matching handler because the jump needs to restore execution to the state at the matching
// handler. This mismatch between fp, sp, and lr with the continuation and argument stack is fine because
Expand Down
3 changes: 1 addition & 2 deletions src/backend/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ impl BuiltinFunction {
// transform thunk is set on the argument stack by `runtime_register_handler` when it
// creates the transform loader continuation.
let transform_thunk = builder.ins().load(I64, MemFlags::new(), base_address, 0);
Self::call_built_in(m, builder, BuiltinFunction::DebugHelper, &[base_address, last_result_ptr, transform_thunk]);

let inst = Self::call_built_in(m, builder, BuiltinFunction::PopHandler, &[]);
let handler_parameter = builder.inst_results(inst)[0];
Expand Down Expand Up @@ -456,7 +455,7 @@ impl BuiltinFunction {
let base_address = builder.block_params(entry_block)[0];
let current_continuation = builder.block_params(entry_block)[1];

// transform thunk is set on the argument stack by `runtime_register_handler` when it
// disposer thunk is set on the argument stack by `runtime_register_handler` when it
// creates the transform loader continuation.
let disposer_thunk = builder.ins().load(I64, MemFlags::new(), base_address, 0);

Expand Down
44 changes: 40 additions & 4 deletions src/backend/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::arch::global_asm;
use std::collections::HashMap;
use cranelift::codegen::isa::CallConv;
use cranelift::prelude::*;
Expand Down Expand Up @@ -61,13 +62,48 @@ impl Default for Compiler<JITModule> {
}
}

#[cfg(target_arch = "aarch64")]
global_asm!(r#"
.global _invoke_compiled_function
_invoke_compiled_function:
// Store all callee-saved registers on the stack
stp x19, x20, [sp, #-16]!
stp x21, x22, [sp, #-16]!
stp x23, x24, [sp, #-16]!
stp x25, x26, [sp, #-16]!
stp x27, x28, [sp, #-16]!
stp x29, x30, [sp, #-16]!
// x0 is the first argument, which is the function pointer
blr x0 // Call the function
// Restore all callee-saved registers from the stack
ldp x29, x30, [sp], #16
ldp x27, x28, [sp], #16
ldp x25, x26, [sp], #16
ldp x23, x24, [sp], #16
ldp x21, x22, [sp], #16
ldp x19, x20, [sp], #16
ret
"#);

extern "C" {
/// compiled function are often invoked with cranelift's tail call convention , which doesn't
/// respect callee-saved registers (callee-saved registers are simply not used). But rust
/// functions expects callee-saved registers to be preserved. So we need to wrap the compiled
/// function and save & restore the callee-saved registers when invoking the compiled function.
fn invoke_compiled_function(f: fn() -> usize) -> usize;
}

impl Compiler<JITModule> {
pub fn finalize_and_get_main(&mut self) -> fn() -> usize {
pub fn finalize_and_get_main(&mut self) -> impl Fn() -> usize {
self.module.finalize_definitions().unwrap();
let main_func_id = self.local_functions.get(MAIN_WRAPPER_NAME).unwrap();
unsafe {
let func_ptr = self.module.get_finalized_function(*main_func_id);
std::mem::transmute::<_, fn() -> usize>(func_ptr)
let func_ptr = self.module.get_finalized_function(*main_func_id);
move || unsafe {
invoke_compiled_function(std::mem::transmute::<_, fn() -> usize>(func_ptr))
}
}
}
Expand Down

0 comments on commit d2c1463

Please sign in to comment.