From d2c1463c24ef0e8be959d6e4509a6f039ffcbc01 Mon Sep 17 00:00:00 2001 From: Tianyu Geng Date: Wed, 26 Jun 2024 23:20:56 -0700 Subject: [PATCH] fix segfault when invoking compiled function with optimization --- Cargo.toml | 2 +- runtime/src/runtime_utils.rs | 1 - src/backend/common.rs | 3 +-- src/backend/compiler.rs | 44 ++++++++++++++++++++++++++++++++---- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 311d2e6..89555ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,4 +24,4 @@ itertools = "0.12.0" enum-ordinalize = "4.3.0" [profile.dev] -opt-level = 0 \ No newline at end of file +opt-level = 3 \ No newline at end of file diff --git a/runtime/src/runtime_utils.rs b/runtime/src/runtime_utils.rs index 485eff8..4ac9168 100644 --- a/runtime/src/runtime_utils.rs +++ b/runtime/src/runtime_utils.rs @@ -410,7 +410,6 @@ pub unsafe extern "C" fn runtime_process_simple_handler_result( // Actually the result is not needed since the continuation is a disposer loader continuation, which always // ignores the last result. let result_ptr = next_base_address.sub(1); - // plus 1 to tag the pointer a an SPtr. result_ptr.write(UniformType::to_uniform_sptr(empty_struct_ptr())); // Here we use the matching handler because the jump needs to restore execution to the state at the matching // handler. This mismatch between fp, sp, and lr with the continuation and argument stack is fine because diff --git a/src/backend/common.rs b/src/backend/common.rs index 988d512..581512d 100644 --- a/src/backend/common.rs +++ b/src/backend/common.rs @@ -410,7 +410,6 @@ impl BuiltinFunction { // transform thunk is set on the argument stack by `runtime_register_handler` when it // creates the transform loader continuation. let transform_thunk = builder.ins().load(I64, MemFlags::new(), base_address, 0); - Self::call_built_in(m, builder, BuiltinFunction::DebugHelper, &[base_address, last_result_ptr, transform_thunk]); let inst = Self::call_built_in(m, builder, BuiltinFunction::PopHandler, &[]); let handler_parameter = builder.inst_results(inst)[0]; @@ -456,7 +455,7 @@ impl BuiltinFunction { let base_address = builder.block_params(entry_block)[0]; let current_continuation = builder.block_params(entry_block)[1]; - // transform thunk is set on the argument stack by `runtime_register_handler` when it + // disposer thunk is set on the argument stack by `runtime_register_handler` when it // creates the transform loader continuation. let disposer_thunk = builder.ins().load(I64, MemFlags::new(), base_address, 0); diff --git a/src/backend/compiler.rs b/src/backend/compiler.rs index 7a23ab3..5ff0b27 100644 --- a/src/backend/compiler.rs +++ b/src/backend/compiler.rs @@ -1,3 +1,4 @@ +use std::arch::global_asm; use std::collections::HashMap; use cranelift::codegen::isa::CallConv; use cranelift::prelude::*; @@ -61,13 +62,48 @@ impl Default for Compiler { } } +#[cfg(target_arch = "aarch64")] +global_asm!(r#" + .global _invoke_compiled_function + + _invoke_compiled_function: + // Store all callee-saved registers on the stack + stp x19, x20, [sp, #-16]! + stp x21, x22, [sp, #-16]! + stp x23, x24, [sp, #-16]! + stp x25, x26, [sp, #-16]! + stp x27, x28, [sp, #-16]! + stp x29, x30, [sp, #-16]! + + // x0 is the first argument, which is the function pointer + blr x0 // Call the function + + // Restore all callee-saved registers from the stack + ldp x29, x30, [sp], #16 + ldp x27, x28, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x19, x20, [sp], #16 + + ret +"#); + +extern "C" { + /// compiled function are often invoked with cranelift's tail call convention , which doesn't + /// respect callee-saved registers (callee-saved registers are simply not used). But rust + /// functions expects callee-saved registers to be preserved. So we need to wrap the compiled + /// function and save & restore the callee-saved registers when invoking the compiled function. + fn invoke_compiled_function(f: fn() -> usize) -> usize; +} + impl Compiler { - pub fn finalize_and_get_main(&mut self) -> fn() -> usize { + pub fn finalize_and_get_main(&mut self) -> impl Fn() -> usize { self.module.finalize_definitions().unwrap(); let main_func_id = self.local_functions.get(MAIN_WRAPPER_NAME).unwrap(); - unsafe { - let func_ptr = self.module.get_finalized_function(*main_func_id); - std::mem::transmute::<_, fn() -> usize>(func_ptr) + let func_ptr = self.module.get_finalized_function(*main_func_id); + move || unsafe { + invoke_compiled_function(std::mem::transmute::<_, fn() -> usize>(func_ptr)) } } }