From f8c190fa21d59a1e478ce59bac19f88e0fe1e975 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Thu, 15 Jun 2023 15:43:23 +0200 Subject: [PATCH] Implement Ingress service --- Cargo.lock | 5 +- src/ingress_grpc/Cargo.toml | 1 + src/ingress_grpc/build.rs | 7 +- src/ingress_grpc/src/handler.rs | 71 +++++--- src/ingress_grpc/src/lib.rs | 15 ++ src/ingress_grpc/src/options.rs | 2 +- src/ingress_grpc/src/pb.rs | 27 +++ .../src/protocol/connect_adapter.rs | 154 +++++++++++++++++- src/ingress_grpc/src/protocol/mod.rs | 71 +++++--- .../src/protocol/tonic_adapter.rs | 2 +- src/ingress_grpc/src/reflection.rs | 58 ++++--- src/ingress_grpc/src/server.rs | 79 ++++++++- 12 files changed, 399 insertions(+), 93 deletions(-) create mode 100644 src/ingress_grpc/src/pb.rs diff --git a/Cargo.lock b/Cargo.lock index b6b87fe72a..51e4564934 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2720,9 +2720,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" @@ -3621,6 +3621,7 @@ dependencies = [ "http", "http-body", "hyper", + "once_cell", "opentelemetry", "opentelemetry-http", "pin-project", diff --git a/src/ingress_grpc/Cargo.toml b/src/ingress_grpc/Cargo.toml index c2db83cc73..5593c4b2d8 100644 --- a/src/ingress_grpc/Cargo.toml +++ b/src/ingress_grpc/Cargo.toml @@ -52,6 +52,7 @@ drain = { workspace = true } arc-swap = { workspace = true } thiserror = { workspace = true } schemars = { workspace = true, optional = true } +once_cell = "1.18" [dev-dependencies] hyper = { workspace = true, features = ["client"] } diff --git a/src/ingress_grpc/build.rs b/src/ingress_grpc/build.rs index c5897b0149..c2c467e184 100644 --- a/src/ingress_grpc/build.rs +++ b/src/ingress_grpc/build.rs @@ -12,9 +12,14 @@ fn main() -> std::io::Result<()> { .compile_protos( &[ "proto/grpc/reflection/v1alpha/reflection.proto", + "proto/dev/restate/services.proto", "tests/proto/greeter.proto", ], - &["proto/grpc/reflection/v1alpha", "tests/proto"], + &[ + "proto/grpc/reflection/v1alpha", + "proto/dev/restate", + "tests/proto", + ], )?; Ok(()) } diff --git a/src/ingress_grpc/src/handler.rs b/src/ingress_grpc/src/handler.rs index ac22f574a4..7de0d65413 100644 --- a/src/ingress_grpc/src/handler.rs +++ b/src/ingress_grpc/src/handler.rs @@ -1,6 +1,8 @@ use super::options::JsonOptions; +use super::pb::grpc::reflection::{ + server_reflection_server::ServerReflection, server_reflection_server::ServerReflectionServer, +}; use super::protocol::{BoxBody, Protocol}; -use super::reflection::{ServerReflection, ServerReflectionServer}; use super::*; use std::sync::Arc; @@ -12,6 +14,7 @@ use http::{Request, Response, StatusCode}; use http_body::Body; use hyper::Body as HyperBody; use opentelemetry::trace::{SpanContext, TraceContextExt}; +use prost::Message; use restate_common::types::{IngressId, ServiceInvocationResponseSink, SpanRelation}; use restate_service_metadata::MethodDescriptorRegistry; use tokio::sync::Semaphore; @@ -87,7 +90,7 @@ impl Service where InvocationFactory: ServiceInvocationFactory + Clone + Send + 'static, - MethodRegistry: MethodDescriptorRegistry, + MethodRegistry: MethodDescriptorRegistry + Clone + Send + 'static, ReflectionService: ServerReflection, { type Response = Response; @@ -144,6 +147,8 @@ where let method_name = path_parts.remove(2).to_string(); let service_name = path_parts.remove(1).to_string(); + // --- Special Restate services + // Reflections if ServerReflectionServer::::NAME == service_name { return self .reflection_server @@ -155,20 +160,6 @@ where .boxed(); } - // Find the service method descriptor - let descriptor = if let Some(desc) = self - .method_registry - .resolve_method_descriptor(&service_name, &method_name) - { - desc - } else { - debug!("{}/{} not found", service_name, method_name); - return ok(protocol.encode_status(Status::not_found(format!( - "{service_name}/{method_name} not found" - )))) - .boxed(); - }; - // Encapsulate in this closure the remaining part of the processing let ingress_id = self.ingress_id; let invocation_factory = self.invocation_factory.clone(); @@ -197,12 +188,30 @@ where let ingress_span_context = ingress_span.context().span().span_context().clone(); async move { + let mut service_name = req_headers.service_name; + let mut method_name = req_headers.method_name; + let mut req_payload = req_payload; + let mut response_sink = Some(ServiceInvocationResponseSink::Ingress(ingress_id)); + let mut wait_response = true; + + // Ingress built-in service + if is_ingress_invoke(&service_name, &method_name) { + let invoke_request = pb::restate::services::InvokeRequest::decode(req_payload) + .map_err(|e| Status::invalid_argument(e.to_string()))?; + + service_name = invoke_request.service; + method_name = invoke_request.method; + req_payload = invoke_request.argument; + response_sink = None; + wait_response = false; + } + // Create the service_invocation let (service_invocation, service_invocation_span) = match invocation_factory.create( - &req_headers.service_name, - &req_headers.method_name, + &service_name, + &method_name, req_payload, - Some(ServiceInvocationResponseSink::Ingress(ingress_id)), + response_sink, SpanRelation::Parent(ingress_span_context) ) { Ok(i) => i, @@ -222,8 +231,22 @@ where // https://docs.rs/tracing/latest/tracing/struct.Span.html#in-asynchronous-code let enter_service_invocation_span = service_invocation_span.enter(); - // More trace info - trace!(restate.invocation.request_headers = ?req_headers); + // Ingress built-in service just sends a fire and forget and closes + if !wait_response { + let sid = service_invocation.id.to_string(); + + if dispatcher_command_sender.send(Command::fire_and_forget( + service_invocation + )).is_err() { + debug!("Ingress dispatcher is closed while there is still an invocation in flight."); + return Err(Status::unavailable("Unavailable")); + } + return Ok( + pb::restate::services::InvokeResponse { + sid, + }.encode_to_vec().into() + ) + } // Send the service invocation let (service_invocation_command, response_rx) = @@ -259,7 +282,7 @@ where let result_fut = protocol.handle_request( service_name, method_name, - descriptor, + self.method_registry.clone(), self.json.clone(), req, ingress_request_handler, @@ -283,3 +306,7 @@ fn span_relation(request_span: &SpanContext) -> SpanRelation { SpanRelation::None } } + +fn is_ingress_invoke(service_name: &str, method_name: &str) -> bool { + "dev.restate.Ingress" == service_name && "Invoke" == method_name +} diff --git a/src/ingress_grpc/src/lib.rs b/src/ingress_grpc/src/lib.rs index 06397a21ed..33d07a042f 100644 --- a/src/ingress_grpc/src/lib.rs +++ b/src/ingress_grpc/src/lib.rs @@ -1,6 +1,7 @@ mod dispatcher; mod handler; mod options; +mod pb; mod protocol; mod reflection; mod server; @@ -223,9 +224,23 @@ mod mocks { pub(super) fn test_descriptor_registry() -> InMemoryMethodDescriptorRegistry { let registry = InMemoryMethodDescriptorRegistry::default(); registry.register(greeter_service_descriptor()); + registry.register(ingress_service_descriptor()); registry } + pub(super) fn ingress_service_descriptor() -> ServiceDescriptor { + crate::pb::DESCRIPTOR_POOL + .get_service_by_name("dev.restate.Ingress") + .unwrap() + } + + pub(super) fn ingress_invoke_method_descriptor() -> MethodDescriptor { + ingress_service_descriptor() + .methods() + .find(|m| m.name() == "Invoke") + .unwrap() + } + pub(super) fn greeter_service_descriptor() -> ServiceDescriptor { test_descriptor_pool() .services() diff --git a/src/ingress_grpc/src/options.rs b/src/ingress_grpc/src/options.rs index b98030817e..338ea3dad9 100644 --- a/src/ingress_grpc/src/options.rs +++ b/src/ingress_grpc/src/options.rs @@ -1,4 +1,4 @@ -use super::reflection::ServerReflection; +use super::pb::grpc::reflection::server_reflection_server::ServerReflection; use super::HyperServerIngress; use super::*; diff --git a/src/ingress_grpc/src/pb.rs b/src/ingress_grpc/src/pb.rs new file mode 100644 index 0000000000..4bf9120a8f --- /dev/null +++ b/src/ingress_grpc/src/pb.rs @@ -0,0 +1,27 @@ +use once_cell::sync::Lazy; +use prost_reflect::DescriptorPool; +use std::convert::AsRef; + +pub(crate) mod grpc { + pub(crate) mod reflection { + #![allow(warnings)] + #![allow(clippy::all)] + #![allow(unknown_lints)] + include!(concat!(env!("OUT_DIR"), "/grpc.reflection.v1alpha.rs")); + } +} +pub(crate) mod restate { + pub(crate) mod services { + #![allow(warnings)] + #![allow(clippy::all)] + #![allow(unknown_lints)] + include!(concat!(env!("OUT_DIR"), "/dev.restate.rs")); + } +} + +pub(crate) static DESCRIPTOR_POOL: Lazy = Lazy::new(|| { + DescriptorPool::decode( + include_bytes!(concat!(env!("OUT_DIR"), "/file_descriptor_set.bin")).as_ref(), + ) + .expect("The built-in descriptor pool should be valid") +}); diff --git a/src/ingress_grpc/src/protocol/connect_adapter.rs b/src/ingress_grpc/src/protocol/connect_adapter.rs index 336c200272..6f0f9a6faf 100644 --- a/src/ingress_grpc/src/protocol/connect_adapter.rs +++ b/src/ingress_grpc/src/protocol/connect_adapter.rs @@ -1,18 +1,19 @@ use bytes::Buf; +use content_type::ConnectContentType; use http::header::{CONTENT_ENCODING, CONTENT_TYPE}; use http::request::Parts; use http::{Method, Request, Response, StatusCode}; use hyper::Body; use prost_reflect::{DeserializeOptions, DynamicMessage, MethodDescriptor, SerializeOptions}; +use restate_service_metadata::MethodDescriptorRegistry; use serde::Serialize; use tonic::{Code, Status}; use tracing::warn; -use content_type::ConnectContentType; - -pub(super) async fn decode_request( +pub(super) async fn decode_request( request: Request, method_descriptor: &MethodDescriptor, + method_registry: MethodRegistry, deserialize_options: DeserializeOptions, ) -> Result<(ConnectContentType, DynamicMessage), Response> { let (mut parts, body) = request.into_parts(); @@ -51,6 +52,7 @@ pub(super) async fn decode_request( let msg = match content_type::read_message( content_type, method_descriptor.input(), + method_registry, deserialize_options, body, ) { @@ -130,7 +132,9 @@ pub(super) mod content_type { use bytes::{Buf, BufMut, BytesMut}; use http::HeaderValue; use prost::Message; - use prost_reflect::MessageDescriptor; + use prost_reflect::{MessageDescriptor, Value}; + use serde::de::IntoDeserializer; + use serde::Deserialize; use tower::BoxError; const APPLICATION_JSON: &str = "application/json"; @@ -162,13 +166,22 @@ pub(super) mod content_type { None } - pub(super) fn read_message( + pub(super) fn read_message( content_type: ConnectContentType, msg_desc: MessageDescriptor, + method_registry: MethodRegistry, deserialize_options: DeserializeOptions, payload_buf: impl Buf + Sized, ) -> Result { match content_type { + ConnectContentType::Json if msg_desc.full_name() == "dev.restate.InvokeRequest" => { + read_json_invoke_request( + msg_desc, + method_registry, + deserialize_options, + payload_buf, + ) + } ConnectContentType::Json => { let mut deser = serde_json::Deserializer::from_reader(payload_buf.reader()); let dynamic_message = DynamicMessage::deserialize_with_options( @@ -183,6 +196,49 @@ pub(super) mod content_type { } } + // TODO this is a bit of a awful hack to parse the InvokeRequest.argument as json object. + // We'll improve this with a JsonMapper interface, as described here: + // https://github.com/restatedev/restate/issues/43 + pub(super) fn read_json_invoke_request( + invoke_request_msg_desc: MessageDescriptor, + method_registry: MethodRegistry, + deserialize_options: DeserializeOptions, + payload_buf: impl Buf + Sized, + ) -> Result { + #[derive(Deserialize)] + struct InvokeRequestAdapter { + service: String, + method: String, + argument: serde_json::Value, + } + + let adapter: InvokeRequestAdapter = serde_json::from_reader(payload_buf.reader())?; + + let descriptor = method_registry + .resolve_method_descriptor(&adapter.service, &adapter.method) + // TODO this error propagation is not great + .ok_or_else(|| { + Status::not_found(format!("{}/{} not found", adapter.service, adapter.method)) + })?; + + let argument_dynamic_message = DynamicMessage::deserialize_with_options( + descriptor.input(), + adapter.argument.into_deserializer(), + &deserialize_options, + )?; + + let mut invoke_req_msg = DynamicMessage::new(invoke_request_msg_desc); + invoke_req_msg.set_field_by_name("service", Value::String(adapter.service)); + invoke_req_msg.set_field_by_name("method", Value::String(adapter.method)); + // TODO can skip this serialization by implementing prost::Message on InvokeRequestAdapter + invoke_req_msg.set_field_by_name( + "argument", + Value::Bytes(argument_dynamic_message.encode_to_vec().into()), + ); + + Ok(invoke_req_msg) + } + pub(super) fn write_message( content_type: ConnectContentType, serialize_options: SerializeOptions, @@ -208,7 +264,7 @@ pub(super) mod content_type { #[cfg(test)] mod tests { use super::*; - use crate::mocks::greeter_get_count_method_descriptor; + use crate::mocks::{greeter_get_count_method_descriptor, test_descriptor_registry}; use bytes::Bytes; use http::HeaderValue; @@ -232,6 +288,7 @@ pub(super) mod content_type { read_message( ConnectContentType::Json, greeter_get_count_method_descriptor().input(), + test_descriptor_registry(), DeserializeOptions::default(), Bytes::from("{}"), ) @@ -377,7 +434,10 @@ pub(super) mod status { mod tests { use super::*; - use crate::mocks::{greeter_greet_method_descriptor, pb}; + use crate::mocks::{ + greeter_greet_method_descriptor, ingress_invoke_method_descriptor, pb, + test_descriptor_registry, + }; use bytes::Bytes; use http::StatusCode; use http_body::Body; @@ -395,6 +455,7 @@ mod tests { .body(json!({"person": "Francesco"}).to_string().into()) .unwrap(), &greeter_greet_method_descriptor(), + test_descriptor_registry(), DeserializeOptions::default(), ) .await @@ -427,6 +488,7 @@ mod tests { ) .unwrap(), &greeter_greet_method_descriptor(), + test_descriptor_registry(), DeserializeOptions::default(), ) .await @@ -443,6 +505,82 @@ mod tests { assert_eq!(ct, ConnectContentType::Protobuf); } + #[test(tokio::test)] + async fn decode_invoke_json() { + let json_payload = json!({ + "service": "greeter.Greeter", + "method": "Greet", + "argument": { + "person": "Francesco" + } + }); + + let (ct, request_payload) = decode_request( + Request::builder() + .uri("http://localhost/dev.restate.Ingress/Invoke") + .method(Method::POST) + .header(CONTENT_TYPE, "application/json") + .body(json_payload.to_string().into()) + .unwrap(), + &ingress_invoke_method_descriptor(), + test_descriptor_registry(), + DeserializeOptions::default(), + ) + .await + .unwrap(); + + assert_eq!( + request_payload + .transcode_to::() + .unwrap(), + crate::pb::restate::services::InvokeRequest { + service: "greeter.Greeter".to_string(), + method: "Greet".to_string(), + argument: pb::GreetingRequest { + person: "Francesco".to_string(), + } + .encode_to_vec() + .into(), + } + ); + assert_eq!(ct, ConnectContentType::Json); + } + + #[test(tokio::test)] + async fn decode_invoke_protobuf() { + let invoke_request = crate::pb::restate::services::InvokeRequest { + service: "greeter.Greeter".to_string(), + method: "Greet".to_string(), + argument: pb::GreetingRequest { + person: "Francesco".to_string(), + } + .encode_to_vec() + .into(), + }; + + let (ct, request_payload) = decode_request( + Request::builder() + .uri("http://localhost/dev.restate.Ingress/Invoke") + .method(Method::POST) + .header(CONTENT_TYPE, "application/protobuf") + .body(invoke_request.encode_to_vec().into()) + .unwrap(), + &greeter_greet_method_descriptor(), + test_descriptor_registry(), + DeserializeOptions::default(), + ) + .await + .unwrap(); + + assert_eq!( + request_payload + .transcode_to::() + .unwrap(), + invoke_request + ); + assert_eq!(ct, ConnectContentType::Protobuf); + } + #[test(tokio::test)] async fn decode_wrong_http_method() { let err_response = decode_request( @@ -453,6 +591,7 @@ mod tests { .body(json!({"person": "Francesco"}).to_string().into()) .unwrap(), &greeter_greet_method_descriptor(), + test_descriptor_registry(), DeserializeOptions::default(), ) .await @@ -471,6 +610,7 @@ mod tests { .body("person: Francesco".to_string().into()) .unwrap(), &greeter_greet_method_descriptor(), + test_descriptor_registry(), DeserializeOptions::default(), ) .await diff --git a/src/ingress_grpc/src/protocol/mod.rs b/src/ingress_grpc/src/protocol/mod.rs index 248281dec1..8964f1bf3b 100644 --- a/src/ingress_grpc/src/protocol/mod.rs +++ b/src/ingress_grpc/src/protocol/mod.rs @@ -15,10 +15,12 @@ use opentelemetry::propagation::TextMapPropagator; use opentelemetry::sdk::propagation::TraceContextPropagator; use prost::Message; use prost_reflect::MethodDescriptor; +use restate_service_metadata::MethodDescriptorRegistry; use tonic::server::Grpc; use tonic::Status; use tower::{BoxError, Layer, Service}; use tower_utils::service_fn_once; +use tracing::debug; pub(crate) enum Protocol { // Use tonic (gRPC or gRPC-Web) @@ -50,19 +52,32 @@ impl Protocol { } } - pub(crate) async fn handle_request( + pub(crate) async fn handle_request( self, service_name: String, method_name: String, - descriptor: MethodDescriptor, + method_registry: MethodRegistry, json: JsonOptions, req: Request, - handler_fn: H, + handler_fn: Handler, ) -> Result, BoxError> where - H: FnOnce(IngressRequest) -> F + Clone + Send + 'static, - F: Future + Send, + MethodRegistry: MethodDescriptorRegistry, + Handler: FnOnce(IngressRequest) -> HandlerFut + Send + 'static, + HandlerFut: Future + Send, { + // Find the service method descriptor + let descriptor = if let Some(desc) = + method_registry.resolve_method_descriptor(&service_name, &method_name) + { + desc + } else { + debug!("{}/{} not found", service_name, method_name); + return Ok(self.encode_status(Status::not_found(format!( + "{service_name}/{method_name} not found" + )))); + }; + // Extract tracing context if any let tracing_context = TraceContextPropagator::new() .extract(&opentelemetry_http::HeaderExtractor(req.headers())); @@ -78,6 +93,7 @@ impl Protocol { Protocol::Connect => Ok(Self::handle_connect_request( ingress_request_headers, descriptor, + method_registry, json, req, handler_fn, @@ -86,16 +102,14 @@ impl Protocol { } } - async fn handle_tonic_request( + async fn handle_tonic_request( ingress_request_headers: IngressRequestHeaders, req: Request, - handler_fn: H, + handler_fn: Handler, ) -> Result, BoxError> where - // TODO Clone bound is not needed, - // remove it once https://github.com/hyperium/tonic/issues/1290 is released - H: FnOnce(IngressRequest) -> F + Clone + Send + 'static, - F: Future + Send, + Handler: FnOnce(IngressRequest) -> HandlerFut + Send + 'static, + HandlerFut: Future + Send, { // Why FnOnce and service_fn_once are safe here? // @@ -129,24 +143,30 @@ impl Protocol { .map(|res| res.map(to_box_body)) } - async fn handle_connect_request( + async fn handle_connect_request( ingress_request_headers: IngressRequestHeaders, descriptor: MethodDescriptor, + method_registry: MethodRegistry, json: JsonOptions, req: Request, - handler_fn: H, + handler_fn: Handler, ) -> Response where - H: FnOnce(IngressRequest) -> F + Send + 'static, - F: Future + Send, + MethodRegistry: MethodDescriptorRegistry, + Handler: FnOnce(IngressRequest) -> HandlerFut + Send + 'static, + HandlerFut: Future + Send, { - let (content_type, request_message) = - match connect_adapter::decode_request(req, &descriptor, json.to_deserialize_options()) - .await - { - Ok(req) => req, - Err(error_res) => return error_res.map(to_box_body), - }; + let (content_type, request_message) = match connect_adapter::decode_request( + req, + &descriptor, + method_registry, + json.to_deserialize_options(), + ) + .await + { + Ok(req) => req, + Err(error_res) => return error_res.map(to_box_body), + }; let ingress_request_body = Bytes::from(request_message.encode_to_vec()); let response = match handler_fn((ingress_request_headers, ingress_request_body)).await { @@ -188,8 +208,10 @@ mod tests { use serde_json::json; fn greeter_service_fn(ingress_req: IngressRequest) -> Ready { - let person = pb::GreetingRequest::decode(ingress_req.1).unwrap().person; - ok(pb::GreetingResponse { + let person = mocks::pb::GreetingRequest::decode(ingress_req.1) + .unwrap() + .person; + ok(mocks::pb::GreetingResponse { greeting: format!("Hello {person}"), } .encode_to_vec() @@ -218,6 +240,7 @@ mod tests { Context::default(), ), greeter_greet_method_descriptor(), + test_descriptor_registry(), JsonOptions::default(), request, greeter_service_fn, diff --git a/src/ingress_grpc/src/protocol/tonic_adapter.rs b/src/ingress_grpc/src/protocol/tonic_adapter.rs index 8159e3a16d..61d71efa44 100644 --- a/src/ingress_grpc/src/protocol/tonic_adapter.rs +++ b/src/ingress_grpc/src/protocol/tonic_adapter.rs @@ -28,7 +28,7 @@ impl TonicUnaryServiceAdapter { impl UnaryService for TonicUnaryServiceAdapter where - H: FnOnce(IngressRequest) -> F + Clone + Send, + H: FnOnce(IngressRequest) -> F + Send, F: Future + Send, { type Response = Bytes; diff --git a/src/ingress_grpc/src/reflection.rs b/src/ingress_grpc/src/reflection.rs index c665f7462a..1987d6ade5 100644 --- a/src/ingress_grpc/src/reflection.rs +++ b/src/ingress_grpc/src/reflection.rs @@ -9,6 +9,10 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; +use crate::pb::grpc::reflection::server_reflection_request::MessageRequest; +use crate::pb::grpc::reflection::server_reflection_response::MessageResponse; +use crate::pb::grpc::reflection::server_reflection_server::*; +use crate::pb::grpc::reflection::*; use arc_swap::{ArcSwap, Guard}; use bytes::Bytes; use futures::{Stream, StreamExt}; @@ -18,19 +22,9 @@ use prost_reflect::{DescriptorPool, FileDescriptor, ServiceDescriptor}; use tonic::{Request, Response, Status, Streaming}; use tracing::{debug, trace}; -mod pb { - #![allow(warnings)] - #![allow(clippy::all)] - #![allow(unknown_lints)] - include!(concat!(env!("OUT_DIR"), "/grpc.reflection.v1alpha.rs")); -} - -pub use pb::server_reflection_server::ServerReflection; -pub use pb::server_reflection_server::ServerReflectionServer; - #[derive(Debug, Clone, Default)] struct ReflectionServiceState { - service_names: Vec, + service_names: Vec, // The usize here is used for reference count files: HashMap, @@ -63,7 +57,7 @@ impl ReflectionServiceState { // This insert retains the order self.service_names.insert( insert_index, - pb::ServiceResponse { + ServiceResponse { name: service_name.clone(), }, ); @@ -123,11 +117,31 @@ impl ReflectionServiceState { } } -#[derive(Default, Clone)] +#[derive(Clone)] pub struct ReflectionRegistry { reflection_service_state: Arc>, } +impl Default for ReflectionRegistry { + fn default() -> Self { + let mut registry = Self { + reflection_service_state: Default::default(), + }; + registry + .register_new_services( + "self_ingress".to_string(), + vec![ + "grpc.reflection.v1alpha.ServerReflection".to_string(), + "dev.restate.Ingress".to_string(), + ], + crate::pb::DESCRIPTOR_POOL.clone(), + ) + .expect("Registering self_ingress in the reflections should not fail"); + + registry + } +} + #[derive(Debug, thiserror::Error)] pub enum RegistrationError { #[error("missing expected field {0} in descriptor")] @@ -338,18 +352,13 @@ fn extract_name( } pub struct ReflectionServiceStream { - request_stream: Streaming, + request_stream: Streaming, state: Guard>, } impl ReflectionServiceStream { - fn handle_request( - &self, - request: &pb::server_reflection_request::MessageRequest, - ) -> Result { - use pb::server_reflection_request::*; - use pb::server_reflection_response::*; - use pb::*; + fn handle_request(&self, request: &MessageRequest) -> Result { + use crate::pb::grpc::reflection::*; Ok(match request { MessageRequest::FileByFilename(f) => self @@ -389,12 +398,9 @@ impl ReflectionServiceStream { } impl Stream for ReflectionServiceStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - use pb::server_reflection_response::*; - use pb::*; - let request = match ready!(self.request_stream.poll_next_unpin(cx)) { Some(Ok(req)) => req, _ => return Poll::Ready(None), @@ -430,7 +436,7 @@ impl ServerReflection for ReflectionRegistry { async fn server_reflection_info( &self, - request: Request>, + request: Request>, ) -> Result, Status> { Ok(Response::new(ReflectionServiceStream { request_stream: request.into_inner(), diff --git a/src/ingress_grpc/src/server.rs b/src/ingress_grpc/src/server.rs index 387b39cde1..0fc5a8e01f 100644 --- a/src/ingress_grpc/src/server.rs +++ b/src/ingress_grpc/src/server.rs @@ -1,7 +1,7 @@ use super::options::JsonOptions; +use super::pb::grpc::reflection::server_reflection_server::ServerReflection; use super::*; -use crate::reflection::ServerReflection; use codederror::CodedError; use futures::FutureExt; use restate_common::types::IngressId; @@ -189,7 +189,7 @@ mod tests { let cmd_fut = tokio::spawn(async move { let (service_invocation, response_tx) = cmd_fut.await.unwrap().unwrap().into_inner(); response_tx - .send(Ok(pb::GreetingResponse { + .send(Ok(mocks::pb::GreetingResponse { greeting: "Igal".to_string(), } .encode_to_vec() @@ -218,7 +218,8 @@ mod tests { "greeter.Greeter" ); assert_eq!(service_invocation.method_name, "Greet"); - let greeting_req = pb::GreetingRequest::decode(&mut service_invocation.argument).unwrap(); + let greeting_req = + mocks::pb::GreetingRequest::decode(&mut service_invocation.argument).unwrap(); assert_eq!(&greeting_req.person, "Francesco"); // Read the http_response_future @@ -239,9 +240,67 @@ mod tests { ingress_handle.await.unwrap().unwrap(); } + #[test(tokio::test)] + async fn test_ingress_service_http_connect_call() { + let (drain, address, cmd_fut, ingress_handle) = bootstrap_test().await; + let cmd_fut = tokio::spawn(async move { + let (service_invocation, response_tx) = cmd_fut.await.unwrap().unwrap().into_inner(); + assert!(response_tx.is_closed()); + service_invocation + }); + + // Send the request + let json_payload = json!({ + "service": "greeter.Greeter", + "method": "Greet", + "argument": { + "person": "Francesco" + } + }); + let http_response = hyper::Client::new() + .request( + hyper::Request::post(format!("http://{address}/dev.restate.Ingress/Invoke")) + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_vec(&json_payload).unwrap())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(http_response.status(), StatusCode::OK); + + // Get the function invocation and assert on it + let mut service_invocation = cmd_fut.await.unwrap(); + assert_eq!( + service_invocation.id.service_id.service_name, + "greeter.Greeter" + ); + assert_eq!(service_invocation.method_name, "Greet"); + let greeting_req = + mocks::pb::GreetingRequest::decode(&mut service_invocation.argument).unwrap(); + assert_eq!(&greeting_req.person, "Francesco"); + assert!(service_invocation.response_sink.is_none()); + + // Read the http_response_future + let (_, response_body) = http_response.into_parts(); + let response_bytes = hyper::body::to_bytes(response_body).await.unwrap(); + let response_json_value: serde_json::Value = + serde_json::from_slice(&response_bytes).unwrap(); + let sid: ServiceInvocationId = response_json_value + .get("sid") + .unwrap() + .as_str() + .unwrap() + .parse() + .unwrap(); + assert_eq!(sid.service_id.service_name, "greeter.Greeter"); + + drain.drain().await; + ingress_handle.await.unwrap().unwrap(); + } + #[test(tokio::test)] async fn test_grpc_call() { - let expected_greeting_response = pb::GreetingResponse { + let expected_greeting_response = mocks::pb::GreetingResponse { greeting: "Igal".to_string(), }; let encoded_greeting_response = Bytes::from(expected_greeting_response.encode_to_vec()); @@ -253,12 +312,13 @@ mod tests { service_invocation }); - let mut client = pb::greeter_client::GreeterClient::connect(format!("http://{address}")) - .await - .unwrap(); + let mut client = + mocks::pb::greeter_client::GreeterClient::connect(format!("http://{address}")) + .await + .unwrap(); let response = client - .greet(pb::GreetingRequest { + .greet(mocks::pb::GreetingRequest { person: "Francesco".to_string(), }) .await @@ -270,7 +330,8 @@ mod tests { "greeter.Greeter" ); assert_eq!(service_invocation.method_name, "Greet"); - let greeting_req = pb::GreetingRequest::decode(&mut service_invocation.argument).unwrap(); + let greeting_req = + mocks::pb::GreetingRequest::decode(&mut service_invocation.argument).unwrap(); assert_eq!(&greeting_req.person, "Francesco"); // Read the http_response_future