Skip to content

Commit

Permalink
Add get_handle in OakApiNativeExtension, update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mariaschett committed Apr 5, 2022
1 parent 760bf6b commit 455299c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 32 deletions.
12 changes: 6 additions & 6 deletions oak_functions/loader/src/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ where
Ok(extension_result)
}

fn get_metadata(&self) -> (String, wasmi::Signature, ExtensionHandle) {
fn get_metadata(&self) -> (String, wasmi::Signature) {
let signature = wasmi::Signature::new(
&[
ABI_USIZE, // key_ptr
Expand All @@ -120,14 +120,14 @@ where
Some(ValueType::I32),
);

(
LOOKUP_ABI_FUNCTION_NAME.to_string(),
signature,
ExtensionHandle::LookupHandle,
)
(LOOKUP_ABI_FUNCTION_NAME.to_string(), signature)
}

fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::LookupHandle
}
}
12 changes: 6 additions & 6 deletions oak_functions/loader/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl OakApiNativeExtension for PrivateMetricsExtension<Logger> {

/// Each Oak Functions application can have at most one instance of PrivateMetricsProxy. So it
/// is fine to return a constant name in the metadata.
fn get_metadata(&self) -> (String, wasmi::Signature, ExtensionHandle) {
fn get_metadata(&self) -> (String, wasmi::Signature) {
let signature = wasmi::Signature::new(
&[
ABI_USIZE, // buf_ptr
Expand All @@ -105,16 +105,16 @@ impl OakApiNativeExtension for PrivateMetricsExtension<Logger> {
Some(ValueType::I32),
);

(
METRICS_ABI_FUNCTION_NAME.to_string(),
signature,
ExtensionHandle::MetricsHandle,
)
(METRICS_ABI_FUNCTION_NAME.to_string(), signature)
}

fn terminate(&mut self) -> anyhow::Result<()> {
self.publish_metrics()
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::MetricsHandle
}
}

/// Provides logic for the host ABI function [`report_metric`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#report_metric).
Expand Down
14 changes: 9 additions & 5 deletions oak_functions/loader/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ pub trait OakApiNativeExtension {

/// Metadata about this Extension, including the exported host function name, the function's
/// signature, and the corresponding ExtensionHandle.
fn get_metadata(&self) -> (String, wasmi::Signature, ExtensionHandle);
fn get_metadata(&self) -> (String, wasmi::Signature);

/// Performs any cleanup or terminating behavior necessary before destroying the WasmState.
fn terminate(&mut self) -> anyhow::Result<()>;

/// Gets the `ExtensionHandle` for this extension.
fn get_handle(&mut self) -> ExtensionHandle;
}

pub trait ExtensionFactory {
Expand Down Expand Up @@ -343,7 +346,9 @@ impl WasmState {
let handle: ExtensionHandle =
ExtensionHandle::from_i32(handle).expect("Fail to parse handle.");

// Quick solution following impelementation of invoking an extension in `invoke_index`.
// TODO(#2664) Quick solution following impelementation of invoking an extension in
// `invoke_index`. Once we refactored the interface of `invoke` to not require
// `WasmState` any more, we can simplify this.

// First, we get all extensions from WasmState.
let mut extensions_indices = self
Expand All @@ -356,8 +361,7 @@ impl WasmState {
let extension = extensions_indices
.iter_mut()
.find_map(|(_, extension)| {
let (_, _, handle_of_extension) = extension.get_metadata();
if handle_of_extension == handle {
if extension.get_handle() == handle {
Some(extension)
} else {
None
Expand Down Expand Up @@ -658,7 +662,7 @@ impl WasmHandler {

for (ind, factory) in self.extension_factories.iter().enumerate() {
let extension = factory.create()?;
let (name, signature, _) = extension.get_metadata();
let (name, signature) = extension.get_metadata();
extensions_indices.insert(ind + EXTENSION_INDEX_OFFSET, extension);
extensions_metadata.insert(name, (ind + EXTENSION_INDEX_OFFSET, signature));
}
Expand Down
20 changes: 11 additions & 9 deletions oak_functions/loader/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
use crate::{
logger::Logger,
server::{
AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory, ExtensionFactory,
OakApiNativeExtension, ABI_USIZE,
AbiExtensionHandle, AbiPointer, AbiPointerOffset, BoxedExtension, BoxedExtensionFactory,
ExtensionFactory, OakApiNativeExtension, ABI_USIZE,
},
};

Expand All @@ -36,7 +36,9 @@ impl OakApiNativeExtension for TestingExtension<Logger> {
wasm_state: &mut crate::server::WasmState,
args: wasmi::RuntimeArgs,
) -> Result<Result<(), oak_functions_abi::proto::OakStatus>, wasmi::Trap> {
// We still read the args after the ExtensionHandle at position 0.
// For consistency we also get the first argument, but we do not need it, as we did read the
// handle already to decide to call the invoke of this extension.
let _handle: AbiExtensionHandle = args.nth_checked(0)?;
let request_ptr: AbiPointer = args.nth_checked(1)?;
let request_len: AbiPointerOffset = args.nth_checked(2)?;
let response_ptr_ptr: AbiPointer = args.nth_checked(3)?;
Expand All @@ -59,7 +61,7 @@ impl OakApiNativeExtension for TestingExtension<Logger> {
Ok(result)
}

fn get_metadata(&self) -> (String, wasmi::Signature, ExtensionHandle) {
fn get_metadata(&self) -> (String, wasmi::Signature) {
let signature = wasmi::Signature::new(
&[
ABI_USIZE, // request_ptr
Expand All @@ -70,16 +72,16 @@ impl OakApiNativeExtension for TestingExtension<Logger> {
Some(ValueType::I32),
);

(
TESTING_ABI_FUNCTION_NAME.to_string(),
signature,
oak_functions_abi::ExtensionHandle::TestingHandle,
)
(TESTING_ABI_FUNCTION_NAME.to_string(), signature)
}

fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::TestingHandle
}
}

fn testing(message: Vec<u8>) -> Result<Vec<u8>, OakStatus> {
Expand Down
12 changes: 6 additions & 6 deletions oak_functions/loader/src/tf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl OakApiNativeExtension for TensorFlowModel<Logger> {

/// Each Oak Functions application can have at most one instance of TensorFlowModule. So it is
/// fine to return a constant name in the metadata.
fn get_metadata(&self) -> (String, wasmi::Signature, ExtensionHandle) {
fn get_metadata(&self) -> (String, wasmi::Signature) {
let signature = wasmi::Signature::new(
&[
ABI_USIZE, // input_ptr
Expand All @@ -81,16 +81,16 @@ impl OakApiNativeExtension for TensorFlowModel<Logger> {
Some(ValueType::I32),
);

(
TF_ABI_FUNCTION_NAME.to_string(),
signature,
ExtensionHandle::TfHandle,
)
(TF_ABI_FUNCTION_NAME.to_string(), signature)
}

fn terminate(&mut self) -> anyhow::Result<()> {
Ok(())
}

fn get_handle(&mut self) -> ExtensionHandle {
ExtensionHandle::TfHandle
}
}

/// Provides logic for the host ABI function [`tf_model_infer`](https://github.com/project-oak/oak/blob/main/docs/oak_functions_abi.md#tf_model_infer).
Expand Down

0 comments on commit 455299c

Please sign in to comment.