diff --git a/oak/server/rust/oak_runtime/src/node/logger.rs b/oak/server/rust/oak_runtime/src/node/logger.rs index 6b360c005af..217ef5f9c8e 100644 --- a/oak/server/rust/oak_runtime/src/node/logger.rs +++ b/oak/server/rust/oak_runtime/src/node/logger.rs @@ -87,10 +87,7 @@ fn logger( // in case a Wasm node wants to emit remaining messages. We will return once the channel is // closed. - // TODO(#646): Temporarily don't wait for messages when terminating. Renable when channels - // track their channels and make sure all channels are closed. - // let _ = runtime.wait_on_channels(&[Some(&reader)]); - runtime.wait_on_channels(node_id, &[Some(reader)])?; + let _ = runtime.wait_on_channels(node_id, &[Some(reader)]); if let Some(message) = runtime.channel_read(node_id, reader)? { match LogMessage::decode(&*message.data) { diff --git a/oak/server/rust/oak_runtime/src/runtime/channel.rs b/oak/server/rust/oak_runtime/src/runtime/channel.rs index 04049650873..c407e287986 100644 --- a/oak/server/rust/oak_runtime/src/runtime/channel.rs +++ b/oak/server/rust/oak_runtime/src/runtime/channel.rs @@ -20,6 +20,7 @@ use std::sync::atomic::Ordering::SeqCst; use std::sync::{Arc, Mutex, RwLock, Weak}; use std::thread::{Thread, ThreadId}; +use log::debug; use rand::RngCore; use oak_abi::OakStatus; @@ -111,6 +112,12 @@ impl Channel { self.readers.load(SeqCst) == 0 || self.writers.load(SeqCst) == 0 } + /// Thread safe method that returns true when there is no longer at least one reader or + /// writer. + pub fn has_no_reference(&self) -> bool { + self.readers.load(SeqCst) == 0 && self.writers.load(SeqCst) == 0 + } + /// Insert the given `thread` reference into `thread_id` slot of the HashMap of waiting /// channels attached to an underlying channel. This allows the channel to wake up any waiting /// channels by calling `thread::unpark` on all the threads it knows about. @@ -228,23 +235,41 @@ impl ChannelMapping { /// operations, and the underlying [`Channel`] may become orphaned. pub fn remove_reference(&self, reference: Handle) -> Result<(), OakStatus> { if let Ok(channel_id) = self.get_writer_channel(reference) { - self.with_channel(channel_id, |channel| { + { + let mut channels = self.channels.write().unwrap(); + let channel = channels + .get(&channel_id) + .expect("remove_reference: Handle is invalid!"); channel.remove_writer(); - Ok(()) - })?; + if channel.has_no_reference() { + channels.remove(&channel_id); + debug!("remove_reference: deallocating channel {:?}", channel_id); + } + } - let mut writers = self.writers.write().unwrap(); - writers.remove(&reference); + { + let mut writers = self.writers.write().unwrap(); + writers.remove(&reference); + } } if let Ok(channel_id) = self.get_reader_channel(reference) { - self.with_channel(channel_id, |channel| { + { + let mut channels = self.channels.write().unwrap(); + let channel = channels + .get(&channel_id) + .expect("remove_reference: Handle is invalid!"); channel.remove_reader(); - Ok(()) - })?; + if channel.has_no_reference() { + channels.remove(&channel_id); + debug!("remove_reference: deallocating channel {:?}", channel_id); + } + } - let mut readers = self.readers.write().unwrap(); - readers.remove(&reference); + { + let mut readers = self.readers.write().unwrap(); + readers.remove(&reference); + } } Ok(()) diff --git a/oak/server/rust/oak_runtime/src/runtime/mod.rs b/oak/server/rust/oak_runtime/src/runtime/mod.rs index 8faa94c2cec..10112f5f327 100644 --- a/oak/server/rust/oak_runtime/src/runtime/mod.rs +++ b/oak/server/rust/oak_runtime/src/runtime/mod.rs @@ -559,6 +559,15 @@ impl Runtime { /// Close a [`Handle`], potentially orphaning the underlying [`channel::Channel`]. pub fn channel_close(&self, node_id: NodeId, reference: Handle) -> Result<(), OakStatus> { self.validate_handle_access(node_id, reference)?; + + if node_id != RUNTIME_NODE_ID { + // Remove handle from the nodes available handles + let nodes = self.nodes.read().unwrap(); + let node = nodes.get(&node_id).expect("channel_close: No such node_id"); + let mut handles = node.handles.lock().unwrap(); + handles.remove(&reference); + } + self.channels.remove_reference(reference) } @@ -569,8 +578,32 @@ impl Runtime { /// Remove a [`Node`] by [`NodeId`] from the [`Runtime`]. pub fn remove_node_id(&self, node_id: NodeId) { + { + // Close any remaining handles + let remaining_handles: Vec<_> = { + let nodes = self.nodes.read().unwrap(); + let node = nodes + .get(&node_id) + .expect("remove_node_id: No such node_id"); + let handles = node.handles.lock().unwrap(); + handles.iter().copied().collect() + }; + + debug!( + "remove_node_id: node_id {:?} had open channels on exit: {:?}", + node_id, remaining_handles + ); + + for handle in remaining_handles { + self.channel_close(node_id, handle) + .expect("remove_node_id: Unable to close hanging channel!"); + } + } + let mut nodes = self.nodes.write().unwrap(); - nodes.remove(&node_id); + nodes + .remove(&node_id) + .expect("remove_node_id: Node didn't exist!"); } }