From 138dbc596a749756047b01e163d17dd4188d9bb5 Mon Sep 17 00:00:00 2001 From: Daniel Vigovszky Date: Wed, 17 Jul 2024 16:55:40 +0200 Subject: [PATCH] Async ResourceStore trait --- Cargo.lock | 11 ++-- wasm-rpc/Cargo.toml | 3 +- wasm-rpc/src/wasmtime.rs | 137 ++++++++++++++++++++------------------- 3 files changed, 80 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4f92eaa6..6e6d16ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,9 +276,9 @@ dependencies = [ [[package]] name = "async-recursion" -version = "1.0.5" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", @@ -1704,6 +1704,7 @@ name = "golem-wasm-rpc" version = "0.0.0" dependencies = [ "arbitrary", + "async-recursion", "async-trait", "bigdecimal", "bincode", @@ -3022,8 +3023,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.4.1", - "itertools 0.11.0", + "heck 0.5.0", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -3043,7 +3044,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.67", diff --git a/wasm-rpc/Cargo.toml b/wasm-rpc/Cargo.toml index 923c712c..e552601a 100644 --- a/wasm-rpc/Cargo.toml +++ b/wasm-rpc/Cargo.toml @@ -17,6 +17,7 @@ crate-type = ["cdylib", "rlib"] wit-bindgen-rt = { version = "0.26.0", features = ["bitflags"] } arbitrary = { version = "1.3.2", features = ["derive"], optional = true } +async-recursion = { version = "1.1.1", optional = true } async-trait = { version = "0.1.77", optional = true } bigdecimal = { version = "0.4.5", optional = true } bincode = { version = "2.0.0-rc.3", optional = true } @@ -61,7 +62,7 @@ serde = ["dep:serde"] stub = [] text = ["wasmtime", "dep:wasm-wave"] typeinfo = ["dep:golem-wasm-ast"] -wasmtime = ["dep:wasmtime", "dep:wasmtime-wasi", "typeinfo"] +wasmtime = ["dep:wasmtime", "dep:wasmtime-wasi", "dep:async-recursion", "typeinfo"] [package.metadata.component] diff --git a/wasm-rpc/src/wasmtime.rs b/wasm-rpc/src/wasmtime.rs index 4af82607..52de3c0d 100644 --- a/wasm-rpc/src/wasmtime.rs +++ b/wasm-rpc/src/wasmtime.rs @@ -13,6 +13,8 @@ // limitations under the License. use crate::{Uri, Value}; +use async_recursion::async_recursion; +use async_trait::async_trait; use golem_wasm_ast::analysis::AnalysedType; use wasmtime::component::{types, ResourceAny, Type, Val}; @@ -22,11 +24,12 @@ pub enum EncodingError { Unknown { details: String }, } +#[async_trait] pub trait ResourceStore { fn self_uri(&self) -> Uri; - fn add(&mut self, resource: ResourceAny) -> u64; - fn get(&mut self, resource_id: u64) -> Option; - fn borrow(&self, resource_id: u64) -> Option; + async fn add(&mut self, resource: ResourceAny) -> u64; + async fn get(&mut self, resource_id: u64) -> Option; + async fn borrow(&self, resource_id: u64) -> Option; } pub struct DecodeParamResult { @@ -44,10 +47,11 @@ impl DecodeParamResult { } /// Converts a Value to a wasmtime Val based on the available type information. -pub fn decode_param( +#[async_recursion] +pub async fn decode_param( param: &Value, param_type: &Type, - resource_store: &mut impl ResourceStore, + resource_store: &mut (impl ResourceStore + Send), ) -> Result { match param_type { Type::Bool => match param { @@ -107,7 +111,7 @@ pub fn decode_param( let mut decoded_values = Vec::new(); let mut resource_ids_to_drop = Vec::new(); for value in values { - let decoded_param = decode_param(value, &ty.ty(), resource_store)?; + let decoded_param = decode_param(value, &ty.ty(), resource_store).await?; decoded_values.push(decoded_param.val); resource_ids_to_drop.extend(decoded_param.resources_to_drop); } @@ -124,7 +128,7 @@ pub fn decode_param( let mut resource_ids_to_drop = Vec::new(); for (value, field) in values.iter().zip(ty.fields()) { - let decoded_param = decode_param(value, &field.ty, resource_store)?; + let decoded_param = decode_param(value, &field.ty, resource_store).await?; record_values.push((field.name.to_string(), decoded_param.val)); resource_ids_to_drop.extend(decoded_param.resources_to_drop); } @@ -142,7 +146,7 @@ pub fn decode_param( let mut resource_ids_to_drop = Vec::new(); for (value, ty) in values.iter().zip(ty.types()) { - let decoded_param = decode_param(value, &ty, resource_store)?; + let decoded_param = decode_param(value, &ty, resource_store).await?; tuple_values.push(decoded_param.val); resource_ids_to_drop.extend(decoded_param.resources_to_drop); } @@ -168,10 +172,10 @@ pub fn decode_param( let name = case.name; match case.ty { Some(ref case_ty) => { - let decoded_value = case_value - .as_ref() - .map(|v| decode_param(v, case_ty, resource_store)) - .transpose()?; + let decoded_value = match case_value { + Some(v) => Some(decode_param(v, case_ty, resource_store).await?), + None => None, + }; match decoded_value { Some(decoded_value) => Ok(DecodeParamResult { val: Val::Variant( @@ -219,7 +223,7 @@ pub fn decode_param( Type::Option(ty) => match param { Value::Option(value) => match value { Some(value) => { - let decoded_value = decode_param(value, &ty.ty(), resource_store)?; + let decoded_value = decode_param(value, &ty.ty(), resource_store).await?; Ok(DecodeParamResult { val: Val::Option(Some(Box::new(decoded_value.val))), resources_to_drop: decoded_value.resources_to_drop, @@ -235,10 +239,10 @@ pub fn decode_param( let ok_ty = ty.ok().ok_or(EncodingError::ValueMismatch { details: "could not get ok type".to_string(), })?; - let decoded_value = value - .as_ref() - .map(|v| decode_param(v, &ok_ty, resource_store)) - .transpose()?; + let decoded_value = match value { + Some(v) => Some(decode_param(v, &ok_ty, resource_store).await?), + None => None, + }; match decoded_value { Some(decoded_value) => Ok(DecodeParamResult { val: Val::Result(Ok(Some(Box::new(decoded_value.val)))), @@ -251,10 +255,11 @@ pub fn decode_param( let err_ty = ty.err().ok_or(EncodingError::ValueMismatch { details: "could not get err type".to_string(), })?; - let decoded_value = value - .as_ref() - .map(|v| decode_param(v, &err_ty, resource_store)) - .transpose()?; + let decoded_value = match value { + Some(v) => Some(decode_param(v, &err_ty, resource_store).await?), + None => None, + }; + match decoded_value { Some(decoded_value) => Ok(DecodeParamResult { val: Val::Result(Err(Some(Box::new(decoded_value.val)))), @@ -288,7 +293,7 @@ pub fn decode_param( Type::Own(_) => match param { Value::Handle { uri, resource_id } => { if resource_store.self_uri() == *uri { - match resource_store.get(*resource_id) { + match resource_store.get(*resource_id).await { Some(resource) => Ok(DecodeParamResult { val: Val::Resource(resource), resources_to_drop: vec![resource], @@ -309,7 +314,7 @@ pub fn decode_param( Type::Borrow(_) => match param { Value::Handle { uri, resource_id } => { if resource_store.self_uri() == *uri { - match resource_store.borrow(*resource_id) { + match resource_store.borrow(*resource_id).await { Some(resource) => Ok(DecodeParamResult::simple(Val::Resource(resource))), None => Err(EncodingError::ValueMismatch { details: "resource not found".to_string(), @@ -328,10 +333,11 @@ pub fn decode_param( } /// Converts a wasmtime Val to a Golem protobuf Val -pub fn encode_output( +#[async_recursion] +pub async fn encode_output( value: &Val, typ: &Type, - resource_store: &mut impl ResourceStore, + resource_store: &mut (impl ResourceStore + Send), ) -> Result { match value { Val::Bool(bool) => Ok(Value::Bool(*bool)), @@ -351,7 +357,8 @@ pub fn encode_output( if let Type::List(list_type) = typ { let mut encoded_values = Vec::new(); for value in (*list).iter() { - encoded_values.push(encode_output(value, &list_type.ty(), resource_store)?); + encoded_values + .push(encode_output(value, &list_type.ty(), resource_store).await?); } Ok(Value::List(encoded_values)) } else { @@ -362,11 +369,11 @@ pub fn encode_output( } Val::Record(record) => { if let Type::Record(record_type) = typ { - let encoded_values = record - .iter() - .zip(record_type.fields()) - .map(|((_name, value), field)| encode_output(value, &field.ty, resource_store)) - .collect::, EncodingError>>()?; + let mut encoded_values = Vec::new(); + for ((_name, value), field) in record.iter().zip(record_type.fields()) { + let field = encode_output(value, &field.ty, resource_store).await?; + encoded_values.push(field); + } Ok(Value::Record(encoded_values)) } else { Err(EncodingError::ValueMismatch { @@ -376,11 +383,11 @@ pub fn encode_output( } Val::Tuple(tuple) => { if let Type::Tuple(tuple_type) = typ { - let encoded_values = tuple - .iter() - .zip(tuple_type.types()) - .map(|(v, t)| encode_output(v, &t, resource_store)) - .collect::, EncodingError>>()?; + let mut encoded_values = Vec::new(); + for (v, t) in tuple.iter().zip(tuple_type.types()) { + let value = encode_output(v, &t, resource_store).await?; + encoded_values.push(value); + } Ok(Value::Tuple(encoded_values)) } else { Err(EncodingError::ValueMismatch { @@ -398,9 +405,8 @@ pub fn encode_output( details: format!("Could not find case for variant {}", name), })?; - let encoded_output = value - .as_ref() - .map(|v| { + let encoded_output = match value { + Some(v) => Some( encode_output( v, &case.ty.ok_or(EncodingError::ValueMismatch { @@ -408,8 +414,11 @@ pub fn encode_output( })?, resource_store, ) - }) - .transpose()?; + .await?, + ), + None => None, + }; + Ok(Value::Variant { case_idx: discriminant as u32, case_value: encoded_output.map(Box::new), @@ -439,7 +448,8 @@ pub fn encode_output( Val::Option(option) => match option { Some(value) => { if let Type::Option(option_type) = typ { - let encoded_output = encode_output(value, &option_type.ty(), resource_store)?; + let encoded_output = + encode_output(value, &option_type.ty(), resource_store).await?; Ok(Value::Option(Some(Box::new(encoded_output)))) } else { Err(EncodingError::ValueMismatch { @@ -453,31 +463,28 @@ pub fn encode_output( if let Type::Result(result_type) = typ { match result { Ok(value) => { - let encoded_output = value - .as_ref() - .map(|v| { - result_type - .ok() - .ok_or(EncodingError::ValueMismatch { - details: "Could not get ok type for result".to_string(), - }) - .and_then(|t| encode_output(v, &t, resource_store)) - }) - .transpose()?; + let encoded_output = match value { + Some(v) => { + let t = result_type.ok().ok_or(EncodingError::ValueMismatch { + details: "Could not get ok type for result".to_string(), + })?; + + Some(encode_output(v, &t, resource_store).await?) + } + None => None, + }; Ok(Value::Result(Ok(encoded_output.map(Box::new)))) } Err(value) => { - let encoded_output = value - .as_ref() - .map(|v| { - result_type - .err() - .ok_or(EncodingError::ValueMismatch { - details: "Could not get error type for result".to_string(), - }) - .and_then(|t| encode_output(v, &t, resource_store)) - }) - .transpose()?; + let encoded_output = match value { + Some(v) => { + let t = result_type.err().ok_or(EncodingError::ValueMismatch { + details: "Could not get error type for result".to_string(), + })?; + Some(encode_output(v, &t, resource_store).await?) + } + None => None, + }; Ok(Value::Result(Err(encoded_output.map(Box::new)))) } } @@ -505,7 +512,7 @@ pub fn encode_output( } } Val::Resource(resource) => { - let id = resource_store.add(*resource); + let id = resource_store.add(*resource).await; Ok(Value::Handle { uri: resource_store.self_uri(), resource_id: id,