diff --git a/sdk/rust/oak/src/grpc/mod.rs b/sdk/rust/oak/src/grpc/mod.rs index 29d07494411..87eefc2aeb7 100644 --- a/sdk/rust/oak/src/grpc/mod.rs +++ b/sdk/rust/oak/src/grpc/mod.rs @@ -20,7 +20,7 @@ pub use crate::proto::code::Code; use crate::{proto, OakError, OakStatus, ReadHandle}; use log::{error, info, warn}; use proto::grpc_encap::{GrpcRequest, GrpcResponse}; -use protobuf::{Message, ProtobufEnum}; +use protobuf::ProtobufEnum; mod invocation; pub use invocation::Invocation; @@ -36,21 +36,6 @@ pub fn build_status(code: Code, msg: &str) -> proto::status::Status { status } -impl crate::io::Encodable for GrpcRequest { - fn encode(&self) -> std::result::Result { - let bytes = self.write_to_bytes()?; - let handles = Vec::new(); - Ok(crate::io::Message { bytes, handles }) - } -} - -impl crate::io::Decodable for GrpcRequest { - fn decode(message: &crate::io::Message) -> std::result::Result { - let value = protobuf::parse_from_bytes(&message.bytes)?; - Ok(value) - } -} - /// Channel-holding object that encapsulates response messages into /// `GrpcResponse` wrapper messages and writes serialized versions to a send /// channel. @@ -58,21 +43,6 @@ pub struct ChannelResponseWriter { sender: crate::io::Sender, } -impl crate::io::Encodable for GrpcResponse { - fn encode(&self) -> std::result::Result { - let bytes = self.write_to_bytes()?; - let handles = Vec::new(); - Ok(crate::io::Message { bytes, handles }) - } -} - -impl crate::io::Decodable for GrpcResponse { - fn decode(message: &crate::io::Message) -> std::result::Result { - let value = protobuf::parse_from_bytes(&message.bytes)?; - Ok(value) - } -} - /// Indicate whether a write method should leave the current gRPC method /// invocation open or close it. #[derive(PartialEq, Clone, Debug)] diff --git a/sdk/rust/oak/src/io/decodable.rs b/sdk/rust/oak/src/io/decodable.rs index faec6dce49b..c615de6ca14 100644 --- a/sdk/rust/oak/src/io/decodable.rs +++ b/sdk/rust/oak/src/io/decodable.rs @@ -21,3 +21,15 @@ use crate::OakError; pub trait Decodable: Sized { fn decode(message: &Message) -> Result; } + +impl Decodable for T { + fn decode(message: &Message) -> Result { + if !message.handles.is_empty() { + return Err( + protobuf::ProtobufError::WireError(protobuf::error::WireError::Other).into(), + ); + } + let value = protobuf::parse_from_bytes(&message.bytes)?; + Ok(value) + } +} diff --git a/sdk/rust/oak/src/io/encodable.rs b/sdk/rust/oak/src/io/encodable.rs index 74302880ade..b375d80cca2 100644 --- a/sdk/rust/oak/src/io/encodable.rs +++ b/sdk/rust/oak/src/io/encodable.rs @@ -21,3 +21,11 @@ use crate::OakError; pub trait Encodable { fn encode(&self) -> Result; } + +impl Encodable for T { + fn encode(&self) -> Result { + let bytes = self.write_to_bytes()?; + let handles = Vec::new(); + Ok(crate::io::Message { bytes, handles }) + } +}