Skip to content

Commit

Permalink
split to path and query getter
Browse files Browse the repository at this point in the history
  • Loading branch information
chrislearn committed Aug 11, 2023
1 parent 7bac3e2 commit 0bbfc15
Showing 1 changed file with 46 additions and 22 deletions.
68 changes: 46 additions & 22 deletions crates/proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ use tokio::io::copy_bidirectional;
type HyperRequest = hyper::Request<ReqBody>;
type HyperResponse = hyper::Response<ResBody>;

/// Encode url path. This can be used when build your custom url rest getter.
/// Encode url path. This can be used when build your custom url path getter.
#[inline]
pub fn encode_url_path(path: &str) -> String {
pub(crate) fn encode_url_path(path: &str) -> String {
path.split('/')
.map(|s| utf8_percent_encode(s, CONTROLS).to_string())
.collect::<Vec<_>>()
Expand Down Expand Up @@ -79,21 +79,21 @@ where
}
}

/// Url rest getter. You can use this to get the rest of the url.
pub type UrlRestGetter = Box<dyn Fn(&Request, &Depot) -> String + Send + Sync + 'static>;
/// Url part getter. You can use this to get the proxied url path or query.
pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;

/// Default url rest getter.
pub fn default_url_rest_getter(req: &Request, _depot: &Depot) -> String {
/// Default url path getter. This getter will get the url path from request wildcard param, like `<*rest>`, `<**rest>`.
pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
let param = req.params().iter().find(|(key, _)| key.starts_with('*'));
let mut rest = if let Some((_, rest)) = param {
encode_url_path(rest)
if let Some((_, rest)) = param {
Some(encode_url_path(rest))
} else {
"".into()
};
if let Some(query) = req.uri().query() {
rest = format!("{}?{}", rest, query);
None
}
rest
}
/// Default url query getter. This getter just return the query string from request uri.
pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
req.uri().query().map(Into::into)
}

/// Proxy
Expand All @@ -104,8 +104,10 @@ pub struct Proxy<U> {
pub upstreams: U,
/// [`Client`] for proxy.
pub client: Client,
/// Url rest getter.
pub url_rest_getter: UrlRestGetter,
/// Url path getter.
pub url_path_getter: UrlPartGetter,
/// Url query getter.
pub url_query_getter: UrlPartGetter,
}

impl<U> Proxy<U>
Expand All @@ -118,25 +120,37 @@ where
Proxy {
upstreams,
client: Client::new(),
url_rest_getter: Box::new(default_url_rest_getter),
url_path_getter: Box::new(default_url_path_getter),
url_query_getter: Box::new(default_url_query_getter),
}
}
/// Create new `Proxy` with upstreams list and [`Client`].
pub fn with_client(upstreams: U, client: Client) -> Self {
Proxy {
upstreams,
client,
url_rest_getter: Box::new(default_url_rest_getter),
url_path_getter: Box::new(default_url_path_getter),
url_query_getter: Box::new(default_url_query_getter),
}
}

/// Set url rest getter.
/// Set url path getter.
#[inline]
pub fn url_rest_getter<G>(mut self, url_rest_getter: G) -> Self
pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> String + Send + Sync + 'static,
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_rest_getter = Box::new(url_rest_getter);
self.url_path_getter = Box::new(url_path_getter);
self
}

/// Set url query getter.
#[inline]
pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
where
G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
{
self.url_query_getter = Box::new(url_query_getter);
self
}

Expand Down Expand Up @@ -170,7 +184,17 @@ where
return Err(Error::other("upstreams is empty"));
}

let rest = (self.url_rest_getter)(req, depot);
let path = (self.url_path_getter)(req, depot).unwrap_or_default();
let query = (self.url_query_getter)(req, depot);
let rest = if let Some(query) = query {
if query.starts_with('?') {
format!("{}{}", path, query)
} else {
format!("{}?{}", path, query)
}
} else {
path
};
let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
format!("{}{}", upstream.trim_end_matches('/'), rest)
} else if upstream.ends_with('/') || rest.starts_with('/') {
Expand Down

0 comments on commit 0bbfc15

Please sign in to comment.