diff --git a/ci/vendor-wit.sh b/ci/vendor-wit.sh index 005f4d3e0730..8f4d8184dfe7 100755 --- a/ci/vendor-wit.sh +++ b/ci/vendor-wit.sh @@ -65,10 +65,6 @@ rm -rf $cache_dir # Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is # slightly different than above. repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn -revision=e2310b +revision=0.2.0-rc-2024-06-25 curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx -# TODO: the in-tree `wasi-nn` implementation does not yet fully support the -# latest WIT specification on `main`. To create a baseline for moving forward, -# the in-tree WIT incorporates some but not all of the upstream changes. This -# TODO can be removed once the implementation catches up with the spec. -# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit +curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit diff --git a/crates/wasi-nn/src/wit.rs b/crates/wasi-nn/src/wit.rs index 40f6fc4c1ff6..e0c740232a65 100644 --- a/crates/wasi-nn/src/wit.rs +++ b/crates/wasi-nn/src/wit.rs @@ -17,6 +17,7 @@ use crate::backend::Id; use crate::{Backend, Registry}; +use anyhow::anyhow; use std::collections::HashMap; use std::hash::Hash; use std::{fmt, str::FromStr}; @@ -54,14 +55,53 @@ impl<'a> WasiNnView<'a> { } } -pub enum Error { +/// A wasi-nn error; this appears on the Wasm side as a component model +/// resource. +#[derive(Debug)] +pub struct Error { + code: ErrorCode, + data: anyhow::Error, +} + +/// Construct an [`Error`] resource and immediately return it. +/// +/// The WIT specification currently relies on "errors as resources;" this helper +/// macro hides some of that complexity. If [#75] is adopted ("errors as +/// records"), this macro is no longer necessary. +/// +/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75 +macro_rules! bail { + ($self:ident, $code:expr, $data:expr) => { + let e = Error { + code: $code, + data: $data.into(), + }; + tracing::error!("failure: {e:?}"); + let r = $self.table.push(e)?; + return Ok(Err(r)); + }; +} + +impl From for Error { + fn from(error: wasmtime::component::ResourceTableError) -> Self { + Self { + code: ErrorCode::Trap, + data: error.into(), + } + } +} + +/// The list of error codes available to the `wasi-nn` API; this should match +/// what is specified in WIT. +#[derive(Debug)] +pub enum ErrorCode { /// Caller module passed an invalid argument. InvalidArgument, /// Invalid encoding. InvalidEncoding, /// The operation timed out. Timeout, - /// Runtime Error. + /// Runtime error. RuntimeError, /// Unsupported operation. UnsupportedOperation, @@ -69,14 +109,9 @@ pub enum Error { TooLarge, /// Graph not found. NotFound, - /// A runtime error occurred that we should trap on; see `StreamError`. - Trap(anyhow::Error), -} - -impl From for Error { - fn from(error: wasmtime::component::ResourceTableError) -> Self { - Self::Trap(error.into()) - } + /// A runtime error that Wasmtime should trap on; this will not appear in + /// the WIT specification. + Trap, } /// Generate the traits and types from the `wasi-nn` WIT specification. @@ -91,6 +126,7 @@ mod gen_ { "wasi:nn/graph/graph": crate::Graph, "wasi:nn/tensor/tensor": crate::Tensor, "wasi:nn/inference/graph-execution-context": crate::ExecutionContext, + "wasi:nn/errors/error": super::Error, }, trappable_error_type: { "wasi:nn/errors/error" => super::Error, @@ -131,36 +167,45 @@ impl gen::graph::Host for WasiNnView<'_> { builders: Vec, encoding: GraphEncoding, target: ExecutionTarget, - ) -> Result, Error> { + ) -> wasmtime::Result, Resource>> { tracing::debug!("load {encoding:?} {target:?}"); if let Some(backend) = self.ctx.backends.get_mut(&encoding) { let slices = builders.iter().map(|s| s.as_slice()).collect::>(); match backend.load(&slices, target.into()) { Ok(graph) => { let graph = self.table.push(graph)?; - Ok(graph) + Ok(Ok(graph)) } Err(error) => { - tracing::error!("failed to load graph: {error:?}"); - Err(Error::RuntimeError) + bail!(self, ErrorCode::RuntimeError, error); } } } else { - Err(Error::InvalidEncoding) + bail!( + self, + ErrorCode::InvalidEncoding, + anyhow!("unable to find a backend for this encoding") + ); } } - fn load_by_name(&mut self, name: String) -> Result, Error> { + fn load_by_name( + &mut self, + name: String, + ) -> wasmtime::Result, Resource>> { use core::result::Result::*; tracing::debug!("load by name {name:?}"); let registry = &self.ctx.registry; if let Some(graph) = registry.get(&name) { let graph = graph.clone(); let graph = self.table.push(graph)?; - Ok(graph) + Ok(Ok(graph)) } else { - tracing::error!("failed to find graph with name: {name}"); - Err(Error::NotFound) + bail!( + self, + ErrorCode::NotFound, + anyhow!("failed to find graph with name: {name}") + ); } } } @@ -169,18 +214,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> { fn init_execution_context( &mut self, graph: Resource, - ) -> Result, Error> { + ) -> wasmtime::Result, Resource>> { use core::result::Result::*; tracing::debug!("initialize execution context"); let graph = self.table.get(&graph)?; match graph.init_execution_context() { Ok(exec_context) => { let exec_context = self.table.push(exec_context)?; - Ok(exec_context) + Ok(Ok(exec_context)) } Err(error) => { - tracing::error!("failed to initialize execution context: {error:?}"); - Err(Error::RuntimeError) + bail!(self, ErrorCode::RuntimeError, error); } } } @@ -197,47 +241,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> { exec_context: Resource, name: String, tensor: Resource, - ) -> Result<(), Error> { + ) -> wasmtime::Result>> { let tensor = self.table.get(&tensor)?; tracing::debug!("set input {name:?}: {tensor:?}"); let tensor = tensor.clone(); // TODO: avoid copying the tensor let exec_context = self.table.get_mut(&exec_context)?; - if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) { - tracing::error!("failed to set input: {e:?}"); - Err(Error::InvalidArgument) + if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) { + bail!(self, ErrorCode::InvalidArgument, error); } else { - Ok(()) + Ok(Ok(())) } } - fn compute(&mut self, exec_context: Resource) -> Result<(), Error> { + fn compute( + &mut self, + exec_context: Resource, + ) -> wasmtime::Result>> { let exec_context = &mut self.table.get_mut(&exec_context)?; tracing::debug!("compute"); match exec_context.compute() { - Ok(()) => Ok(()), + Ok(()) => Ok(Ok(())), Err(error) => { - tracing::error!("failed to compute: {error:?}"); - Err(Error::RuntimeError) + bail!(self, ErrorCode::RuntimeError, error); } } } - #[doc = r" Extract the outputs after inference."] fn get_output( &mut self, exec_context: Resource, name: String, - ) -> Result, Error> { + ) -> wasmtime::Result, Resource>> { let exec_context = self.table.get_mut(&exec_context)?; tracing::debug!("get output {name:?}"); match exec_context.get_output(Id::Name(name)) { Ok(tensor) => { let tensor = self.table.push(tensor)?; - Ok(tensor) + Ok(Ok(tensor)) } Err(error) => { - tracing::error!("failed to get output: {error:?}"); - Err(Error::RuntimeError) + bail!(self, ErrorCode::RuntimeError, error); } } } @@ -285,21 +328,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> { } } -impl gen::tensor::Host for WasiNnView<'_> {} +impl gen::errors::HostError for WasiNnView<'_> { + fn new( + &mut self, + _code: gen::errors::ErrorCode, + _data: String, + ) -> wasmtime::Result> { + unimplemented!("this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76") + } + + fn code(&mut self, error: Resource) -> wasmtime::Result { + let error = self.table.get(&error)?; + match error.code { + ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument), + ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding), + ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout), + ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError), + ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation), + ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge), + ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound), + ErrorCode::Trap => Err(anyhow!(error.data.to_string())), + } + } + + fn data(&mut self, error: Resource) -> wasmtime::Result { + let error = self.table.get(&error)?; + Ok(error.data.to_string()) + } + + fn drop(&mut self, error: Resource) -> wasmtime::Result<()> { + self.table.delete(error)?; + Ok(()) + } +} + impl gen::errors::Host for WasiNnView<'_> { - fn convert_error(&mut self, err: Error) -> wasmtime::Result { - match err { - Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument), - Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding), - Error::Timeout => Ok(gen::errors::Error::Timeout), - Error::RuntimeError => Ok(gen::errors::Error::RuntimeError), - Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation), - Error::TooLarge => Ok(gen::errors::Error::TooLarge), - Error::NotFound => Ok(gen::errors::Error::NotFound), - Error::Trap(e) => Err(e), + fn convert_error(&mut self, err: Error) -> wasmtime::Result { + if matches!(err.code, ErrorCode::Trap) { + Err(err.data) + } else { + Ok(err) } } } + +impl gen::tensor::Host for WasiNnView<'_> {} impl gen::inference::Host for WasiNnView<'_> {} impl Hash for gen::graph::GraphEncoding { diff --git a/crates/wasi-nn/wit/wasi-nn.wit b/crates/wasi-nn/wit/wasi-nn.wit index b8ffd22e8c04..872e8cd7d745 100644 --- a/crates/wasi-nn/wit/wasi-nn.wit +++ b/crates/wasi-nn/wit/wasi-nn.wit @@ -1,4 +1,4 @@ -package wasi:nn; +package wasi:nn@0.2.0-rc-2024-06-25; /// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet) /// capable of performing ML training. WebAssembly programs that want to use a host's ML @@ -134,7 +134,7 @@ interface inference { /// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42) interface errors { - enum error { + enum error-code { // Caller module passed an invalid argument. invalid-argument, // Invalid encoding. @@ -148,6 +148,21 @@ interface errors { // Graph is too large. too-large, // Graph not found. - not-found + not-found, + // The operation is insecure or has insufficient privilege to be performed. + // e.g., cannot access a hardware feature requested + security, + // The operation failed for an unspecified reason. + unknown + } + + resource error { + constructor(code: error-code, data: string); + + /// Return the error code. + code: func() -> error-code; + + /// Errors can propagated with backend specific status through a string value. + data: func() -> string; } }