diff --git a/oak_functions/abi/src/lib.rs b/oak_functions/abi/src/lib.rs index f51b31fbb8a..e6a501ade06 100644 --- a/oak_functions/abi/src/lib.rs +++ b/oak_functions/abi/src/lib.rs @@ -17,6 +17,8 @@ //! Type, constant and Wasm host function definitions for the Oak-Functions application //! binary interface (ABI). +pub use crate::proto::ExtensionHandle; + pub mod proto { include!(concat!(env!("OUT_DIR"), "/oak.functions.abi.rs")); include!(concat!(env!("OUT_DIR"), "/oak.functions.lookup_data.rs")); @@ -83,6 +85,7 @@ extern "C" { ) -> u32; pub fn invoke( + handle: ExtensionHandle, request_ptr: *const u8, request_len: usize, response_ptr_ptr: *mut *mut u8, diff --git a/oak_functions/loader/fuzz/fuzz_targets/data/fuzzable.wasm b/oak_functions/loader/fuzz/fuzz_targets/data/fuzzable.wasm index 0dce4fb1f60..7d3a2723e36 100755 Binary files a/oak_functions/loader/fuzz/fuzz_targets/data/fuzzable.wasm and b/oak_functions/loader/fuzz/fuzz_targets/data/fuzzable.wasm differ diff --git a/oak_functions/loader/src/lookup.rs b/oak_functions/loader/src/lookup.rs index 9d97752f1c1..20b1a563b33 100644 --- a/oak_functions/loader/src/lookup.rs +++ b/oak_functions/loader/src/lookup.rs @@ -21,7 +21,7 @@ use crate::{ OakApiNativeExtension, WasmState, ABI_USIZE, }, }; -use oak_functions_abi::proto::OakStatus; +use oak_functions_abi::{proto::OakStatus, ExtensionHandle}; use oak_functions_lookup::{LookupData, LookupDataManager}; use oak_logger::OakLogger; use std::sync::Arc; @@ -126,4 +126,8 @@ where fn terminate(&mut self) -> anyhow::Result<()> { Ok(()) } + + fn get_handle(&mut self) -> ExtensionHandle { + ExtensionHandle::LookupHandle + } } diff --git a/oak_functions/loader/src/metrics.rs b/oak_functions/loader/src/metrics.rs index 76018543a48..5eb776f4ede 100644 --- a/oak_functions/loader/src/metrics.rs +++ b/oak_functions/loader/src/metrics.rs @@ -22,7 +22,7 @@ use crate::{ }, }; use alloc::sync::Arc; -use oak_functions_abi::proto::OakStatus; +use oak_functions_abi::{proto::OakStatus, ExtensionHandle}; use oak_functions_metrics::{ PrivateMetricsAggregator, PrivateMetricsExtension, PrivateMetricsProxy, }; @@ -111,6 +111,10 @@ impl OakApiNativeExtension for PrivateMetricsExtension { fn terminate(&mut self) -> anyhow::Result<()> { self.publish_metrics() } + + fn get_handle(&mut self) -> ExtensionHandle { + ExtensionHandle::MetricsHandle + } } /// Provides logic for the host ABI function [`report_metric`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#report_metric). diff --git a/oak_functions/loader/src/server.rs b/oak_functions/loader/src/server.rs index ab6622df66b..06fee50aade 100644 --- a/oak_functions/loader/src/server.rs +++ b/oak_functions/loader/src/server.rs @@ -19,7 +19,9 @@ use anyhow::Context; use byteorder::{ByteOrder, LittleEndian}; use futures::future::FutureExt; use log::Level; -use oak_functions_abi::proto::{OakStatus, Request, Response, ServerPolicy, StatusCode}; +use oak_functions_abi::proto::{ + ExtensionHandle, OakStatus, Request, Response, ServerPolicy, StatusCode, +}; use oak_logger::OakLogger; use serde::Deserialize; use std::{collections::HashMap, convert::TryInto, str, sync::Arc, time::Duration}; @@ -33,14 +35,15 @@ const ALLOC_FUNCTION_NAME: &str = "alloc"; const READ_REQUEST: usize = 0; const WRITE_RESPONSE: usize = 1; const WRITE_LOG_MESSAGE: usize = 3; +const INVOKE: usize = 4; const EXTENSION_INDEX_OFFSET: usize = 10; // Type aliases for positions and offsets in Wasm linear memory. Any future 64-bit version // of Wasm would use different types. pub type AbiPointer = u32; pub type AbiPointerOffset = u32; -// Type alias for the ChannelHandle type, which has to be cast into a ChannelHandle. -pub type AbiChannelHandle = i32; +// Type alias for the ExtensionHandle type, which has to be cast into a ExtensionHandle. +pub type AbiExtensionHandle = i32; /// Wasm type identifier for position/offset values in linear memory. Any future 64-bit version of /// Wasm would use a different value. pub const ABI_USIZE: ValueType = ValueType::I32; @@ -121,12 +124,15 @@ pub trait OakApiNativeExtension { args: wasmi::RuntimeArgs, ) -> Result, wasmi::Trap>; - /// Metadata about this Extension, including the exported host function name, and the function's - /// signature. + /// Metadata about this Extension, including the exported host function name, the function's + /// signature, and the corresponding ExtensionHandle. fn get_metadata(&self) -> (String, wasmi::Signature); /// Performs any cleanup or terminating behavior necessary before destroying the WasmState. fn terminate(&mut self) -> anyhow::Result<()>; + + /// Gets the `ExtensionHandle` for this extension. + fn get_handle(&mut self) -> ExtensionHandle; } pub trait ExtensionFactory { @@ -332,6 +338,46 @@ impl WasmState { Ok(()) } + pub fn invoke_extension_with_handle( + &mut self, + handle: AbiExtensionHandle, + args: wasmi::RuntimeArgs, + ) -> Result, wasmi::Trap> { + let handle: ExtensionHandle = + ExtensionHandle::from_i32(handle).expect("Fail to parse handle."); + + // TODO(#2664): Quick solution following impelementation of invoking an extension in + // `invoke_index`. Once we refactored the interface of `invoke` to not require + // `WasmState` any more, we can simplify this. + + // First, we get all extensions from WasmState. + let mut extensions_indices = self + .extensions_indices + .take() + .expect("no extensions_indices is set"); + + // Then we get the extension which has the given handle by looking at the values of the + // extension_indices. + let extension = extensions_indices + .iter_mut() + .find_map(|(_, extension)| { + if extension.get_handle() == handle { + Some(extension) + } else { + None + } + }) + .expect("Fail to find extension with given handle."); + + // We invoke the found extension. + let result = from_oak_status_result(extension.invoke(self, args)?); + + // We put the extension indices back. + self.extensions_indices = Some(extensions_indices); + + result + } + pub fn alloc(&mut self, len: u32) -> AbiPointer { let result = self.instance.as_ref().unwrap().invoke_export( ALLOC_FUNCTION_NAME, @@ -371,6 +417,7 @@ impl wasmi::Externals for WasmState { WRITE_LOG_MESSAGE => from_oak_status_result( self.write_log_message(args.nth_checked(0)?, args.nth_checked(1)?), ), + INVOKE => self.invoke_extension_with_handle(args.nth_checked(0)?, args), _ => { let mut extensions_indices = self @@ -689,7 +736,19 @@ fn oak_functions_resolve_func(field_name: &str) -> Option<(usize, wasmi::Signatu Some(ValueType::I32), ), ), - + "invoke" => ( + INVOKE, + wasmi::Signature::new( + &[ + ABI_USIZE, // handle + ABI_USIZE, // request_ptr + ABI_USIZE, // request_len + ABI_USIZE, // response_ptr_ptr + ABI_USIZE, // response_len_ptr + ][..], + Some(ValueType::I32), + ), + ), _ => return None, }; diff --git a/oak_functions/loader/src/testing.rs b/oak_functions/loader/src/testing.rs index dc592825c73..609be646aa2 100644 --- a/oak_functions/loader/src/testing.rs +++ b/oak_functions/loader/src/testing.rs @@ -17,18 +17,18 @@ use crate::{ logger::Logger, server::{ - AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory, ExtensionFactory, - OakApiNativeExtension, ABI_USIZE, + AbiExtensionHandle, AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory, + ExtensionFactory, OakApiNativeExtension, ABI_USIZE, }, }; use log::Level; -use oak_functions_abi::proto::OakStatus; +use oak_functions_abi::{proto::OakStatus, ExtensionHandle}; use serde::{Deserialize, Serialize}; use wasmi::ValueType; /// Host function name for testing. -const TESTING_ABI_FUNCTION_NAME: &str = "invoke"; +const TESTING_ABI_FUNCTION_NAME: &str = "testing"; impl OakApiNativeExtension for TestingExtension { fn invoke( @@ -36,10 +36,13 @@ impl OakApiNativeExtension for TestingExtension { wasm_state: &mut crate::server::WasmState, args: wasmi::RuntimeArgs, ) -> Result, wasmi::Trap> { - let request_ptr: AbiPointer = args.nth_checked(0)?; - let request_len: AbiPointerOffset = args.nth_checked(1)?; - let response_ptr_ptr: AbiPointer = args.nth_checked(2)?; - let response_len_ptr: AbiPointer = args.nth_checked(3)?; + // For consistency we also get the first argument, but we do not need it, as we did read the + // handle already to decide to call the invoke of this extension. + let _handle: AbiExtensionHandle = args.nth_checked(0)?; + let request_ptr: AbiPointer = args.nth_checked(1)?; + let request_len: AbiPointerOffset = args.nth_checked(2)?; + let response_ptr_ptr: AbiPointer = args.nth_checked(3)?; + let response_len_ptr: AbiPointer = args.nth_checked(4)?; let extension_args = wasm_state .read_extension_args(request_ptr, request_len) @@ -75,6 +78,10 @@ impl OakApiNativeExtension for TestingExtension { fn terminate(&mut self) -> anyhow::Result<()> { Ok(()) } + + fn get_handle(&mut self) -> ExtensionHandle { + ExtensionHandle::TestingHandle + } } fn testing(message: Vec) -> Result, OakStatus> { diff --git a/oak_functions/loader/src/tf.rs b/oak_functions/loader/src/tf.rs index d04a25f3e92..52e68e9c642 100644 --- a/oak_functions/loader/src/tf.rs +++ b/oak_functions/loader/src/tf.rs @@ -23,7 +23,7 @@ use crate::{ }; use anyhow::Context; use bytes::Bytes; -use oak_functions_abi::proto::OakStatus; +use oak_functions_abi::proto::{ExtensionHandle, OakStatus}; use oak_functions_tf_inference::{parse_model, TensorFlowModel}; use prost::Message; use std::{fs::File, io::Read, sync::Arc}; @@ -87,6 +87,10 @@ impl OakApiNativeExtension for TensorFlowModel { fn terminate(&mut self) -> anyhow::Result<()> { Ok(()) } + + fn get_handle(&mut self) -> ExtensionHandle { + ExtensionHandle::TfHandle + } } /// Provides logic for the host ABI function [`tf_model_infer`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#tf_model_infer). diff --git a/oak_functions/proto/abi.proto b/oak_functions/proto/abi.proto index 665f5c92b88..8ec4b6c58fc 100644 --- a/oak_functions/proto/abi.proto +++ b/oak_functions/proto/abi.proto @@ -21,14 +21,16 @@ package oak.functions.abi; option java_multiple_files = true; option java_package = "oak.functions.abi"; -// A ChannelHandle corresonds to a channel. The Endpoint which does not own the channel uses a -// ChannelHandle to read, write, or wait on the corresponding channel. -enum ChannelHandle { - CHANNEL_HANDLE_UNSPECIFIED = 0; +// The `ExtensionHandle` indicates which extension to invoke in the Oak Functions runtime. +// We assume every extension exposes exactly one method to invoke. +// `ExtensionHandle`s are exchanged as `i32` values. +enum ExtensionHandle { + INVALID_HANDLE = 0; // Handle for an extension used for testing Wasm modules. - TESTING = 1; - // Handle for an extension to look up a key in the lookup data. - LOOKUP_DATA = 2; + TESTING_HANDLE = 1; + LOOKUP_HANDLE = 2; + METRICS_HANDLE = 3; + TF_HANDLE = 4; } // Status values exchanged as i32 values across the Node Wasm interface. diff --git a/oak_functions/sdk/oak_functions/src/lib.rs b/oak_functions/sdk/oak_functions/src/lib.rs index d780c102479..829f551fbce 100644 --- a/oak_functions/sdk/oak_functions/src/lib.rs +++ b/oak_functions/sdk/oak_functions/src/lib.rs @@ -158,6 +158,7 @@ pub fn invoke(request: &[u8]) -> Result, OakStatus> { let mut response_len: usize = 0; let status_code = unsafe { oak_functions_abi::invoke( + oak_functions_abi::ExtensionHandle::TestingHandle, request.as_ptr(), request.len(), &mut response_ptr,