Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance async configuration of bindgen! macro #6942

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions crates/component-macro/src/bindgen.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use proc_macro2::{Span, TokenStream};
use std::collections::HashMap;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::{braced, token, Ident, Token};
use wasmtime_wit_bindgen::{Opts, Ownership, TrappableError};
use wasmtime_wit_bindgen::{AsyncConfig, Opts, Ownership, TrappableError};
use wit_parser::{PackageId, Resolve, UnresolvedPackage, WorldId};

pub struct Config {
Expand All @@ -15,7 +16,7 @@ pub struct Config {
}

pub fn expand(input: &Config) -> Result<TokenStream> {
if !cfg!(feature = "async") && input.opts.async_ {
if !cfg!(feature = "async") && input.opts.async_.maybe_async() {
return Err(Error::new(
Span::call_site(),
"cannot enable async bindings unless `async` crate feature is active",
Expand Down Expand Up @@ -45,6 +46,7 @@ impl Parse for Config {
let mut world = None;
let mut inline = None;
let mut path = None;
let mut async_configured = false;

if input.peek(token::Brace) {
let content;
Expand All @@ -71,7 +73,13 @@ impl Parse for Config {
inline = Some(s.value());
}
Opt::Tracing(val) => opts.tracing = val,
Opt::Async(val) => opts.async_ = val,
Opt::Async(val, span) => {
if async_configured {
return Err(Error::new(span, "cannot specify second async config"));
}
async_configured = true;
opts.async_ = val;
}
Opt::TrappableErrorType(val) => opts.trappable_error_type = val,
Opt::Ownership(val) => opts.ownership = val,
Opt::Interfaces(s) => {
Expand Down Expand Up @@ -171,14 +179,16 @@ mod kw {
syn::custom_keyword!(ownership);
syn::custom_keyword!(interfaces);
syn::custom_keyword!(with);
syn::custom_keyword!(except_imports);
syn::custom_keyword!(only_imports);
}

enum Opt {
World(syn::LitStr),
Path(syn::LitStr),
Inline(syn::LitStr),
Tracing(bool),
Async(bool),
Async(AsyncConfig, Span),
TrappableErrorType(Vec<TrappableError>),
Ownership(Ownership),
Interfaces(syn::LitStr),
Expand All @@ -205,9 +215,43 @@ impl Parse for Opt {
input.parse::<Token![:]>()?;
Ok(Opt::Tracing(input.parse::<syn::LitBool>()?.value))
} else if l.peek(Token![async]) {
input.parse::<Token![async]>()?;
let span = input.parse::<Token![async]>()?.span;
input.parse::<Token![:]>()?;
Ok(Opt::Async(input.parse::<syn::LitBool>()?.value))
if input.peek(syn::LitBool) {
match input.parse::<syn::LitBool>()?.value {
true => Ok(Opt::Async(AsyncConfig::All, span)),
false => Ok(Opt::Async(AsyncConfig::None, span)),
}
} else {
let contents;
syn::braced!(contents in input);

let l = contents.lookahead1();
let ctor: fn(HashSet<String>) -> AsyncConfig = if l.peek(kw::except_imports) {
contents.parse::<kw::except_imports>()?;
contents.parse::<Token![:]>()?;
AsyncConfig::AllExceptImports
} else if l.peek(kw::only_imports) {
contents.parse::<kw::only_imports>()?;
contents.parse::<Token![:]>()?;
AsyncConfig::OnlyImports
} else {
return Err(l.error());
};

let list;
syn::bracketed!(list in contents);
let fields: Punctuated<syn::LitStr, Token![,]> =
list.parse_terminated(Parse::parse, Token![,])?;

if contents.peek(Token![,]) {
contents.parse::<Token![,]>()?;
}
Ok(Opt::Async(
ctor(fields.iter().map(|s| s.value()).collect()),
span,
))
}
} else if l.peek(kw::ownership) {
input.parse::<kw::ownership>()?;
input.parse::<Token![:]>()?;
Expand Down
16 changes: 16 additions & 0 deletions crates/wasmtime/src/component/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,22 @@ pub(crate) use self::store::ComponentStoreData;
/// // This option defaults to `false`.
/// async: true,
///
/// // Alternative mode of async configuration where this still implies
/// // async instantiation happens, for example, but more control is
/// // provided over which imports are async and which aren't.
/// //
/// // Note that in this mode all exports are still async.
/// async: {
/// // All imports are async except for functions with these names
/// except_imports: ["foo", "bar"],
///
/// // All imports are synchronous except for functions with these names
/// //
/// // Note that this key cannot be specified with `except_imports`,
/// // only one or the other is accepted.
/// only_imports: ["foo", "bar"],
/// },
///
/// // This can be used to translate WIT return values of the form
/// // `result<T, error-type>` into `Result<T, RustErrorType>` in Rust.
/// // The `RustErrorType` structure will have an automatically generated
Expand Down
70 changes: 55 additions & 15 deletions crates/wit-bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::types::{TypeInfo, Types};
use anyhow::{anyhow, bail, Context};
use heck::*;
use indexmap::IndexMap;
use std::collections::{BTreeMap, HashMap};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Write as _;
use std::io::{Read, Write};
use std::mem;
Expand Down Expand Up @@ -94,7 +94,7 @@ pub struct Opts {
pub tracing: bool,

/// Whether or not to use async rust functions and traits.
pub async_: bool,
pub async_: AsyncConfig,

/// A list of "trappable errors" which are used to replace the `E` in
/// `result<T, E>` found in WIT.
Expand Down Expand Up @@ -123,6 +123,42 @@ pub struct TrappableError {
pub rust_type_name: String,
}

#[derive(Default, Debug, Clone)]
pub enum AsyncConfig {
/// No functions are `async`.
#[default]
None,
/// All generated functions should be `async`.
All,
/// These imported functions should not be async, but everything else is.
AllExceptImports(HashSet<String>),
/// These functions are the only imports that are async, all other imports
/// are sync.
///
/// Note that all exports are still async in this situation.
OnlyImports(HashSet<String>),
}

impl AsyncConfig {
pub fn is_import_async(&self, f: &str) -> bool {
match self {
AsyncConfig::None => false,
AsyncConfig::All => true,
AsyncConfig::AllExceptImports(set) => !set.contains(f),
AsyncConfig::OnlyImports(set) => set.contains(f),
}
}

pub fn maybe_async(&self) -> bool {
match self {
AsyncConfig::None => false,
AsyncConfig::All | AsyncConfig::AllExceptImports(_) | AsyncConfig::OnlyImports(_) => {
true
}
}
}
}

impl Opts {
pub fn generate(&self, resolve: &Resolve, world: WorldId) -> String {
let mut r = Wasmtime::default();
Expand Down Expand Up @@ -412,7 +448,7 @@ impl Wasmtime {
}
self.src.push_str("}\n");

let (async_, async__, send, await_) = if self.opts.async_ {
let (async_, async__, send, await_) = if self.opts.async_.maybe_async() {
("async", "_async", ":Send", ".await")
} else {
("", "", "", "")
Expand Down Expand Up @@ -577,7 +613,7 @@ impl Wasmtime {
}

let world_camel = to_rust_upper_camel_case(&resolve.worlds[world].name);
if self.opts.async_ {
if self.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}
uwrite!(self.src, "pub trait {world_camel}Imports");
Expand Down Expand Up @@ -646,7 +682,7 @@ impl Wasmtime {
self.src.push_str(&name);
}

let maybe_send = if self.opts.async_ {
let maybe_send = if self.opts.async_.maybe_async() {
" + Send, T: Send"
} else {
""
Expand Down Expand Up @@ -854,7 +890,7 @@ impl<'a> InterfaceGenerator<'a> {
self.rustdoc(docs);
uwriteln!(self.src, "pub enum {camel} {{}}");

if self.gen.opts.async_ {
if self.gen.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}

Expand Down Expand Up @@ -1375,7 +1411,7 @@ impl<'a> InterfaceGenerator<'a> {
let iface = &self.resolve.interfaces[id];
let owner = TypeOwner::Interface(id);

if self.gen.opts.async_ {
if self.gen.opts.async_.maybe_async() {
uwriteln!(self.src, "#[wasmtime::component::__internal::async_trait]")
}
// Generate the `pub trait` which represents the host functionality for
Expand All @@ -1400,7 +1436,7 @@ impl<'a> InterfaceGenerator<'a> {
}
uwriteln!(self.src, "}}");

let where_clause = if self.gen.opts.async_ {
let where_clause = if self.gen.opts.async_.maybe_async() {
"T: Send, U: Host + Send".to_string()
} else {
"U: Host".to_string()
Expand Down Expand Up @@ -1443,7 +1479,7 @@ impl<'a> InterfaceGenerator<'a> {
uwrite!(
self.src,
"{linker}.{}(\"{}\", ",
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
"func_wrap_async"
} else {
"func_wrap"
Expand Down Expand Up @@ -1472,7 +1508,7 @@ impl<'a> InterfaceGenerator<'a> {
self.src.push_str(", ");
}
self.src.push_str(") |");
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
self.src.push_str(" Box::new(async move { \n");
} else {
self.src.push_str(" { \n");
Expand Down Expand Up @@ -1541,7 +1577,7 @@ impl<'a> InterfaceGenerator<'a> {
for (i, _) in func.params.iter().enumerate() {
uwrite!(self.src, "arg{},", i);
}
if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
uwrite!(self.src, ").await;\n");
} else {
uwrite!(self.src, ");\n");
Expand Down Expand Up @@ -1571,7 +1607,7 @@ impl<'a> InterfaceGenerator<'a> {
uwrite!(self.src, "r\n");
}

if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
// Need to close Box::new and async block
self.src.push_str("})");
} else {
Expand All @@ -1582,7 +1618,7 @@ impl<'a> InterfaceGenerator<'a> {
fn generate_function_trait_sig(&mut self, func: &Function) {
self.rustdoc(&func.docs);

if self.gen.opts.async_ {
if self.gen.opts.async_.is_import_async(&func.name) {
self.push_str("async ");
}
self.push_str("fn ");
Expand Down Expand Up @@ -1658,7 +1694,11 @@ impl<'a> InterfaceGenerator<'a> {
ns: Option<&WorldKey>,
func: &Function,
) {
let (async_, async__, await_) = if self.gen.opts.async_ {
// Exports must be async if anything could be async, it's just imports
// that get to be optionally async/sync.
let is_async = self.gen.opts.async_.maybe_async();

let (async_, async__, await_) = if is_async {
("async", "_async", ".await")
} else {
("", "", "")
Expand All @@ -1681,7 +1721,7 @@ impl<'a> InterfaceGenerator<'a> {
self.src.push_str(") -> wasmtime::Result<");
self.print_result_ty(&func.results, TypeMode::Owned);

if self.gen.opts.async_ {
if is_async {
self.src
.push_str("> where <S as wasmtime::AsContext>::Data: Send {\n");
} else {
Expand Down
Loading