diff --git a/crates/component-macro/src/bindgen.rs b/crates/component-macro/src/bindgen.rs index 3dfd7f846ea4..a87616e55d82 100644 --- a/crates/component-macro/src/bindgen.rs +++ b/crates/component-macro/src/bindgen.rs @@ -1,10 +1,11 @@ use proc_macro2::{Span, TokenStream}; use std::collections::HashMap; +use std::collections::HashSet; use std::path::{Path, PathBuf}; use syn::parse::{Error, Parse, ParseStream, Result}; use syn::punctuated::Punctuated; use syn::{braced, token, Ident, Token}; -use wasmtime_wit_bindgen::{Opts, Ownership, TrappableError}; +use wasmtime_wit_bindgen::{AsyncConfig, Opts, Ownership, TrappableError}; use wit_parser::{PackageId, Resolve, UnresolvedPackage, WorldId}; pub struct Config { @@ -15,7 +16,7 @@ pub struct Config { } pub fn expand(input: &Config) -> Result { - if !cfg!(feature = "async") && input.opts.async_ { + if !cfg!(feature = "async") && input.opts.async_.maybe_async() { return Err(Error::new( Span::call_site(), "cannot enable async bindings unless `async` crate feature is active", @@ -45,6 +46,7 @@ impl Parse for Config { let mut world = None; let mut inline = None; let mut path = None; + let mut async_configured = false; if input.peek(token::Brace) { let content; @@ -71,7 +73,13 @@ impl Parse for Config { inline = Some(s.value()); } Opt::Tracing(val) => opts.tracing = val, - Opt::Async(val) => opts.async_ = val, + Opt::Async(val, span) => { + if async_configured { + return Err(Error::new(span, "cannot specify second async config")); + } + async_configured = true; + opts.async_ = val; + } Opt::TrappableErrorType(val) => opts.trappable_error_type = val, Opt::Ownership(val) => opts.ownership = val, Opt::Interfaces(s) => { @@ -171,6 +179,8 @@ mod kw { syn::custom_keyword!(ownership); syn::custom_keyword!(interfaces); syn::custom_keyword!(with); + syn::custom_keyword!(except_imports); + syn::custom_keyword!(only_imports); } enum Opt { @@ -178,7 +188,7 @@ enum Opt { Path(syn::LitStr), Inline(syn::LitStr), Tracing(bool), - Async(bool), + Async(AsyncConfig, Span), TrappableErrorType(Vec), Ownership(Ownership), Interfaces(syn::LitStr), @@ -205,9 +215,43 @@ impl Parse for Opt { input.parse::()?; Ok(Opt::Tracing(input.parse::()?.value)) } else if l.peek(Token![async]) { - input.parse::()?; + let span = input.parse::()?.span; input.parse::()?; - Ok(Opt::Async(input.parse::()?.value)) + if input.peek(syn::LitBool) { + match input.parse::()?.value { + true => Ok(Opt::Async(AsyncConfig::All, span)), + false => Ok(Opt::Async(AsyncConfig::None, span)), + } + } else { + let contents; + syn::braced!(contents in input); + + let l = contents.lookahead1(); + let ctor: fn(HashSet) -> AsyncConfig = if l.peek(kw::except_imports) { + contents.parse::()?; + contents.parse::()?; + AsyncConfig::AllExceptImports + } else if l.peek(kw::only_imports) { + contents.parse::()?; + contents.parse::()?; + AsyncConfig::OnlyImports + } else { + return Err(l.error()); + }; + + let list; + syn::bracketed!(list in contents); + let fields: Punctuated = + list.parse_terminated(Parse::parse, Token![,])?; + + if contents.peek(Token![,]) { + contents.parse::()?; + } + Ok(Opt::Async( + ctor(fields.iter().map(|s| s.value()).collect()), + span, + )) + } } else if l.peek(kw::ownership) { input.parse::()?; input.parse::()?; diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index 4129975def45..6003ef67cdc4 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -281,6 +281,22 @@ pub(crate) use self::store::ComponentStoreData; /// // This option defaults to `false`. /// async: true, /// +/// // Alternative mode of async configuration where this still implies +/// // async instantiation happens, for example, but more control is +/// // provided over which imports are async and which aren't. +/// // +/// // Note that in this mode all exports are still async. +/// async: { +/// // All imports are async except for functions with these names +/// except_imports: ["foo", "bar"], +/// +/// // All imports are synchronous except for functions with these names +/// // +/// // Note that this key cannot be specified with `except_imports`, +/// // only one or the other is accepted. +/// only_imports: ["foo", "bar"], +/// }, +/// /// // This can be used to translate WIT return values of the form /// // `result` into `Result` in Rust. /// // The `RustErrorType` structure will have an automatically generated diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 44431ba2e0f3..21ae6367b489 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -3,7 +3,7 @@ use crate::types::{TypeInfo, Types}; use anyhow::{anyhow, bail, Context}; use heck::*; use indexmap::IndexMap; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::Write as _; use std::io::{Read, Write}; use std::mem; @@ -94,7 +94,7 @@ pub struct Opts { pub tracing: bool, /// Whether or not to use async rust functions and traits. - pub async_: bool, + pub async_: AsyncConfig, /// A list of "trappable errors" which are used to replace the `E` in /// `result` found in WIT. @@ -123,6 +123,42 @@ pub struct TrappableError { pub rust_type_name: String, } +#[derive(Default, Debug, Clone)] +pub enum AsyncConfig { + /// No functions are `async`. + #[default] + None, + /// All generated functions should be `async`. + All, + /// These imported functions should not be async, but everything else is. + AllExceptImports(HashSet), + /// These functions are the only imports that are async, all other imports + /// are sync. + /// + /// Note that all exports are still async in this situation. + OnlyImports(HashSet), +} + +impl AsyncConfig { + pub fn is_import_async(&self, f: &str) -> bool { + match self { + AsyncConfig::None => false, + AsyncConfig::All => true, + AsyncConfig::AllExceptImports(set) => !set.contains(f), + AsyncConfig::OnlyImports(set) => set.contains(f), + } + } + + pub fn maybe_async(&self) -> bool { + match self { + AsyncConfig::None => false, + AsyncConfig::All | AsyncConfig::AllExceptImports(_) | AsyncConfig::OnlyImports(_) => { + true + } + } + } +} + impl Opts { pub fn generate(&self, resolve: &Resolve, world: WorldId) -> String { let mut r = Wasmtime::default(); @@ -412,7 +448,7 @@ impl Wasmtime { } self.src.push_str("}\n"); - let (async_, async__, send, await_) = if self.opts.async_ { + let (async_, async__, send, await_) = if self.opts.async_.maybe_async() { ("async", "_async", ":Send", ".await") } else { ("", "", "", "") @@ -577,7 +613,7 @@ impl Wasmtime { } let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name); - if self.opts.async_ { + if self.opts.async_.maybe_async() { uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]") } uwrite!(self.src, "pub trait {world_camel}Imports"); @@ -646,7 +682,7 @@ impl Wasmtime { self.src.push_str(&name); } - let maybe_send = if self.opts.async_ { + let maybe_send = if self.opts.async_.maybe_async() { " + Send, T: Send" } else { "" @@ -854,7 +890,7 @@ impl<'a> InterfaceGenerator<'a> { self.rustdoc(docs); uwriteln!(self.src, "pub enum {camel} {{}}"); - if self.gen.opts.async_ { + if self.gen.opts.async_.maybe_async() { uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]") } @@ -1375,7 +1411,7 @@ impl<'a> InterfaceGenerator<'a> { let iface = &self.resolve.interfaces[id]; let owner = TypeOwner::Interface(id); - if self.gen.opts.async_ { + if self.gen.opts.async_.maybe_async() { uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]") } // Generate the `pub trait` which represents the host functionality for @@ -1400,7 +1436,7 @@ impl<'a> InterfaceGenerator<'a> { } uwriteln!(self.src, "}}"); - let where_clause = if self.gen.opts.async_ { + let where_clause = if self.gen.opts.async_.maybe_async() { "T: Send, U: Host + Send".to_string() } else { "U: Host".to_string() @@ -1443,7 +1479,7 @@ impl<'a> InterfaceGenerator<'a> { uwrite!( self.src, "{linker}.{}(\"{}\", ", - if self.gen.opts.async_ { + if self.gen.opts.async_.is_import_async(&func.name) { "func_wrap_async" } else { "func_wrap" @@ -1472,7 +1508,7 @@ impl<'a> InterfaceGenerator<'a> { self.src.push_str(", "); } self.src.push_str(") |"); - if self.gen.opts.async_ { + if self.gen.opts.async_.is_import_async(&func.name) { self.src.push_str(" Box::new(async move { \n"); } else { self.src.push_str(" { \n"); @@ -1541,7 +1577,7 @@ impl<'a> InterfaceGenerator<'a> { for (i, _) in func.params.iter().enumerate() { uwrite!(self.src, "arg{},", i); } - if self.gen.opts.async_ { + if self.gen.opts.async_.is_import_async(&func.name) { uwrite!(self.src, ").await;\n"); } else { uwrite!(self.src, ");\n"); @@ -1571,7 +1607,7 @@ impl<'a> InterfaceGenerator<'a> { uwrite!(self.src, "r\n"); } - if self.gen.opts.async_ { + if self.gen.opts.async_.is_import_async(&func.name) { // Need to close Box::new and async block self.src.push_str("})"); } else { @@ -1582,7 +1618,7 @@ impl<'a> InterfaceGenerator<'a> { fn generate_function_trait_sig(&mut self, func: &Function) { self.rustdoc(&func.docs); - if self.gen.opts.async_ { + if self.gen.opts.async_.is_import_async(&func.name) { self.push_str("async "); } self.push_str("fn "); @@ -1658,7 +1694,11 @@ impl<'a> InterfaceGenerator<'a> { ns: Option<&WorldKey>, func: &Function, ) { - let (async_, async__, await_) = if self.gen.opts.async_ { + // Exports must be async if anything could be async, it's just imports + // that get to be optionally async/sync. + let is_async = self.gen.opts.async_.maybe_async(); + + let (async_, async__, await_) = if is_async { ("async", "_async", ".await") } else { ("", "", "") @@ -1681,7 +1721,7 @@ impl<'a> InterfaceGenerator<'a> { self.src.push_str(") -> wasmtime::Result<"); self.print_result_ty(&func.results, TypeMode::Owned); - if self.gen.opts.async_ { + if is_async { self.src .push_str("> where ::Data: Send {\n"); } else { diff --git a/tests/all/component_model/bindgen.rs b/tests/all/component_model/bindgen.rs index dbe65a3ddc6f..f0160dc92203 100644 --- a/tests/all/component_model/bindgen.rs +++ b/tests/all/component_model/bindgen.rs @@ -316,3 +316,97 @@ mod resources_at_interface_level { Ok(()) } } + +mod async_config { + use super::*; + + wasmtime::component::bindgen!({ + inline: " + package foo:foo + + world t1 { + import x: func() + import y: func() + export z: func() + } + ", + async: true, + }); + + struct T; + + #[async_trait::async_trait] + impl T1Imports for T { + async fn x(&mut self) -> Result<()> { + Ok(()) + } + + async fn y(&mut self) -> Result<()> { + Ok(()) + } + } + + async fn _test_t1(t1: &T1, store: &mut Store<()>) { + let _ = t1.call_z(&mut *store).await; + } + + wasmtime::component::bindgen!({ + inline: " + package foo:foo + + world t2 { + import x: func() + import y: func() + export z: func() + } + ", + async: { + except_imports: ["x"], + }, + }); + + #[async_trait::async_trait] + impl T2Imports for T { + fn x(&mut self) -> Result<()> { + Ok(()) + } + + async fn y(&mut self) -> Result<()> { + Ok(()) + } + } + + async fn _test_t2(t2: &T2, store: &mut Store<()>) { + let _ = t2.call_z(&mut *store).await; + } + + wasmtime::component::bindgen!({ + inline: " + package foo:foo + + world t3 { + import x: func() + import y: func() + export z: func() + } + ", + async: { + only_imports: ["x"], + }, + }); + + #[async_trait::async_trait] + impl T3Imports for T { + async fn x(&mut self) -> Result<()> { + Ok(()) + } + + fn y(&mut self) -> Result<()> { + Ok(()) + } + } + + async fn _test_t3(t3: &T3, store: &mut Store<()>) { + let _ = t3.call_z(&mut *store).await; + } +}