diff --git a/examples/grpc-reflection.graphql b/examples/grpc-reflection.graphql index cb79a65933..71a567d5de 100644 --- a/examples/grpc-reflection.graphql +++ b/examples/grpc-reflection.graphql @@ -2,7 +2,7 @@ schema @server(port: 8000) @upstream(baseURL: "http://localhost:50051", httpCache: 42, batch: {delay: 10}) - @link(src: "http://localhost:50051", type: Grpc) { + @link(src: "http://localhost:50051", type: Grpc, headers: [{key: "authorization", value: "Bearer 123"}]) { query: Query } diff --git a/generated/.tailcallrc.graphql b/generated/.tailcallrc.graphql index 2eb5cf2714..76c7f0aa08 100644 --- a/generated/.tailcallrc.graphql +++ b/generated/.tailcallrc.graphql @@ -202,6 +202,10 @@ The @link directive allows you to import external resources, such as configurati will be later used by `@grpc` directive –. """ directive @link( + """ + Custom headers for gRPC reflection server. + """ + headers: [KeyValue] """ The id of the link. It is used to reference the link in the schema. """ diff --git a/generated/.tailcallrc.schema.json b/generated/.tailcallrc.schema.json index 2261b63111..8d72151d21 100644 --- a/generated/.tailcallrc.schema.json +++ b/generated/.tailcallrc.schema.json @@ -808,6 +808,16 @@ "description": "The @link directive allows you to import external resources, such as configuration – which will be merged into the config importing it –, or a .proto file – which will be later used by `@grpc` directive –.", "type": "object", "properties": { + "headers": { + "description": "Custom headers for gRPC reflection server.", + "type": [ + "array", + "null" + ], + "items": { + "$ref": "#/definitions/KeyValue" + } + }, "id": { "description": "The id of the link. It is used to reference the link in the schema.", "type": [ diff --git a/src/core/config/link.rs b/src/core/config/link.rs index a71003b331..2675f44567 100644 --- a/src/core/config/link.rs +++ b/src/core/config/link.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use tailcall_macros::DirectiveDefinition; use super::super::is_default; +use super::KeyValue; #[derive( Default, @@ -57,6 +58,10 @@ pub struct Link { /// The type of the link. It can be `Config`, or `Protobuf`. #[serde(default, skip_serializing_if = "is_default", rename = "type")] pub type_of: LinkType, + /// + /// Custom headers for gRPC reflection server. + #[serde(default, skip_serializing_if = "is_default")] + pub headers: Option>, /// Additional metadata pertaining to the linked resource. #[serde(default, skip_serializing_if = "is_default")] pub meta: Option, diff --git a/src/core/config/reader.rs b/src/core/config/reader.rs index 46eee6107f..a4ff2cf742 100644 --- a/src/core/config/reader.rs +++ b/src/core/config/reader.rs @@ -123,7 +123,10 @@ impl ConfigReader { }) } LinkType::Grpc => { - let meta = self.proto_reader.fetch(link.src.as_str()).await?; + let meta = self + .proto_reader + .fetch(link.src.as_str(), link.headers.clone()) + .await?; for m in meta { extensions.add_proto(m); diff --git a/src/core/generator/generator.rs b/src/core/generator/generator.rs index a03c752661..0e27e2f688 100644 --- a/src/core/generator/generator.rs +++ b/src/core/generator/generator.rs @@ -90,6 +90,7 @@ impl Generator { id: None, src: metadata.path.to_owned(), type_of: LinkType::Protobuf, + headers: None, meta: None, }); Ok(config) diff --git a/src/core/grpc/data_loader_request.rs b/src/core/grpc/data_loader_request.rs index b8b02c3f35..a02111a5dd 100644 --- a/src/core/grpc/data_loader_request.rs +++ b/src/core/grpc/data_loader_request.rs @@ -73,6 +73,7 @@ mod tests { id: None, src: test_file.to_string(), type_of: LinkType::Protobuf, + headers: None, meta: None, }]); let method = GrpcMethod { diff --git a/src/core/grpc/protobuf.rs b/src/core/grpc/protobuf.rs index df28845dbd..ca0012b46b 100644 --- a/src/core/grpc/protobuf.rs +++ b/src/core/grpc/protobuf.rs @@ -266,6 +266,7 @@ pub mod tests { id: Some(id.clone()), src: path.to_string(), type_of: LinkType::Protobuf, + headers: None, meta: None, }]); diff --git a/src/core/grpc/request_template.rs b/src/core/grpc/request_template.rs index c96bd1a2f5..6b7e95242f 100644 --- a/src/core/grpc/request_template.rs +++ b/src/core/grpc/request_template.rs @@ -160,6 +160,7 @@ mod tests { id: Some(id.clone()), src: test_file.to_string(), type_of: LinkType::Protobuf, + headers: None, meta: None, }]); let method = GrpcMethod { diff --git a/src/core/proto_reader/fetch.rs b/src/core/proto_reader/fetch.rs index 87663885d0..6ddc57b68a 100644 --- a/src/core/proto_reader/fetch.rs +++ b/src/core/proto_reader/fetch.rs @@ -1,3 +1,5 @@ +use std::str::FromStr; + use anyhow::{Context, Result}; use base64::prelude::BASE64_STANDARD; use base64::Engine; @@ -9,7 +11,7 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use crate::core::blueprint::GrpcMethod; -use crate::core::config::ConfigReaderContext; +use crate::core::config::{ConfigReaderContext, KeyValue}; use crate::core::grpc::protobuf::ProtobufSet; use crate::core::grpc::request_template::RequestBody; use crate::core::grpc::RequestTemplate; @@ -72,11 +74,16 @@ struct ReflectionResponse { pub struct GrpcReflection { server_reflection_method: GrpcMethod, url: String, + headers: Option>, target_runtime: TargetRuntime, } impl GrpcReflection { - pub fn new>(url: T, target_runtime: TargetRuntime) -> Self { + pub fn new>( + url: T, + headers: Option>, + target_runtime: TargetRuntime, + ) -> Self { let server_reflection_method = GrpcMethod { package: "grpc.reflection.v1alpha".to_string(), service: "ServerReflection".to_string(), @@ -85,6 +92,7 @@ impl GrpcReflection { Self { server_reflection_method, url: url.as_ref().to_string(), + headers, target_runtime, } } @@ -135,16 +143,28 @@ impl GrpcReflection { ) .as_str(), ); + + let mut headers = vec![]; + if let Some(custom_headers) = &self.headers { + for header in custom_headers { + headers.push(( + HeaderName::from_str(&header.key)?, + Mustache::parse(header.value.as_str()), + )); + } + } + headers.push(( + HeaderName::from_static("content-type"), + Mustache::parse("application/grpc+proto"), + )); + let body_ = Some(RequestBody { + mustache: Some(Mustache::parse(body.to_string().as_str())), + value: Default::default(), + }); let req_template = RequestTemplate { url: Mustache::parse(url.as_str()), - headers: vec![( - HeaderName::from_static("content-type"), - Mustache::parse("application/grpc+proto"), - )], - body: Some(RequestBody { - mustache: Some(Mustache::parse(body.to_string().as_str())), - value: Default::default(), - }), + headers, + body: body_, operation: operation.clone(), operation_type: Default::default(), }; @@ -230,6 +250,7 @@ mod grpc_fetch { let grpc_reflection = GrpcReflection::new( format!("http://localhost:{}", server.port()), + None, crate::core::runtime::test::init(None), ); @@ -258,6 +279,7 @@ mod grpc_fetch { let grpc_reflection = GrpcReflection::new( format!("http://localhost:{}", server.port()), + None, crate::core::runtime::test::init(None), ); @@ -290,7 +312,7 @@ mod grpc_fetch { let runtime = crate::core::runtime::test::init(None); let grpc_reflection = - GrpcReflection::new(format!("http://localhost:{}", server.port()), runtime); + GrpcReflection::new(format!("http://localhost:{}", server.port()), None, runtime); let resp = grpc_reflection.list_all_files().await?; @@ -322,7 +344,7 @@ mod grpc_fetch { let runtime = crate::core::runtime::test::init(None); let grpc_reflection = - GrpcReflection::new(format!("http://localhost:{}", server.port()), runtime); + GrpcReflection::new(format!("http://localhost:{}", server.port()), None, runtime); let resp = grpc_reflection.list_all_files().await; @@ -349,7 +371,7 @@ mod grpc_fetch { let runtime = crate::core::runtime::test::init(None); let grpc_reflection = - GrpcReflection::new(format!("http://localhost:{}", server.port()), runtime); + GrpcReflection::new(format!("http://localhost:{}", server.port()), None, runtime); let result = grpc_reflection.get_by_service("nonexistent.Service").await; assert!(result.is_err()); @@ -358,4 +380,42 @@ mod grpc_fetch { Ok(()) } + + #[tokio::test] + async fn test_custom_headers_resp_list_all() -> Result<()> { + let server = start_mock_server(); + + let http_reflection_service_not_found = server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo") + .header("authorization", "Bearer 123"); + then.status(200).body(get_fake_resp()); + }); + + let runtime = crate::core::runtime::test::init(None); + + let grpc_reflection = GrpcReflection::new( + format!("http://localhost:{}", server.port()), + Some(vec![KeyValue { + key: "authorization".to_string(), + value: "Bearer 123".to_string(), + }]), + runtime, + ); + + let resp = grpc_reflection.list_all_files().await?; + + assert_eq!( + [ + "news.NewsService".to_string(), + "grpc.reflection.v1alpha.ServerReflection".to_string() + ] + .to_vec(), + resp + ); + + http_reflection_service_not_found.assert(); + + Ok(()) + } } diff --git a/src/core/proto_reader/reader.rs b/src/core/proto_reader/reader.rs index 1630be0ade..81c65cc16a 100644 --- a/src/core/proto_reader/reader.rs +++ b/src/core/proto_reader/reader.rs @@ -8,6 +8,7 @@ use futures_util::FutureExt; use prost_reflect::prost_types::{FileDescriptorProto, FileDescriptorSet}; use protox::file::{FileResolver, GoogleFileResolver}; +use crate::core::config::KeyValue; use crate::core::proto_reader::fetch::GrpcReflection; use crate::core::resource_reader::{Cached, ResourceReader}; use crate::core::runtime::TargetRuntime; @@ -31,8 +32,16 @@ impl ProtoReader { } /// Fetches proto files from a grpc server (grpc reflection) - pub async fn fetch>(&self, url: T) -> anyhow::Result> { - let grpc_reflection = Arc::new(GrpcReflection::new(url.as_ref(), self.runtime.clone())); + pub async fn fetch>( + &self, + url: T, + headers: Option>, + ) -> anyhow::Result> { + let grpc_reflection = Arc::new(GrpcReflection::new( + url.as_ref(), + headers, + self.runtime.clone(), + )); let mut proto_metadata = vec![]; let service_list = grpc_reflection.list_all_files().await?; diff --git a/tailcall-upstream-grpc/src/main.rs b/tailcall-upstream-grpc/src/main.rs index c42b75aec3..b20f68d18f 100644 --- a/tailcall-upstream-grpc/src/main.rs +++ b/tailcall-upstream-grpc/src/main.rs @@ -14,8 +14,9 @@ use opentelemetry_otlp::WithExportConfig; use opentelemetry_sdk::propagation::TraceContextPropagator; use opentelemetry_sdk::{runtime, Resource}; use tonic::metadata::MetadataMap; +use tonic::service::interceptor::InterceptedService; use tonic::transport::Server as TonicServer; -use tonic::{Response, Status}; +use tonic::{Request, Response, Status}; use tonic_tracing_opentelemetry::middleware::server; use tower::make::Shared; use tracing_subscriber::layer::SubscriberExt; @@ -215,6 +216,14 @@ fn init_tracer() -> Result<(), Error> { Ok(()) } +/// Intercepts the request and checks if the token is valid. +fn intercept(req: Request<()>) -> Result, Status> { + match req.metadata().get("authorization") { + Some(token) if token == "Bearer 123" => Ok(req), + _ => Err(Status::permission_denied("Unauthorized")), + } +} + #[tokio::main] async fn main() -> Result<(), Error> { if std::env::var("HONEYCOMB_API_KEY").is_ok() { @@ -234,7 +243,7 @@ async fn main() -> Result<(), Error> { let tonic_service = TonicServer::builder() .layer(server::OtelGrpcLayer::default()) .add_service(NewsServiceServer::new(news_service)) - .add_service(service) + .add_service(InterceptedService::new(service, intercept)) .into_service(); let make_svc = Shared::new(tonic_service); println!("Server listening on grpc://{}", addr);