forked from bytecodealliance/wasmtime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
wasi-nn.wit
166 lines (144 loc) · 5.9 KB
/
wasi-nn.wit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package wasi:[email protected];
/// `wasi-nn` is a WASI API for performing machine learning (ML) inference. The API is not (yet)
/// capable of performing ML training. WebAssembly programs that want to use a host's ML
/// capabilities can access these capabilities through `wasi-nn`'s core abstractions: _graphs_ and
/// _tensors_. A user `load`s an ML model -- instantiated as a _graph_ -- to use in an ML _backend_.
/// Then, the user passes _tensor_ inputs to the _graph_, computes the inference, and retrieves the
/// _tensor_ outputs.
///
/// This example world shows how to use these primitives together.
world ml {
import tensor;
import graph;
import inference;
import errors;
}
/// All inputs and outputs to an ML inference are represented as `tensor`s.
interface tensor {
/// The dimensions of a tensor.
///
/// The array length matches the tensor rank and each element in the array describes the size of
/// each dimension
type tensor-dimensions = list<u32>;
/// The type of the elements in a tensor.
enum tensor-type {
FP16,
FP32,
FP64,
BF16,
U8,
I32,
I64
}
/// The tensor data.
///
/// Initially conceived as a sparse representation, each empty cell would be filled with zeros
/// and the array length must match the product of all of the dimensions and the number of bytes
/// in the type (e.g., a 2x2 tensor with 4-byte f32 elements would have a data array of length
/// 16). Naturally, this representation requires some knowledge of how to lay out data in
/// memory--e.g., using row-major ordering--and could perhaps be improved.
type tensor-data = list<u8>;
resource tensor {
constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data);
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
// containing a single value, use `[1]` for the tensor dimensions.
dimensions: func() -> tensor-dimensions;
// Describe the type of element in the tensor (e.g., `f32`).
ty: func() -> tensor-type;
// Return the tensor data.
data: func() -> tensor-data;
}
}
/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML
/// framework (e.g., TensorFlow):
interface graph {
use errors.{error};
use tensor.{tensor};
use inference.{graph-execution-context};
/// An execution graph for performing inference (i.e., a model).
resource graph {
init-execution-context: func() -> result<graph-execution-context, error>;
}
/// Describes the encoding of the graph. This allows the API to be implemented by various
/// backends that encode (i.e., serialize) their graph IR with different formats.
enum graph-encoding {
openvino,
onnx,
tensorflow,
pytorch,
tensorflowlite,
ggml,
autodetect,
}
/// Define where the graph should be executed.
enum execution-target {
cpu,
gpu,
tpu
}
/// The graph initialization data.
///
/// This gets bundled up into an array of buffers because implementing backends may encode their
/// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
type graph-builder = list<u8>;
/// Load a `graph` from an opaque sequence of bytes to use for inference.
load: func(builder: list<graph-builder>, encoding: graph-encoding, target: execution-target) -> result<graph, error>;
/// Load a `graph` by name.
///
/// How the host expects the names to be passed and how it stores the graphs for retrieval via
/// this function is **implementation-specific**. This allows hosts to choose name schemes that
/// range from simple to complex (e.g., URLs?) and caching mechanisms of various kinds.
load-by-name: func(name: string) -> result<graph, error>;
}
/// An inference "session" is encapsulated by a `graph-execution-context`. This structure binds a
/// `graph` to input tensors before `compute`-ing an inference:
interface inference {
use errors.{error};
use tensor.{tensor, tensor-data};
/// Bind a `graph` to the input and output tensors for an inference.
///
/// TODO: this may no longer be necessary in WIT
/// (https://github.com/WebAssembly/wasi-nn/issues/43)
resource graph-execution-context {
/// Define the inputs to use for inference.
set-input: func(name: string, tensor: tensor) -> result<_, error>;
/// Compute the inference on the given inputs.
///
/// Note the expected sequence of calls: `set-input`, `compute`, `get-output`. TODO: this
/// expectation could be removed as a part of
/// https://github.com/WebAssembly/wasi-nn/issues/43.
compute: func() -> result<_, error>;
/// Extract the outputs after inference.
get-output: func(name: string) -> result<tensor, error>;
}
}
/// TODO: create function-specific errors (https://github.com/WebAssembly/wasi-nn/issues/42)
interface errors {
enum error-code {
// Caller module passed an invalid argument.
invalid-argument,
// Invalid encoding.
invalid-encoding,
// The operation timed out.
timeout,
// Runtime Error.
runtime-error,
// Unsupported operation.
unsupported-operation,
// Graph is too large.
too-large,
// Graph not found.
not-found,
// The operation is insecure or has insufficient privilege to be performed.
// e.g., cannot access a hardware feature requested
security,
// The operation failed for an unspecified reason.
unknown
}
resource error {
/// Return the error code.
code: func() -> error-code;
/// Errors can propagated with backend specific status through a string value.
data: func() -> string;
}
}