From ec3ddd2ff85aa6adf1db855d1556b1d0f117ce33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Testi?= <34962582+andytesti@users.noreply.github.com> Date: Fri, 19 Jul 2024 04:52:08 -0300 Subject: [PATCH] Add support for nested gRPC callouts. (#240) Signed-off-by: andytesti --- src/dispatcher.rs | 77 ++++++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 35 deletions(-) diff --git a/src/dispatcher.rs b/src/dispatcher.rs index 4d1f16fa..671ef0f3 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -453,7 +453,8 @@ impl Dispatcher { } fn on_grpc_receive(&self, token_id: u32, response_size: usize) { - if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) { + let context_id = self.grpc_callouts.borrow_mut().remove(&token_id); + if let Some(context_id) = context_id { if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); @@ -467,24 +468,26 @@ impl Dispatcher { hostcalls::set_effective_context(context_id).unwrap(); root.on_grpc_call_response(token_id, 0, response_size); } - } else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) { - let context_id = *context_id; - if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - http_stream.on_grpc_stream_message(token_id, response_size); - } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - stream.on_grpc_stream_message(token_id, response_size); - } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - root.on_grpc_stream_message(token_id, response_size); - } } else { - // TODO: change back to a panic once underlying issue is fixed. - trace!("on_grpc_receive_initial_metadata: invalid token_id"); + let context_id = self.grpc_streams.borrow().get(&token_id).cloned(); + if let Some(context_id) = context_id { + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_stream_message(token_id, response_size); + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_stream_message(token_id, response_size); + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_stream_message(token_id, response_size); + } + } else { + // TODO: change back to a panic once underlying issue is fixed. + trace!("on_grpc_receive_initial_metadata: invalid token_id"); + } } } @@ -514,7 +517,8 @@ impl Dispatcher { } fn on_grpc_close(&self, token_id: u32, status_code: u32) { - if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) { + let context_id = self.grpc_callouts.borrow_mut().remove(&token_id); + if let Some(context_id) = context_id { if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { self.active_id.set(context_id); hostcalls::set_effective_context(context_id).unwrap(); @@ -528,23 +532,26 @@ impl Dispatcher { hostcalls::set_effective_context(context_id).unwrap(); root.on_grpc_call_response(token_id, status_code, 0); } - } else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) { - if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - http_stream.on_grpc_stream_close(token_id, status_code) - } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - stream.on_grpc_stream_close(token_id, status_code) - } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { - self.active_id.set(context_id); - hostcalls::set_effective_context(context_id).unwrap(); - root.on_grpc_stream_close(token_id, status_code) - } } else { - // TODO: change back to a panic once underlying issue is fixed. - trace!("on_grpc_close: invalid token_id, a non-connected stream has closed"); + let context_id = self.grpc_streams.borrow_mut().remove(&token_id); + if let Some(context_id) = context_id { + if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + http_stream.on_grpc_stream_close(token_id, status_code) + } else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + stream.on_grpc_stream_close(token_id, status_code) + } else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) { + self.active_id.set(context_id); + hostcalls::set_effective_context(context_id).unwrap(); + root.on_grpc_stream_close(token_id, status_code) + } + } else { + // TODO: change back to a panic once underlying issue is fixed. + trace!("on_grpc_close: invalid token_id, a non-connected stream has closed"); + } } } }