From 158ff848068d7c81ebb56dbca2ebeea87fea3a3a Mon Sep 17 00:00:00 2001 From: Tiziano Santoro Date: Tue, 28 Jan 2020 10:29:26 +0000 Subject: [PATCH] Define gRPC Invocation in Rust SDK (#530) - Encapsulate a request receiver and a response sender into a single object which implements `Encodable` and `Decodable`. - Add `#[must_use]` to a couple places where return status codes were not checked. We should probably consider using `Result` types everywhere. Fixes #528 --- examples/chat/module/rust/src/backend.rs | 4 +- oak/server/rust/oak_runtime/src/lib.rs | 2 + sdk/rust/oak/src/grpc/invocation.rs | 60 +++++++++++++ sdk/rust/oak/src/grpc/mod.rs | 105 ++++++++++++----------- sdk/rust/oak_tests/src/lib.rs | 12 ++- 5 files changed, 128 insertions(+), 55 deletions(-) create mode 100644 sdk/rust/oak/src/grpc/invocation.rs diff --git a/examples/chat/module/rust/src/backend.rs b/examples/chat/module/rust/src/backend.rs index 63a65b3e1a0..5f592bd891f 100644 --- a/examples/chat/module/rust/src/backend.rs +++ b/examples/chat/module/rust/src/backend.rs @@ -57,7 +57,9 @@ impl Room { fn handle_command(&mut self, command: Command) -> Result<(), oak::OakError> { match command { Command::Join(h) => { - self.clients.push(oak::grpc::ChannelResponseWriter::new(h)); + let sender = oak::io::Sender::new(h); + self.clients + .push(oak::grpc::ChannelResponseWriter::new(sender)); Ok(()) } Command::SendMessage(message_bytes) => { diff --git a/oak/server/rust/oak_runtime/src/lib.rs b/oak/server/rust/oak_runtime/src/lib.rs index 8d7705397bf..165de0361c9 100644 --- a/oak/server/rust/oak_runtime/src/lib.rs +++ b/oak/server/rust/oak_runtime/src/lib.rs @@ -70,6 +70,7 @@ impl MockChannel { .collect(), } } + #[must_use] pub fn write_message(&mut self, msg: OakMessage) -> u32 { if let Some(status) = self.write_status { return status; @@ -156,6 +157,7 @@ impl ChannelHalf { .expect("corrupt channel ref") .read_message(size, actual_size, handle_count, actual_handle_count) } + #[must_use] pub fn write_message(&mut self, msg: OakMessage) -> u32 { if self.direction != Direction::Write { return OakStatus::ERR_BAD_HANDLE.value() as u32; diff --git a/sdk/rust/oak/src/grpc/invocation.rs b/sdk/rust/oak/src/grpc/invocation.rs new file mode 100644 index 00000000000..8fa3aaa5130 --- /dev/null +++ b/sdk/rust/oak/src/grpc/invocation.rs @@ -0,0 +1,60 @@ +// +// Copyright 2020 The Project Oak Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +/// A gRPC invocation, consisting of exactly two channels: one to read incoming requests from the +/// client, and one to write outgoing responses to the client. +pub struct Invocation { + pub request_receiver: crate::io::Receiver, + pub response_sender: crate::io::Sender, +} + +// TODO(#389): Automatically generate this code. +impl crate::io::Encodable for Invocation { + fn encode(&self) -> Result { + let bytes = vec![]; + let handles = vec![ + self.request_receiver.handle.handle, + self.response_sender.handle.handle, + ]; + Ok(crate::io::Message { bytes, handles }) + } +} + +// TODO(#389): Automatically generate this code. +impl crate::io::Decodable for Invocation { + fn decode(message: &crate::io::Message) -> Result { + if !message.bytes.is_empty() { + panic!( + "incorrect number of bytes received: {} (expected: 0)", + message.bytes.len() + ); + } + if message.handles.len() != 2 { + panic!( + "incorrect number of handles received: {} (expected: 2)", + message.handles.len() + ); + } + Ok(Invocation { + request_receiver: crate::io::Receiver::new(crate::ReadHandle { + handle: message.handles[0], + }), + response_sender: crate::io::Sender::new(crate::WriteHandle { + handle: message.handles[1], + }), + }) + } +} diff --git a/sdk/rust/oak/src/grpc/mod.rs b/sdk/rust/oak/src/grpc/mod.rs index 48eff8e2f79..17cb1da37a1 100644 --- a/sdk/rust/oak/src/grpc/mod.rs +++ b/sdk/rust/oak/src/grpc/mod.rs @@ -17,10 +17,14 @@ //! Functionality to help Oak Nodes interact with gRPC. pub use crate::proto::code::Code; -use crate::{proto, Handle, OakError, OakStatus, ReadHandle, WriteHandle}; -use log::{error, info}; +use crate::{proto, OakError, OakStatus, ReadHandle}; +use log::info; +use proto::grpc_encap::{GrpcRequest, GrpcResponse}; use protobuf::{Message, ProtobufEnum}; +mod invocation; +pub use invocation::Invocation; + /// Result type that uses a [`proto::status::Status`] type for error values. pub type Result = std::result::Result; @@ -32,14 +36,29 @@ pub fn build_status(code: Code, msg: &str) -> proto::status::Status { status } +impl crate::io::Encodable for GrpcRequest { + fn encode(&self) -> std::result::Result { + let bytes = self.write_to_bytes()?; + let handles = Vec::new(); + Ok(crate::io::Message { bytes, handles }) + } +} + +impl crate::io::Decodable for GrpcRequest { + fn decode(message: &crate::io::Message) -> std::result::Result { + let value = protobuf::parse_from_bytes(&message.bytes)?; + Ok(value) + } +} + /// Channel-holding object that encapsulates response messages into /// `GrpcResponse` wrapper messages and writes serialized versions to a send -/// channel. +/// channel. pub struct ChannelResponseWriter { - channel: crate::io::Sender, + sender: crate::io::Sender, } -impl crate::io::Encodable for proto::grpc_encap::GrpcResponse { +impl crate::io::Encodable for GrpcResponse { fn encode(&self) -> std::result::Result { let bytes = self.write_to_bytes()?; let handles = Vec::new(); @@ -56,15 +75,13 @@ pub enum WriteMode { } impl ChannelResponseWriter { - pub fn new(out_handle: crate::WriteHandle) -> Self { - ChannelResponseWriter { - channel: crate::io::Sender::new(out_handle), - } + pub fn new(sender: crate::io::Sender) -> Self { + ChannelResponseWriter { sender } } /// Retrieve the Oak handle underlying the writer object. pub fn handle(self) -> crate::WriteHandle { - self.channel.handle + self.sender.handle } /// Write out a gRPC response and optionally close out the method @@ -76,7 +93,7 @@ impl ChannelResponseWriter { ) -> std::result::Result<(), OakError> { // Put the serialized response into a GrpcResponse message wrapper and // serialize it into the channel. - let mut grpc_rsp = proto::grpc_encap::GrpcResponse::new(); + let mut grpc_rsp = GrpcResponse::new(); let mut any = protobuf::well_known_types::Any::new(); rsp.write_to_writer(&mut any.value)?; grpc_rsp.set_rsp_msg(any); @@ -84,9 +101,9 @@ impl ChannelResponseWriter { WriteMode::KeepOpen => false, WriteMode::Close => true, }); - self.channel.send(&grpc_rsp)?; + self.sender.send(&grpc_rsp)?; if mode == WriteMode::Close { - self.channel.close()?; + self.sender.close()?; } Ok(()) } @@ -94,15 +111,15 @@ impl ChannelResponseWriter { /// Write an empty gRPC response and optionally close out the method /// invocation. Any errors from the channel are silently dropped. pub fn write_empty(&mut self, mode: WriteMode) -> std::result::Result<(), OakError> { - let mut grpc_rsp = proto::grpc_encap::GrpcResponse::new(); + let mut grpc_rsp = GrpcResponse::new(); grpc_rsp.set_rsp_msg(protobuf::well_known_types::Any::new()); grpc_rsp.set_last(match mode { WriteMode::KeepOpen => false, WriteMode::Close => true, }); - self.channel.send(&grpc_rsp)?; + self.sender.send(&grpc_rsp)?; if mode == WriteMode::Close { - self.channel.close()?; + self.sender.close()?; } Ok(()) } @@ -111,13 +128,13 @@ impl ChannelResponseWriter { pub fn close(&mut self, result: Result<()>) -> std::result::Result<(), OakError> { // Build a final GrpcResponse message wrapper and serialize it into the // channel. - let mut grpc_rsp = proto::grpc_encap::GrpcResponse::new(); + let mut grpc_rsp = GrpcResponse::new(); grpc_rsp.set_last(true); if let Err(status) = result { grpc_rsp.set_status(status); } - self.channel.send(&grpc_rsp)?; - self.channel.close()?; + self.sender.send(&grpc_rsp)?; + self.sender.close()?; Ok(()) } } @@ -162,47 +179,33 @@ pub fn event_loop( if !grpc_in_handle.handle.is_valid() { return Err(OakStatus::ERR_CHANNEL_CLOSED); } - let read_handles = vec![grpc_in_handle]; - let mut space = crate::new_handle_space(&read_handles); - + let invocation_receiver = crate::io::Receiver::new(grpc_in_handle); loop { - // Block until there is a method notification message to read on an - // input channel. - crate::prep_handle_space(&mut space); - // TODO: Use higher-level wait function from SDK instead of the ABI one. - let status = - unsafe { oak_abi::wait_on_channels(space.as_mut_ptr(), read_handles.len() as u32) }; - crate::result_from_status(status as i32, ())?; - - let mut buf = Vec::::new(); - let mut handles = Vec::::with_capacity(2); - crate::channel_read(grpc_in_handle, &mut buf, &mut handles)?; - if !buf.is_empty() { - error!("unexpected data received in gRPC notification message") - } - if handles.len() != 2 { - panic!( - "unexpected number of handles {} received alongside gRPC request", - handles.len() - ) - } - let req_handle = ReadHandle { handle: handles[0] }; - let rsp_handle = WriteHandle { handle: handles[1] }; + // Explicitly call `wait` and then `try_receive` here because calling `receive` would hide + // any `OakStatus::ERR_TERMINATED` errors that may occur in the `wait` phase, since they + // would be wrapped in a `OakError` value that would then always be unwrapped and panic. + invocation_receiver.wait()?; + let invocation: Invocation = invocation_receiver.try_receive().unwrap_or_else(|err| { + panic!("could not receive gRPC invocation: {}", err); + }); // Read a single encapsulated request message from the read half. - let mut buf = Vec::::with_capacity(1024); - let mut handles = Vec::::new(); - crate::channel_read(req_handle, &mut buf, &mut handles)?; - let _ = crate::channel_close(req_handle.handle); - let req: proto::grpc_encap::GrpcRequest = - protobuf::parse_from_bytes(&buf).expect("failed to parse GrpcRequest message"); + let req: GrpcRequest = invocation.request_receiver.receive().unwrap_or_else(|err| { + panic!("could not read gRPC request: {:?}", err); + }); + // Since we are expecting a single message, close the channel immediately. + // This will change when we implement client streaming (#97). + invocation.request_receiver.close().unwrap_or_else(|err| { + panic!("could not close gRPC request channel: {:?}", err); + }); if !req.last { + // TODO(#97): Implement client streaming. panic!("Support for streaming requests not yet implemented"); } node.invoke( &req.method_name, req.get_req_msg().value.as_slice(), - ChannelResponseWriter::new(rsp_handle), + ChannelResponseWriter::new(invocation.response_sender), ); } } diff --git a/sdk/rust/oak_tests/src/lib.rs b/sdk/rust/oak_tests/src/lib.rs index b3c4277a65c..1e3d64b23a9 100644 --- a/sdk/rust/oak_tests/src/lib.rs +++ b/sdk/rust/oak_tests/src/lib.rs @@ -526,7 +526,10 @@ where // Create a new channel to hold the request message. let (mut req_write_half, req_read_half) = RUNTIME.write().expect(RUNTIME_MISSING).new_channel(); - req_write_half.write_message(req_msg); + let status = req_write_half.write_message(req_msg); + if status != oak::OakStatus::OK.value() as u32 { + panic!("could not write message (status: {})", status); + } // Create a new channel for responses to arrive on and also attach that to the message. let (rsp_write_half, mut rsp_read_half) = RUNTIME.write().expect(RUNTIME_MISSING).new_channel(); @@ -543,10 +546,13 @@ where .expect(RUNTIME_MISSING) .grpc_channel() .expect("no gRPC notification channel setup"); - grpc_channel + let status = grpc_channel .write() .expect("corrupt gRPC channel ref") .write_message(notify_msg); + if status != oak::OakStatus::OK.value() as u32 { + panic!("could not write message (status: {})", status); + } // Read the serialized, encapsulated response. loop { @@ -561,7 +567,7 @@ where std::thread::sleep(std::time::Duration::from_millis(100)); continue; } else { - panic!(format!("failed to read from response channel: {}", e)); + panic!("failed to read from response channel: {}", e); } } Ok(r) => r,