diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs
new file mode 100644
index 000000000..17bbc9cdd
--- /dev/null
+++ b/tests/integration_tests/tests/origin.rs
@@ -0,0 +1,100 @@
+use futures::future::BoxFuture;
+use futures_util::FutureExt;
+use integration_tests::pb::test_client;
+use integration_tests::pb::{test_server, Input, Output};
+use std::task::Context;
+use std::task::Poll;
+use std::time::Duration;
+use tokio::sync::oneshot;
+use tonic::codegen::http::Request;
+use tonic::{
+ transport::{Endpoint, Server},
+ Response, Status,
+};
+use tower::Layer;
+use tower::Service;
+
+#[tokio::test]
+async fn writes_origin_header() {
+ struct Svc;
+
+ #[tonic::async_trait]
+ impl test_server::Test for Svc {
+ async fn unary_call(
+ &self,
+ _req: tonic::Request,
+ ) -> Result, Status> {
+ Ok(Response::new(Output {}))
+ }
+ }
+
+ let svc = test_server::TestServer::new(Svc);
+
+ let (tx, rx) = oneshot::channel::<()>();
+
+ let jh = tokio::spawn(async move {
+ Server::builder()
+ .layer(OriginLayer {})
+ .add_service(svc)
+ .serve_with_shutdown("127.0.0.1:1442".parse().unwrap(), rx.map(drop))
+ .await
+ .unwrap();
+ });
+
+ tokio::time::sleep(Duration::from_millis(100)).await;
+
+ let channel = Endpoint::from_static("http://127.0.0.1:1442")
+ .origin("https://docs.rs".parse().expect("valid uri"))
+ .connect()
+ .await
+ .unwrap();
+
+ let mut client = test_client::TestClient::new(channel);
+
+ match client.unary_call(Input {}).await {
+ Ok(_) => {}
+ Err(status) => panic!("{}", status.message()),
+ }
+
+ tx.send(()).unwrap();
+
+ jh.await.unwrap();
+}
+
+#[derive(Clone)]
+struct OriginLayer {}
+
+impl Layer for OriginLayer {
+ type Service = OriginService;
+
+ fn layer(&self, inner: S) -> Self::Service {
+ OriginService { inner }
+ }
+}
+
+#[derive(Clone)]
+struct OriginService {
+ inner: S,
+}
+
+impl Service> for OriginService
+where
+ T: Service>,
+ T::Future: Send + 'static,
+ T::Error: Into>,
+{
+ type Response = T::Response;
+ type Error = Box;
+ type Future = BoxFuture<'static, Result>;
+
+ fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> {
+ self.inner.poll_ready(cx).map_err(Into::into)
+ }
+
+ fn call(&mut self, req: Request) -> Self::Future {
+ assert_eq!(req.uri().host(), Some("docs.rs"));
+ let fut = self.inner.call(req);
+
+ Box::pin(async move { fut.await.map_err(Into::into) })
+ }
+}
diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs
index 5d08e22f3..32ce858b6 100644
--- a/tonic/src/transport/channel/endpoint.rs
+++ b/tonic/src/transport/channel/endpoint.rs
@@ -24,6 +24,7 @@ use tower::make::MakeConnection;
#[derive(Clone)]
pub struct Endpoint {
pub(crate) uri: Uri,
+ pub(crate) origin: Option,
pub(crate) user_agent: Option,
pub(crate) timeout: Option,
pub(crate) concurrency_limit: Option,
@@ -106,6 +107,25 @@ impl Endpoint {
.map_err(|_| Error::new_invalid_user_agent())
}
+ /// Set a custom origin.
+ ///
+ /// Override the `origin`, mainly useful when you are reaching a Server/LoadBalancer
+ /// which serves multiple services at the same time.
+ /// It will play the role of SNI (Server Name Indication).
+ ///
+ /// ```
+ /// # use tonic::transport::Endpoint;
+ /// # let mut builder = Endpoint::from_static("https://proxy.com");
+ /// builder.origin("https://example.com".parse().expect("http://example.com must be a valid URI"));
+ /// // origin: "https://example.com"
+ /// ```
+ pub fn origin(self, origin: Uri) -> Self {
+ Endpoint {
+ origin: Some(origin),
+ ..self
+ }
+ }
+
/// Apply a timeout to each request.
///
/// ```
@@ -395,6 +415,7 @@ impl From for Endpoint {
fn from(uri: Uri) -> Self {
Self {
uri,
+ origin: None,
user_agent: None,
concurrency_limit: None,
rate_limit: None,
diff --git a/tonic/src/transport/service/add_origin.rs b/tonic/src/transport/service/add_origin.rs
index 50d5b6d96..b706ee995 100644
--- a/tonic/src/transport/service/add_origin.rs
+++ b/tonic/src/transport/service/add_origin.rs
@@ -1,4 +1,6 @@
use futures_core::future::BoxFuture;
+use http::uri::Authority;
+use http::uri::Scheme;
use http::{Request, Uri};
use std::task::{Context, Poll};
use tower_service::Service;
@@ -6,12 +8,21 @@ use tower_service::Service;
#[derive(Debug)]
pub(crate) struct AddOrigin {
inner: T,
- origin: Uri,
+ scheme: Option,
+ authority: Option,
}
impl AddOrigin {
pub(crate) fn new(inner: T, origin: Uri) -> Self {
- Self { inner, origin }
+ let http::uri::Parts {
+ scheme, authority, ..
+ } = origin.into_parts();
+
+ Self {
+ inner,
+ scheme,
+ authority,
+ }
}
}
@@ -30,24 +41,24 @@ where
}
fn call(&mut self, req: Request) -> Self::Future {
- // Split the request into the head and the body.
- let (mut head, body) = req.into_parts();
-
- // Split the request URI into parts.
- let mut uri: http::uri::Parts = head.uri.into();
- let set_uri = self.origin.clone().into_parts();
-
- if set_uri.scheme.is_none() || set_uri.authority.is_none() {
+ if self.scheme.is_none() || self.authority.is_none() {
let err = crate::transport::Error::new_invalid_uri();
return Box::pin(async move { Err::(err.into()) });
}
- // Update the URI parts, setting hte scheme and authority
- uri.scheme = Some(set_uri.scheme.expect("expected scheme"));
- uri.authority = Some(set_uri.authority.expect("expected authority"));
+ // Split the request into the head and the body.
+ let (mut head, body) = req.into_parts();
// Update the the request URI
- head.uri = http::Uri::from_parts(uri).expect("valid uri");
+ head.uri = {
+ // Split the request URI into parts.
+ let mut uri: http::uri::Parts = head.uri.into();
+ // Update the URI parts, setting hte scheme and authority
+ uri.scheme = self.scheme.clone();
+ uri.authority = self.authority.clone();
+
+ http::Uri::from_parts(uri).expect("valid uri")
+ };
let request = Request::from_parts(head, body);
diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs
index 3aee2681c..a12f2d6a9 100644
--- a/tonic/src/transport/service/connection.rs
+++ b/tonic/src/transport/service/connection.rs
@@ -55,7 +55,11 @@ impl Connection {
}
let stack = ServiceBuilder::new()
- .layer_fn(|s| AddOrigin::new(s, endpoint.uri.clone()))
+ .layer_fn(|s| {
+ let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone();
+
+ AddOrigin::new(s, origin)
+ })
.layer_fn(|s| UserAgent::new(s, endpoint.user_agent.clone()))
.layer_fn(|s| GrpcTimeout::new(s, endpoint.timeout))
.option_layer(endpoint.concurrency_limit.map(ConcurrencyLimitLayer::new))