Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wasi-nn: track upstream specification #9056

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions ci/vendor-wit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ rm -rf $cache_dir
# Separately (for now), vendor the `wasi-nn` WIT files since their retrieval is
# slightly different than above.
repo=https://raw.githubusercontent.com/WebAssembly/wasi-nn
revision=e2310b
revision=0.2.0-rc-2024-06-25
curl -L $repo/$revision/wasi-nn.witx -o crates/wasi-nn/witx/wasi-nn.witx
# TODO: the in-tree `wasi-nn` implementation does not yet fully support the
# latest WIT specification on `main`. To create a baseline for moving forward,
# the in-tree WIT incorporates some but not all of the upstream changes. This
# TODO can be removed once the implementation catches up with the spec.
# curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
curl -L $repo/$revision/wit/wasi-nn.wit -o crates/wasi-nn/wit/wasi-nn.wit
169 changes: 121 additions & 48 deletions crates/wasi-nn/src/wit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use crate::backend::Id;
use crate::{Backend, Registry};
use anyhow::anyhow;
use std::collections::HashMap;
use std::hash::Hash;
use std::{fmt, str::FromStr};
Expand Down Expand Up @@ -54,29 +55,63 @@ impl<'a> WasiNnView<'a> {
}
}

pub enum Error {
/// A wasi-nn error; this appears on the Wasm side as a component model
/// resource.
#[derive(Debug)]
pub struct Error {
code: ErrorCode,
data: anyhow::Error,
}

/// Construct an [`Error`] resource and immediately return it.
///
/// The WIT specification currently relies on "errors as resources;" this helper
/// macro hides some of that complexity. If [#75] is adopted ("errors as
/// records"), this macro is no longer necessary.
///
/// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
macro_rules! bail {
($self:ident, $code:expr, $data:expr) => {
let e = Error {
code: $code,
data: $data.into(),
};
tracing::error!("failure: {e:?}");
let r = $self.table.push(e)?;
return Ok(Err(r));
};
}

impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self {
code: ErrorCode::Trap,
data: error.into(),
}
}
}

/// The list of error codes available to the `wasi-nn` API; this should match
/// what is specified in WIT.
#[derive(Debug)]
pub enum ErrorCode {
/// Caller module passed an invalid argument.
InvalidArgument,
/// Invalid encoding.
InvalidEncoding,
/// The operation timed out.
Timeout,
/// Runtime Error.
/// Runtime error.
RuntimeError,
/// Unsupported operation.
UnsupportedOperation,
/// Graph is too large.
TooLarge,
/// Graph not found.
NotFound,
/// A runtime error occurred that we should trap on; see `StreamError`.
Trap(anyhow::Error),
}

impl From<wasmtime::component::ResourceTableError> for Error {
fn from(error: wasmtime::component::ResourceTableError) -> Self {
Self::Trap(error.into())
}
/// A runtime error that Wasmtime should trap on; this will not appear in
/// the WIT specification.
Trap,
}

/// Generate the traits and types from the `wasi-nn` WIT specification.
Expand All @@ -91,6 +126,7 @@ mod gen_ {
"wasi:nn/graph/graph": crate::Graph,
"wasi:nn/tensor/tensor": crate::Tensor,
"wasi:nn/inference/graph-execution-context": crate::ExecutionContext,
"wasi:nn/errors/error": super::Error,
},
trappable_error_type: {
"wasi:nn/errors/error" => super::Error,
Expand Down Expand Up @@ -131,36 +167,45 @@ impl gen::graph::Host for WasiNnView<'_> {
builders: Vec<GraphBuilder>,
encoding: GraphEncoding,
target: ExecutionTarget,
) -> Result<Resource<crate::Graph>, Error> {
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
tracing::debug!("load {encoding:?} {target:?}");
if let Some(backend) = self.ctx.backends.get_mut(&encoding) {
let slices = builders.iter().map(|s| s.as_slice()).collect::<Vec<_>>();
match backend.load(&slices, target.into()) {
Ok(graph) => {
let graph = self.table.push(graph)?;
Ok(graph)
Ok(Ok(graph))
}
Err(error) => {
tracing::error!("failed to load graph: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
} else {
Err(Error::InvalidEncoding)
bail!(
self,
ErrorCode::InvalidEncoding,
anyhow!("unable to find a backend for this encoding")
);
}
}

fn load_by_name(&mut self, name: String) -> Result<Resource<Graph>, Error> {
fn load_by_name(
&mut self,
name: String,
) -> wasmtime::Result<Result<Resource<Graph>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("load by name {name:?}");
let registry = &self.ctx.registry;
if let Some(graph) = registry.get(&name) {
let graph = graph.clone();
let graph = self.table.push(graph)?;
Ok(graph)
Ok(Ok(graph))
} else {
tracing::error!("failed to find graph with name: {name}");
Err(Error::NotFound)
bail!(
self,
ErrorCode::NotFound,
anyhow!("failed to find graph with name: {name}")
);
}
}
}
Expand All @@ -169,18 +214,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
fn init_execution_context(
&mut self,
graph: Resource<Graph>,
) -> Result<Resource<GraphExecutionContext>, Error> {
) -> wasmtime::Result<Result<Resource<GraphExecutionContext>, Resource<Error>>> {
use core::result::Result::*;
tracing::debug!("initialize execution context");
let graph = self.table.get(&graph)?;
match graph.init_execution_context() {
Ok(exec_context) => {
let exec_context = self.table.push(exec_context)?;
Ok(exec_context)
Ok(Ok(exec_context))
}
Err(error) => {
tracing::error!("failed to initialize execution context: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Expand All @@ -197,47 +241,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
exec_context: Resource<GraphExecutionContext>,
name: String,
tensor: Resource<Tensor>,
) -> Result<(), Error> {
) -> wasmtime::Result<Result<(), Resource<Error>>> {
let tensor = self.table.get(&tensor)?;
tracing::debug!("set input {name:?}: {tensor:?}");
let tensor = tensor.clone(); // TODO: avoid copying the tensor
let exec_context = self.table.get_mut(&exec_context)?;
if let Err(e) = exec_context.set_input(Id::Name(name), &tensor) {
tracing::error!("failed to set input: {e:?}");
Err(Error::InvalidArgument)
if let Err(error) = exec_context.set_input(Id::Name(name), &tensor) {
bail!(self, ErrorCode::InvalidArgument, error);
} else {
Ok(())
Ok(Ok(()))
}
}

fn compute(&mut self, exec_context: Resource<GraphExecutionContext>) -> Result<(), Error> {
fn compute(
&mut self,
exec_context: Resource<GraphExecutionContext>,
) -> wasmtime::Result<Result<(), Resource<Error>>> {
let exec_context = &mut self.table.get_mut(&exec_context)?;
tracing::debug!("compute");
match exec_context.compute() {
Ok(()) => Ok(()),
Ok(()) => Ok(Ok(())),
Err(error) => {
tracing::error!("failed to compute: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}

#[doc = r" Extract the outputs after inference."]
fn get_output(
&mut self,
exec_context: Resource<GraphExecutionContext>,
name: String,
) -> Result<Resource<Tensor>, Error> {
) -> wasmtime::Result<Result<Resource<Tensor>, Resource<Error>>> {
let exec_context = self.table.get_mut(&exec_context)?;
tracing::debug!("get output {name:?}");
match exec_context.get_output(Id::Name(name)) {
Ok(tensor) => {
let tensor = self.table.push(tensor)?;
Ok(tensor)
Ok(Ok(tensor))
}
Err(error) => {
tracing::error!("failed to get output: {error:?}");
Err(Error::RuntimeError)
bail!(self, ErrorCode::RuntimeError, error);
}
}
}
Expand Down Expand Up @@ -285,21 +328,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
}
}

impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::errors::HostError for WasiNnView<'_> {
fn new(
&mut self,
_code: gen::errors::ErrorCode,
_data: String,
) -> wasmtime::Result<Resource<Error>> {
unimplemented!("this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76")
}

fn code(&mut self, error: Resource<Error>) -> wasmtime::Result<gen::errors::ErrorCode> {
let error = self.table.get(&error)?;
match error.code {
ErrorCode::InvalidArgument => Ok(gen::errors::ErrorCode::InvalidArgument),
ErrorCode::InvalidEncoding => Ok(gen::errors::ErrorCode::InvalidEncoding),
ErrorCode::Timeout => Ok(gen::errors::ErrorCode::Timeout),
ErrorCode::RuntimeError => Ok(gen::errors::ErrorCode::RuntimeError),
ErrorCode::UnsupportedOperation => Ok(gen::errors::ErrorCode::UnsupportedOperation),
ErrorCode::TooLarge => Ok(gen::errors::ErrorCode::TooLarge),
ErrorCode::NotFound => Ok(gen::errors::ErrorCode::NotFound),
ErrorCode::Trap => Err(anyhow!(error.data.to_string())),
}
}

fn data(&mut self, error: Resource<Error>) -> wasmtime::Result<String> {
let error = self.table.get(&error)?;
Ok(error.data.to_string())
}

fn drop(&mut self, error: Resource<Error>) -> wasmtime::Result<()> {
self.table.delete(error)?;
Ok(())
}
}

impl gen::errors::Host for WasiNnView<'_> {
fn convert_error(&mut self, err: Error) -> wasmtime::Result<gen::errors::Error> {
match err {
Error::InvalidArgument => Ok(gen::errors::Error::InvalidArgument),
Error::InvalidEncoding => Ok(gen::errors::Error::InvalidEncoding),
Error::Timeout => Ok(gen::errors::Error::Timeout),
Error::RuntimeError => Ok(gen::errors::Error::RuntimeError),
Error::UnsupportedOperation => Ok(gen::errors::Error::UnsupportedOperation),
Error::TooLarge => Ok(gen::errors::Error::TooLarge),
Error::NotFound => Ok(gen::errors::Error::NotFound),
Error::Trap(e) => Err(e),
fn convert_error(&mut self, err: Error) -> wasmtime::Result<Error> {
if matches!(err.code, ErrorCode::Trap) {
Err(err.data)
} else {
Ok(err)
}
}
}

impl gen::tensor::Host for WasiNnView<'_> {}
impl gen::inference::Host for WasiNnView<'_> {}

impl Hash for gen::graph::GraphEncoding {
Expand Down
21 changes: 18 additions & 3 deletions crates/wasi-nn/wit/wasi-nn.wit
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package wasi:nn;
package wasi:nn@0.2.0-rc-2024-06-25;

/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
Expand Down Expand Up @@ -134,7 +134,7 @@ interface inference {

/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
enum error {
enum error-code {
// Caller module passed an invalid argument.
invalid-argument,
// Invalid encoding.
Expand All @@ -148,6 +148,21 @@ interface errors {
// Graph is too large.
too-large,
// Graph not found.
not-found
not-found,
// The operation is insecure or has insufficient privilege to be performed.
// e.g., cannot access a hardware feature requested
security,
// The operation failed for an unspecified reason.
unknown
}

resource error {
constructor(code: error-code, data: string);

/// Return the error code.
code: func() -> error-code;

/// Errors can propagated with backend specific status through a string value.
data: func() -> string;
}
}
Loading