diff --git a/crates/proxy/src/lib.rs b/crates/proxy/src/lib.rs index e6ebdb247..1cb3fa59b 100644 --- a/crates/proxy/src/lib.rs +++ b/crates/proxy/src/lib.rs @@ -24,9 +24,9 @@ use tokio::io::copy_bidirectional; type HyperRequest = hyper::Request; type HyperResponse = hyper::Response; -/// Encode url path. This can be used when build your custom url rest getter. +/// Encode url path. This can be used when build your custom url path getter. #[inline] -pub fn encode_url_path(path: &str) -> String { +pub(crate) fn encode_url_path(path: &str) -> String { path.split('/') .map(|s| utf8_percent_encode(s, CONTROLS).to_string()) .collect::>() @@ -79,21 +79,21 @@ where } } -/// Url rest getter. You can use this to get the rest of the url. -pub type UrlRestGetter = Box String + Send + Sync + 'static>; +/// Url part getter. You can use this to get the proxied url path or query. +pub type UrlPartGetter = Box Option + Send + Sync + 'static>; -/// Default url rest getter. -pub fn default_url_rest_getter(req: &Request, _depot: &Depot) -> String { +/// Default url path getter. This getter will get the url path from request wildcard param, like `<*rest>`, `<**rest>`. +pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option { let param = req.params().iter().find(|(key, _)| key.starts_with('*')); - let mut rest = if let Some((_, rest)) = param { - encode_url_path(rest) + if let Some((_, rest)) = param { + Some(encode_url_path(rest)) } else { - "".into() - }; - if let Some(query) = req.uri().query() { - rest = format!("{}?{}", rest, query); + None } - rest +} +/// Default url query getter. This getter just return the query string from request uri. +pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option { + req.uri().query().map(Into::into) } /// Proxy @@ -104,8 +104,10 @@ pub struct Proxy { pub upstreams: U, /// [`Client`] for proxy. pub client: Client, - /// Url rest getter. - pub url_rest_getter: UrlRestGetter, + /// Url path getter. + pub url_path_getter: UrlPartGetter, + /// Url query getter. + pub url_query_getter: UrlPartGetter, } impl Proxy @@ -118,7 +120,8 @@ where Proxy { upstreams, client: Client::new(), - url_rest_getter: Box::new(default_url_rest_getter), + url_path_getter: Box::new(default_url_path_getter), + url_query_getter: Box::new(default_url_query_getter), } } /// Create new `Proxy` with upstreams list and [`Client`]. @@ -126,17 +129,28 @@ where Proxy { upstreams, client, - url_rest_getter: Box::new(default_url_rest_getter), + url_path_getter: Box::new(default_url_path_getter), + url_query_getter: Box::new(default_url_query_getter), } } - /// Set url rest getter. + /// Set url path getter. #[inline] - pub fn url_rest_getter(mut self, url_rest_getter: G) -> Self + pub fn url_path_getter(mut self, url_path_getter: G) -> Self where - G: Fn(&Request, &Depot) -> String + Send + Sync + 'static, + G: Fn(&Request, &Depot) -> Option + Send + Sync + 'static, { - self.url_rest_getter = Box::new(url_rest_getter); + self.url_path_getter = Box::new(url_path_getter); + self + } + + /// Set url query getter. + #[inline] + pub fn url_query_getter(mut self, url_query_getter: G) -> Self + where + G: Fn(&Request, &Depot) -> Option + Send + Sync + 'static, + { + self.url_query_getter = Box::new(url_query_getter); self } @@ -170,7 +184,17 @@ where return Err(Error::other("upstreams is empty")); } - let rest = (self.url_rest_getter)(req, depot); + let path = (self.url_path_getter)(req, depot).unwrap_or_default(); + let query = (self.url_query_getter)(req, depot); + let rest = if let Some(query) = query { + if query.starts_with('?') { + format!("{}{}", path, query) + } else { + format!("{}?{}", path, query) + } + } else { + path + }; let forward_url = if upstream.ends_with('/') && rest.starts_with('/') { format!("{}{}", upstream.trim_end_matches('/'), rest) } else if upstream.ends_with('/') || rest.starts_with('/') {