diff --git a/crates/cli-support/src/lib.rs b/crates/cli-support/src/lib.rs index abe5ec226ee..ea165386ddc 100755 --- a/crates/cli-support/src/lib.rs +++ b/crates/cli-support/src/lib.rs @@ -338,8 +338,15 @@ impl Bindgen { .context("failed getting Wasm module")?, }; - self.threads - .run(&mut module) + let mut start = module.start.take().map(|start| { + let mut builder = walrus::FunctionBuilder::new(&mut module.types, &[], &[]); + builder.func_body().call(start); + builder + }); + + let thread_count = self + .threads + .run(&mut module, &mut start) .with_context(|| "failed to prepare module for threading")?; // If requested, turn all mangled symbols into prettier unmangled @@ -380,9 +387,11 @@ impl Bindgen { // interface types. wit::process( &mut module, + &mut start, programs, self.externref, self.wasm_interface_types, + thread_count, self.emit_start, )?; @@ -435,6 +444,8 @@ impl Bindgen { .context("failed to transform return pointers into multi-value Wasm")?; } + module.start = start.map(|builder| builder.finish(Vec::new(), &mut module.funcs)); + // We've done a whole bunch of transformations to the wasm module, many // of which leave "garbage" lying around, so let's prune out all our // unnecessary things here. diff --git a/crates/cli-support/src/wit/mod.rs b/crates/cli-support/src/wit/mod.rs index d098eab93bc..82d34d38e31 100644 --- a/crates/cli-support/src/wit/mod.rs +++ b/crates/cli-support/src/wit/mod.rs @@ -8,6 +8,7 @@ use std::str; use walrus::MemoryId; use walrus::{ExportId, FunctionId, ImportId, Module}; use wasm_bindgen_shared::struct_function_export_name; +use wasm_bindgen_threads_xform::ThreadCount; mod incoming; mod nonstandard; @@ -20,6 +21,7 @@ pub use self::standard::*; struct Context<'a> { start_found: bool, module: &'a mut Module, + start: &'a mut Option, adapters: NonstandardWitSection, aux: WasmBindgenAux, /// All of the wasm module's exported functions. @@ -34,6 +36,7 @@ struct Context<'a> { descriptors: HashMap, externref_enabled: bool, wasm_interface_types: bool, + thread_count: Option, support_start: bool, } @@ -47,9 +50,11 @@ struct InstructionBuilder<'a, 'b> { pub fn process( module: &mut Module, + start: &mut Option, programs: Vec, externref_enabled: bool, wasm_interface_types: bool, + thread_count: Option, support_start: bool, ) -> Result<(NonstandardWitSectionId, WasmBindgenAuxId), Error> { let mut cx = Context { @@ -63,9 +68,11 @@ pub fn process( unique_crate_identifier: "", memory: wasm_bindgen_wasm_conventions::get_memory(module).ok(), module, + start, start_found: false, externref_enabled, wasm_interface_types, + thread_count, support_start, }; cx.init()?; @@ -312,14 +319,13 @@ impl<'a> Context<'a> { self.module .add_import_func(PLACEHOLDER_MODULE, "__wbindgen_init_externref_table", ty); - self.module.start = Some(match self.module.start { - Some(prev_start) => { - let mut builder = walrus::FunctionBuilder::new(&mut self.module.types, &[], &[]); - builder.func_body().call(import).call(prev_start); - builder.finish(Vec::new(), &mut self.module.funcs) - } - None => import, - }); + let module = &mut self.module; + let builder = self + .start + .get_or_insert_with(|| walrus::FunctionBuilder::new(&mut module.types, &[], &[])); + + builder.func_body().call_at(0, import); + self.bind_intrinsic(import_id, Intrinsic::InitExternrefTable)?; Ok(()) @@ -481,22 +487,21 @@ impl<'a> Context<'a> { return Ok(()); } - let prev_start = match self.module.start { - Some(f) => f, - None => { - self.module.start = Some(id); - return Ok(()); - } - }; + let module = &mut self.module; + let builder = self + .start + .get_or_insert_with(|| walrus::FunctionBuilder::new(&mut module.types, &[], &[])); - // Note that we call the previous start function, if any, first. This is + // Note that we leave the previous start function, if any, first. This is // because the start function currently only shows up when it's injected // through thread/externref transforms. These injected start functions // need to happen before user code, so we always schedule them first. - let mut builder = walrus::FunctionBuilder::new(&mut self.module.types, &[], &[]); - builder.func_body().call(prev_start).call(id); - let new_start = builder.finish(Vec::new(), &mut self.module.funcs); - self.module.start = Some(new_start); + if let Some(thread_count) = self.thread_count { + thread_count.wrap_start(builder, id) + } else { + builder.func_body().call(id); + } + Ok(()) } diff --git a/crates/threads-xform/src/lib.rs b/crates/threads-xform/src/lib.rs index 5f8361c21b0..169f0eaebc5 100644 --- a/crates/threads-xform/src/lib.rs +++ b/crates/threads-xform/src/lib.rs @@ -23,6 +23,9 @@ pub struct Config { enabled: bool, } +#[derive(Clone, Copy)] +pub struct ThreadCount(walrus::LocalId); + impl Config { /// Create a new configuration with default settings. pub fn new() -> Config { @@ -103,9 +106,13 @@ impl Config { /// * Some stack space is prepared for each thread after the first one. /// /// More and/or less may happen here over time, stay tuned! - pub fn run(&self, module: &mut Module) -> Result<(), Error> { + pub fn run( + &self, + module: &mut Module, + start: &mut Option, + ) -> Result, Error> { if !self.is_enabled(module) { - return Ok(()); + return Ok(None); } let memory = wasm_conventions::get_memory(module)?; @@ -157,7 +164,7 @@ impl Config { let _ = module.exports.add("__stack_alloc", stack.alloc); - inject_start(module, &tls, &stack, thread_counter_addr, memory)?; + let thread_count = inject_start(module, start, &tls, &stack, thread_counter_addr, memory)?; // we expose a `__wbindgen_thread_destroy()` helper function that deallocates stack space. // @@ -177,7 +184,21 @@ impl Config { // call while the leader is destroying its stack! You should make sure that this cannot happen. inject_destroy(module, &tls, &stack, memory)?; - Ok(()) + Ok(Some(thread_count)) + } +} + +impl ThreadCount { + pub fn wrap_start(self, builder: &mut walrus::FunctionBuilder, start: FunctionId) { + // We only want to call the start function if we are in the first thread. + // The thread counter should be 0 for the first thread. + builder.func_body().local_get(self.0).if_else( + None, + |_| {}, + |body| { + body.call(start); + }, + ); } } @@ -296,26 +317,22 @@ struct Stack { fn inject_start( module: &mut Module, + start: &mut Option, tls: &Tls, stack: &Stack, thread_counter_addr: i32, memory: MemoryId, -) -> Result<(), Error> { +) -> Result { use walrus::ir::*; assert!(stack.size % PAGE_SIZE == 0); - let mut builder = walrus::FunctionBuilder::new(&mut module.types, &[], &[]); + let builder = + start.get_or_insert_with(|| walrus::FunctionBuilder::new(&mut module.types, &[], &[])); let local = module.locals.add(ValType::I32); + let thread_count = module.locals.add(ValType::I32); let mut body = builder.func_body(); - // Call previous start function if one is available. Currently this is - // always true because LLVM injects a call to `__wasm_init_memory` as the - // start function which, well, initializes memory. - if let Some(prev) = module.start.take() { - body.call(prev); - } - let malloc = find_function(module, "__wbindgen_malloc")?; // Perform an if/else based on whether we're the first thread or not. Our @@ -324,6 +341,7 @@ fn inject_start( body.i32_const(thread_counter_addr) .i32_const(1) .atomic_rmw(memory, AtomicOp::Add, AtomicWidth::I32, ATOMIC_MEM_ARG) + .local_tee(thread_count) .if_else( None, // If our thread id is nonzero then we're the second or greater thread, so @@ -360,13 +378,7 @@ fn inject_start( .global_get(tls.base) .call(tls.init); - // Finish off our newly generated function. - let start_id = builder.finish(Vec::new(), &mut module.funcs); - - // ... and finally flag it as the new start function - module.start = Some(start_id); - - Ok(()) + Ok(ThreadCount(thread_count)) } fn inject_destroy( diff --git a/crates/threads-xform/tests/all.rs b/crates/threads-xform/tests/all.rs index 1101ccd2fcb..5bb7cb5fbc1 100644 --- a/crates/threads-xform/tests/all.rs +++ b/crates/threads-xform/tests/all.rs @@ -22,7 +22,7 @@ fn runtest(test: &Test) -> Result { let config = wasm_bindgen_threads_xform::Config::new(); - config.run(&mut module)?; + config.run(&mut module, &mut None)?; walrus::passes::gc::run(&mut module); let features = {