Skip to content

Commit

Permalink
Invoke Testing extension through ExtensionHandle (#2660)
Browse files Browse the repository at this point in the history
Add get_handle in OakApiNativeExtension, invoke testing extension through handle.
  • Loading branch information
mariaschett authored Apr 5, 2022
1 parent 7553bc7 commit 20c4ca2
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 24 deletions.
3 changes: 3 additions & 0 deletions oak_functions/abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
//! Type, constant and Wasm host function definitions for the Oak-Functions application
//! binary interface (ABI).

pub use crate::proto::ExtensionHandle;

pub mod proto {
include!(concat!(env!("OUT_DIR"), "/oak.functions.abi.rs"));
include!(concat!(env!("OUT_DIR"), "/oak.functions.lookup_data.rs"));
Expand Down Expand Up @@ -83,6 +85,7 @@ extern "C" {
) -> u32;

pub fn invoke(
handle: ExtensionHandle,
request_ptr: *const u8,
request_len: usize,
response_ptr_ptr: *mut *mut u8,
Expand Down
Binary file modified oak_functions/loader/fuzz/fuzz_targets/data/fuzzable.wasm
Binary file not shown.
6 changes: 5 additions & 1 deletion oak_functions/loader/src/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
OakApiNativeExtension, WasmState, ABI_USIZE,
},
};
use oak_functions_abi::proto::OakStatus;
use oak_functions_abi::{proto::OakStatus, ExtensionHandle};
use oak_functions_lookup::{LookupData, LookupDataManager};
use oak_logger::OakLogger;
use std::sync::Arc;
Expand Down Expand Up @@ -126,4 +126,8 @@ where
fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::LookupHandle
}
}
6 changes: 5 additions & 1 deletion oak_functions/loader/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
},
};
use alloc::sync::Arc;
use oak_functions_abi::proto::OakStatus;
use oak_functions_abi::{proto::OakStatus, ExtensionHandle};
use oak_functions_metrics::{
PrivateMetricsAggregator, PrivateMetricsExtension, PrivateMetricsProxy,
};
Expand Down Expand Up @@ -111,6 +111,10 @@ impl OakApiNativeExtension for PrivateMetricsExtension<Logger> {
fn terminate(&mut self) -> anyhow::Result<()> {
self.publish_metrics()
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::MetricsHandle
}
}

/// Provides logic for the host ABI function [`report_metric`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#report_metric).
Expand Down
71 changes: 65 additions & 6 deletions oak_functions/loader/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use anyhow::Context;
use byteorder::{ByteOrder, LittleEndian};
use futures::future::FutureExt;
use log::Level;
use oak_functions_abi::proto::{OakStatus, Request, Response, ServerPolicy, StatusCode};
use oak_functions_abi::proto::{
ExtensionHandle, OakStatus, Request, Response, ServerPolicy, StatusCode,
};
use oak_logger::OakLogger;
use serde::Deserialize;
use std::{collections::HashMap, convert::TryInto, str, sync::Arc, time::Duration};
Expand All @@ -33,14 +35,15 @@ const ALLOC_FUNCTION_NAME: &str = "alloc";
const READ_REQUEST: usize = 0;
const WRITE_RESPONSE: usize = 1;
const WRITE_LOG_MESSAGE: usize = 3;
const INVOKE: usize = 4;
const EXTENSION_INDEX_OFFSET: usize = 10;

// Type aliases for positions and offsets in Wasm linear memory. Any future 64-bit version
// of Wasm would use different types.
pub type AbiPointer = u32;
pub type AbiPointerOffset = u32;
// Type alias for the ChannelHandle type, which has to be cast into a ChannelHandle.
pub type AbiChannelHandle = i32;
// Type alias for the ExtensionHandle type, which has to be cast into a ExtensionHandle.
pub type AbiExtensionHandle = i32;
/// Wasm type identifier for position/offset values in linear memory. Any future 64-bit version of
/// Wasm would use a different value.
pub const ABI_USIZE: ValueType = ValueType::I32;
Expand Down Expand Up @@ -121,12 +124,15 @@ pub trait OakApiNativeExtension {
args: wasmi::RuntimeArgs,
) -> Result<Result<(), OakStatus>, wasmi::Trap>;

/// Metadata about this Extension, including the exported host function name, and the function's
/// signature.
/// Metadata about this Extension, including the exported host function name, the function's
/// signature, and the corresponding ExtensionHandle.
fn get_metadata(&self) -> (String, wasmi::Signature);

/// Performs any cleanup or terminating behavior necessary before destroying the WasmState.
fn terminate(&mut self) -> anyhow::Result<()>;

/// Gets the `ExtensionHandle` for this extension.
fn get_handle(&mut self) -> ExtensionHandle;
}

pub trait ExtensionFactory {
Expand Down Expand Up @@ -332,6 +338,46 @@ impl WasmState {
Ok(())
}

pub fn invoke_extension_with_handle(
&mut self,
handle: AbiExtensionHandle,
args: wasmi::RuntimeArgs,
) -> Result<Option<wasmi::RuntimeValue>, wasmi::Trap> {
let handle: ExtensionHandle =
ExtensionHandle::from_i32(handle).expect("Fail to parse handle.");

// TODO(#2664): Quick solution following impelementation of invoking an extension in
// `invoke_index`. Once we refactored the interface of `invoke` to not require
// `WasmState` any more, we can simplify this.

// First, we get all extensions from WasmState.
let mut extensions_indices = self
.extensions_indices
.take()
.expect("no extensions_indices is set");

// Then we get the extension which has the given handle by looking at the values of the
// extension_indices.
let extension = extensions_indices
.iter_mut()
.find_map(|(_, extension)| {
if extension.get_handle() == handle {
Some(extension)
} else {
None
}
})
.expect("Fail to find extension with given handle.");

// We invoke the found extension.
let result = from_oak_status_result(extension.invoke(self, args)?);

// We put the extension indices back.
self.extensions_indices = Some(extensions_indices);

result
}

pub fn alloc(&mut self, len: u32) -> AbiPointer {
let result = self.instance.as_ref().unwrap().invoke_export(
ALLOC_FUNCTION_NAME,
Expand Down Expand Up @@ -371,6 +417,7 @@ impl wasmi::Externals for WasmState {
WRITE_LOG_MESSAGE => from_oak_status_result(
self.write_log_message(args.nth_checked(0)?, args.nth_checked(1)?),
),
INVOKE => self.invoke_extension_with_handle(args.nth_checked(0)?, args),

_ => {
let mut extensions_indices = self
Expand Down Expand Up @@ -689,7 +736,19 @@ fn oak_functions_resolve_func(field_name: &str) -> Option<(usize, wasmi::Signatu
Some(ValueType::I32),
),
),

"invoke" => (
INVOKE,
wasmi::Signature::new(
&[
ABI_USIZE, // handle
ABI_USIZE, // request_ptr
ABI_USIZE, // request_len
ABI_USIZE, // response_ptr_ptr
ABI_USIZE, // response_len_ptr
][..],
Some(ValueType::I32),
),
),
_ => return None,
};

Expand Down
23 changes: 15 additions & 8 deletions oak_functions/loader/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,32 @@
use crate::{
logger::Logger,
server::{
AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory, ExtensionFactory,
OakApiNativeExtension, ABI_USIZE,
AbiExtensionHandle, AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory,
ExtensionFactory, OakApiNativeExtension, ABI_USIZE,
},
};

use log::Level;
use oak_functions_abi::proto::OakStatus;
use oak_functions_abi::{proto::OakStatus, ExtensionHandle};
use serde::{Deserialize, Serialize};
use wasmi::ValueType;

/// Host function name for testing.
const TESTING_ABI_FUNCTION_NAME: &str = "invoke";
const TESTING_ABI_FUNCTION_NAME: &str = "testing";

impl OakApiNativeExtension for TestingExtension<Logger> {
fn invoke(
&mut self,
wasm_state: &mut crate::server::WasmState,
args: wasmi::RuntimeArgs,
) -> Result<Result<(), oak_functions_abi::proto::OakStatus>, wasmi::Trap> {
let request_ptr: AbiPointer = args.nth_checked(0)?;
let request_len: AbiPointerOffset = args.nth_checked(1)?;
let response_ptr_ptr: AbiPointer = args.nth_checked(2)?;
let response_len_ptr: AbiPointer = args.nth_checked(3)?;
// For consistency we also get the first argument, but we do not need it, as we did read the
// handle already to decide to call the invoke of this extension.
let _handle: AbiExtensionHandle = args.nth_checked(0)?;
let request_ptr: AbiPointer = args.nth_checked(1)?;
let request_len: AbiPointerOffset = args.nth_checked(2)?;
let response_ptr_ptr: AbiPointer = args.nth_checked(3)?;
let response_len_ptr: AbiPointer = args.nth_checked(4)?;

let extension_args = wasm_state
.read_extension_args(request_ptr, request_len)
Expand Down Expand Up @@ -75,6 +78,10 @@ impl OakApiNativeExtension for TestingExtension<Logger> {
fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::TestingHandle
}
}

fn testing(message: Vec<u8>) -> Result<Vec<u8>, OakStatus> {
Expand Down
6 changes: 5 additions & 1 deletion oak_functions/loader/src/tf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
};
use anyhow::Context;
use bytes::Bytes;
use oak_functions_abi::proto::OakStatus;
use oak_functions_abi::proto::{ExtensionHandle, OakStatus};
use oak_functions_tf_inference::{parse_model, TensorFlowModel};
use prost::Message;
use std::{fs::File, io::Read, sync::Arc};
Expand Down Expand Up @@ -87,6 +87,10 @@ impl OakApiNativeExtension for TensorFlowModel<Logger> {
fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::TfHandle
}
}

/// Provides logic for the host ABI function [`tf_model_infer`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#tf_model_infer).
Expand Down
16 changes: 9 additions & 7 deletions oak_functions/proto/abi.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ package oak.functions.abi;
option java_multiple_files = true;
option java_package = "oak.functions.abi";

// A ChannelHandle corresonds to a channel. The Endpoint which does not own the channel uses a
// ChannelHandle to read, write, or wait on the corresponding channel.
enum ChannelHandle {
CHANNEL_HANDLE_UNSPECIFIED = 0;
// The `ExtensionHandle` indicates which extension to invoke in the Oak Functions runtime.
// We assume every extension exposes exactly one method to invoke.
// `ExtensionHandle`s are exchanged as `i32` values.
enum ExtensionHandle {
INVALID_HANDLE = 0;
// Handle for an extension used for testing Wasm modules.
TESTING = 1;
// Handle for an extension to look up a key in the lookup data.
LOOKUP_DATA = 2;
TESTING_HANDLE = 1;
LOOKUP_HANDLE = 2;
METRICS_HANDLE = 3;
TF_HANDLE = 4;
}

// Status values exchanged as i32 values across the Node Wasm interface.
Expand Down
1 change: 1 addition & 0 deletions oak_functions/sdk/oak_functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ pub fn invoke(request: &[u8]) -> Result<Vec<u8>, OakStatus> {
let mut response_len: usize = 0;
let status_code = unsafe {
oak_functions_abi::invoke(
oak_functions_abi::ExtensionHandle::TestingHandle,
request.as_ptr(),
request.len(),
&mut response_ptr,
Expand Down

0 comments on commit 20c4ca2

Please sign in to comment.