diff --git a/oak/proto/oak_api.proto b/oak/proto/oak_api.proto index df16dbfb9ea..0f85433b6b4 100644 --- a/oak/proto/oak_api.proto +++ b/oak/proto/oak_api.proto @@ -42,6 +42,8 @@ enum OakStatus { ERR_TERMINATED = 9; // Channel has no messages available to read. ERR_CHANNEL_EMPTY = 10; + // The node does not have sufficient permissions to perform the requested operation. + ERR_PERMISSION_DENIED = 11; } // Single byte values used to indicate the read status of a channel on the diff --git a/oak/server/rust/oak_runtime/src/runtime/channel.rs b/oak/server/rust/oak_runtime/src/runtime/channel.rs index c407e287986..521fc2a0760 100644 --- a/oak/server/rust/oak_runtime/src/runtime/channel.rs +++ b/oak/server/rust/oak_runtime/src/runtime/channel.rs @@ -62,7 +62,7 @@ pub struct Channel { /// This is set at channel creation time and does not change after that. /// /// See https://github.com/project-oak/oak/blob/master/docs/concepts.md#labels - label: oak_abi::label::Label, + pub label: oak_abi::label::Label, } /// A reference to a [`Channel`]. Each [`Handle`] has an implicit direction such that it is only diff --git a/oak/server/rust/oak_runtime/src/runtime/mod.rs b/oak/server/rust/oak_runtime/src/runtime/mod.rs index eb7824f2a85..ddf4987dfe2 100644 --- a/oak/server/rust/oak_runtime/src/runtime/mod.rs +++ b/oak/server/rust/oak_runtime/src/runtime/mod.rs @@ -44,8 +44,6 @@ struct Node { /// This is set at node creation time and does not change after that. /// /// See https://github.com/project-oak/oak/blob/master/docs/concepts.md#labels - // TODO(#630): Remove exception when label tracking is implemented. - #[allow(dead_code)] label: oak_abi::label::Label, /// A [`HashSet`] containing all the handles associated with this Node. @@ -155,7 +153,7 @@ impl Runtime { // will prevent additional nodes from starting to wait again, because `wait_on_channels` // will return immediately with `OakStatus::ErrTerminated`. let instances: Vec<_> = { - let mut nodes = self.nodes.write().unwrap(); + let mut nodes = self.nodes.write().expect("could not acquire lock on nodes"); self.terminating.store(true, SeqCst); nodes @@ -205,10 +203,8 @@ impl Runtime { return; } - let nodes = self.nodes.read().unwrap(); - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into track_handles_in_node!"); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); let mut tracked_handles = node.handles.lock().unwrap(); for handle in handles { @@ -223,12 +219,13 @@ impl Runtime { return Ok(()); } - let nodes = self.nodes.read().unwrap(); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); // Lookup the node_id in the runtime's nodes hashmap - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into validate_handle_access!"); - let tracked_handles = node.handles.lock().unwrap(); + let node = nodes.get(&node_id).expect("invalid node_id"); + let tracked_handles = node + .handles + .lock() + .expect("could not acquire lock on tracked handles"); // Check the handle exists in the handles associated with a node, otherwise // return None. @@ -255,12 +252,13 @@ impl Runtime { return Ok(()); } - let nodes = self.nodes.read().unwrap(); - let node = nodes - .get(&node_id) - .expect("Invalid node_id passed into filter_optional_handles!"); + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); - let tracked_handles = node.handles.lock().unwrap(); + let tracked_handles = node + .handles + .lock() + .expect("could not acquire lock on node handles"); for optional_handle in handles { if let Some(handle) = optional_handle { // Check handle is accessible by the node. @@ -276,6 +274,66 @@ impl Runtime { Ok(()) } + fn validate_can_read_from_channel( + &self, + node_id: NodeId, + channel_handle: Handle, + ) -> Result<(), OakStatus> { + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); + let node_label = &node.label; + + let channel_label = self.channels.with_channel( + self.channels.get_reader_channel(channel_handle)?, + |channel| Ok(channel.label.clone()), + )?; + + if channel_label.flows_to(node_label) { + Ok(()) + } else { + Err(OakStatus::ErrPermissionDenied) + } + } + + fn validate_can_read_from_channels( + &self, + node_id: NodeId, + channel_handles: I, + ) -> Result<(), OakStatus> + where + I: IntoIterator, + { + if channel_handles.into_iter().all(|channel_handle| { + self.validate_can_read_from_channel(node_id, channel_handle) + .is_ok() + }) { + Ok(()) + } else { + Err(OakStatus::ErrPermissionDenied) + } + } + + fn validate_can_write_to_channel( + &self, + node_id: NodeId, + channel_handle: Handle, + ) -> Result<(), OakStatus> { + let nodes = self.nodes.read().expect("could not acquire lock on nodes"); + let node = nodes.get(&node_id).expect("invalid node_id"); + let node_label = &node.label; + + let channel_label = self.channels.with_channel( + self.channels.get_writer_channel(channel_handle)?, + |channel| Ok(channel.label.clone()), + )?; + + if node_label.flows_to(&channel_label) { + Ok(()) + } else { + Err(OakStatus::ErrPermissionDenied) + } + } + /// Creates a new [`Channel`] and returns a `(writer handle, reader handle)` pair. pub fn new_channel(&self, node_id: NodeId, label: &oak_abi::label::Label) -> (Handle, Handle) { let (writer, reader) = self.channels.new_channel(label); @@ -321,6 +379,10 @@ impl Runtime { readers: &[Option], ) -> Result, OakStatus> { self.validate_handles_access(node_id, readers.iter())?; + self.validate_can_read_from_channels( + node_id, + readers.iter().filter_map(|x| x.as_ref()).copied(), + )?; let thread = thread::current(); while !self.is_terminating() { @@ -380,6 +442,7 @@ impl Runtime { msg: Message, ) -> Result<(), OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_write_to_channel(node_id, reference)?; self.channels.with_channel(self.channels.get_writer_channel(reference)?, |channel|{ if channel.is_orphan() { @@ -445,6 +508,7 @@ impl Runtime { reference: Handle, ) -> Result, OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let mut messages = channel.messages.write().unwrap(); @@ -474,6 +538,7 @@ impl Runtime { reference: Handle, ) -> Result { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let messages = channel.messages.read().unwrap(); @@ -502,6 +567,7 @@ impl Runtime { handles_capacity: usize, ) -> Result, OakStatus> { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; let result = self.channels .with_channel(self.channels.get_reader_channel(reference)?, |channel| { let mut messages = channel.messages.write().unwrap(); @@ -548,6 +614,7 @@ impl Runtime { reference: Handle, ) -> Result { self.validate_handle_access(node_id, reference)?; + self.validate_can_read_from_channel(node_id, reference)?; { let readers = self.channels.readers.read().unwrap(); if readers.contains_key(&reference) { diff --git a/sdk/rust/oak/src/io/mod.rs b/sdk/rust/oak/src/io/mod.rs index e7ce80dfbd3..7f875c7265b 100644 --- a/sdk/rust/oak/src/io/mod.rs +++ b/sdk/rust/oak/src/io/mod.rs @@ -68,5 +68,8 @@ pub fn error_from_nonok_status(status: OakStatus) -> io::Error { OakStatus::ErrInternal => io::Error::new(io::ErrorKind::Other, "Internal error"), OakStatus::ErrTerminated => io::Error::new(io::ErrorKind::Other, "Node terminated"), OakStatus::ErrChannelEmpty => io::Error::new(io::ErrorKind::UnexpectedEof, "Channel empty"), + OakStatus::ErrPermissionDenied => { + io::Error::new(io::ErrorKind::PermissionDenied, "Permission denied") + } } }