diff --git a/api/src/config.rs b/api/src/config.rs index c626afa1dbd..257bfb2624e 100644 --- a/api/src/config.rs +++ b/api/src/config.rs @@ -917,12 +917,6 @@ pub struct MirrorConfig { /// HTTP request headers to be passed to mirror server. #[serde(default)] pub headers: HashMap, - /// Whether the authorization process is through mirror, default to false. - /// true: authorization through mirror, e.g. Using normal registry as mirror. - /// false: authorization through original registry, - /// e.g. when using Dragonfly server as mirror, authorization through it may affect performance. - #[serde(default)] - pub auth_through: bool, /// Interval for mirror health checking, in seconds. #[serde(default = "default_check_interval")] pub health_check_interval: u64, @@ -936,7 +930,6 @@ impl Default for MirrorConfig { Self { host: String::new(), headers: HashMap::new(), - auth_through: false, health_check_interval: 5, failure_limit: 5, ping_url: String::new(), @@ -1825,7 +1818,6 @@ mod tests { [[backend.oss.mirrors]] host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" - auth_through = true health_check_interval = 10 failure_limit = 10 "#; @@ -1859,7 +1851,6 @@ mod tests { let mirror = &oss.mirrors[0]; assert_eq!(mirror.host, "http://127.0.0.1:65001"); assert_eq!(mirror.ping_url, "http://127.0.0.1:65001/ping"); - assert!(mirror.auth_through); assert!(mirror.headers.is_empty()); assert_eq!(mirror.health_check_interval, 10); assert_eq!(mirror.failure_limit, 10); @@ -1891,7 +1882,6 @@ mod tests { [[backend.registry.mirrors]] host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" - auth_through = true health_check_interval = 10 failure_limit = 10 "#; @@ -1927,7 +1917,6 @@ mod tests { let mirror = ®istry.mirrors[0]; assert_eq!(mirror.host, "http://127.0.0.1:65001"); assert_eq!(mirror.ping_url, "http://127.0.0.1:65001/ping"); - assert!(mirror.auth_through); assert!(mirror.headers.is_empty()); assert_eq!(mirror.health_check_interval, 10); assert_eq!(mirror.failure_limit, 10); diff --git a/docs/nydusd.md b/docs/nydusd.md index 4a92bc7ab59..069eed77471 100644 --- a/docs/nydusd.md +++ b/docs/nydusd.md @@ -297,9 +297,6 @@ Currently, the mirror mode is only tested in the registry backend, and in theory { // Mirror server URL (including scheme), e.g. Dragonfly dfdaemon server URL "host": "http://dragonfly1.io:65001", - // true: Send the authorization request to the mirror e.g. another docker registry. - // false: Authorization request won't be relayed by the mirror e.g. Dragonfly. - "auth_through": false, // Headers for mirror server "headers": { // For Dragonfly dfdaemon server URL, we need to specify "X-Dragonfly-Registry" (including scheme). diff --git a/misc/configs/nydusd-blob-cache-entry-configuration-v2.toml b/misc/configs/nydusd-blob-cache-entry-configuration-v2.toml index 25c1952e6df..ab434f60733 100644 --- a/misc/configs/nydusd-blob-cache-entry-configuration-v2.toml +++ b/misc/configs/nydusd-blob-cache-entry-configuration-v2.toml @@ -57,8 +57,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. @@ -108,8 +106,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. diff --git a/misc/configs/nydusd-blob-cache-entry.toml b/misc/configs/nydusd-blob-cache-entry.toml index fea1c445fa8..dcab4707dfc 100644 --- a/misc/configs/nydusd-blob-cache-entry.toml +++ b/misc/configs/nydusd-blob-cache-entry.toml @@ -62,8 +62,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. @@ -113,8 +111,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. diff --git a/misc/configs/nydusd-config-v2.toml b/misc/configs/nydusd-config-v2.toml index ca281901ebd..ed33ec77c4d 100644 --- a/misc/configs/nydusd-config-v2.toml +++ b/misc/configs/nydusd-config-v2.toml @@ -55,8 +55,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. @@ -106,8 +104,6 @@ host = "http://127.0.0.1:65001" ping_url = "http://127.0.0.1:65001/ping" # HTTP request headers to be passed to mirror server. # headers = -# Whether the authorization process is through mirror, default to false. -auth_through = true # Interval for mirror health checking, in seconds. health_check_interval = 5 # Maximum number of failures before marking a mirror as unusable. diff --git a/storage/src/backend/connection.rs b/storage/src/backend/connection.rs index e324e4f4aa2..6b6b2e69e43 100644 --- a/storage/src/backend/connection.rs +++ b/storage/src/backend/connection.rs @@ -403,14 +403,17 @@ impl Connection { } else { mirror_cloned.config.ping_url.clone() }; - info!("Mirror health checking url: {}", mirror_health_url); + info!( + "[mirror] start health check, ping url: {}", + mirror_health_url + ); let client = Client::new(); loop { // Try to recover the mirror server when it is unavailable. if !mirror_cloned.status.load(Ordering::Relaxed) { info!( - "Mirror server {} unhealthy, try to recover", + "[mirror] server unhealthy, try to recover: {}", mirror_cloned.config.host ); @@ -422,14 +425,17 @@ impl Connection { // If the response status is less than StatusCode::INTERNAL_SERVER_ERROR, // the mirror server is recovered. if resp.status() < StatusCode::INTERNAL_SERVER_ERROR { - info!("Mirror server {} recovered", mirror_cloned.config.host); + info!( + "[mirror] server recovered: {}", + mirror_cloned.config.host + ); mirror_cloned.failed_times.store(0, Ordering::Relaxed); mirror_cloned.status.store(true, Ordering::Relaxed); } }) .map_err(|e| { warn!( - "Mirror server {} is not recovered: {}", + "[mirror] failed to recover server: {}, {}", mirror_cloned.config.host, e ); }); @@ -448,13 +454,6 @@ impl Connection { self.shutdown.store(true, Ordering::Release); } - /// If the auth_through is enable, all requests are send to the mirror server. - /// If the auth_through disabled, e.g. P2P/Dragonfly, we try to avoid sending - /// non-authorization request to the mirror server, which causes performance loss. - /// requesting_auth means this request is to get authorization from a server, - /// which must be a non-authorization request. - /// IOW, only the requesting_auth is false and the headers contain authorization token, - /// we send this request to mirror. #[allow(clippy::too_many_arguments)] pub fn call( &self, @@ -464,8 +463,6 @@ impl Connection { data: Option>, headers: &mut HeaderMap, catch_status: bool, - // This means the request is dedicated to authorization. - requesting_auth: bool, ) -> ConnectionResult { if self.shutdown.load(Ordering::Acquire) { return Err(ConnectionError::Disconnected); @@ -524,27 +521,10 @@ impl Connection { } } + let mut mirror_enabled = false; if !self.mirrors.is_empty() { - let mut fallback_due_auth = false; + mirror_enabled = true; for mirror in self.mirrors.iter() { - // With configuration `auth_through` disabled, we should not intend to send authentication - // request to mirror. Mainly because mirrors like P2P/Dragonfly has a poor performance when - // relaying non-data requests. But it's still possible that ever returned token is expired. - // So mirror might still respond us with status code UNAUTHORIZED, which should be handle - // by sending authentication request to the original registry. - // - // - For non-authentication request with token in request header, handle is as usual requests to registry. - // This request should already take token in header. - // - For authentication request - // 1. auth_through is disabled(false): directly pass below mirror translations and jump to original registry handler. - // 2. auth_through is enabled(true): try to get authenticated from mirror and should also handle status code UNAUTHORIZED. - if !mirror.config.auth_through - && (!headers.contains_key(HEADER_AUTHORIZATION) || requesting_auth) - { - fallback_due_auth = true; - break; - } - if mirror.status.load(Ordering::Relaxed) { let data_cloned = data.as_ref().cloned(); @@ -556,7 +536,7 @@ impl Connection { } let current_url = mirror.mirror_url(url)?; - debug!("mirror server url {}", current_url); + debug!("[mirror] replace to: {}", current_url); let result = self.call_inner( &self.client, @@ -578,14 +558,14 @@ impl Connection { } Err(err) => { warn!( - "request mirror server failed, mirror: {:?}, error: {:?}", - mirror, err + "[mirror] request failed, server: {:?}, {:?}", + mirror.config.host, err ); mirror.failed_times.fetch_add(1, Ordering::Relaxed); if mirror.failed_times.load(Ordering::Relaxed) >= mirror.failure_limit { warn!( - "reach to failure limit {}, disable mirror: {:?}", + "[mirror] exceed failure limit {}, server disabled: {:?}", mirror.failure_limit, mirror ); mirror.status.store(false, Ordering::Relaxed); @@ -598,9 +578,10 @@ impl Connection { headers.remove(HeaderName::from_str(key).unwrap()); } } - if !fallback_due_auth { - warn!("Request to all mirror server failed, fallback to original server."); - } + } + + if mirror_enabled { + warn!("[mirror] request all servers failed, fallback to original server."); } self.call_inner( diff --git a/storage/src/backend/http_proxy.rs b/storage/src/backend/http_proxy.rs index 8fd31df77a5..c1324fbef78 100644 --- a/storage/src/backend/http_proxy.rs +++ b/storage/src/backend/http_proxy.rs @@ -214,7 +214,6 @@ impl BlobReader for HttpProxyReader { None, &mut HeaderMap::new(), true, - false, ) .map(|resp| resp.headers().to_owned()) .map_err(|e| HttpProxyError::RemoteRequest(e).into()) @@ -255,15 +254,7 @@ impl BlobReader for HttpProxyReader { .map_err(|e| HttpProxyError::ConstructHeader(format!("{}", e)))?, ); let mut resp = connection - .call::<&[u8]>( - Method::GET, - uri.as_str(), - None, - None, - &mut headers, - true, - false, - ) + .call::<&[u8]>(Method::GET, uri.as_str(), None, None, &mut headers, true) .map_err(HttpProxyError::RemoteRequest)?; Ok(resp diff --git a/storage/src/backend/object_storage.rs b/storage/src/backend/object_storage.rs index c7a617b2aaa..7c2b8ba655c 100644 --- a/storage/src/backend/object_storage.rs +++ b/storage/src/backend/object_storage.rs @@ -89,15 +89,7 @@ where let resp = self .connection - .call::<&[u8]>( - Method::HEAD, - url.as_str(), - None, - None, - &mut headers, - true, - false, - ) + .call::<&[u8]>(Method::HEAD, url.as_str(), None, None, &mut headers, true) .map_err(ObjectStorageError::Request)?; let content_length = resp .headers() @@ -136,15 +128,7 @@ where // Safe because the the call() is a synchronous operation. let mut resp = self .connection - .call::<&[u8]>( - Method::GET, - url.as_str(), - None, - None, - &mut headers, - true, - false, - ) + .call::<&[u8]>(Method::GET, url.as_str(), None, None, &mut headers, true) .map_err(ObjectStorageError::Request)?; Ok(resp .copy_to(&mut buf) diff --git a/storage/src/backend/registry.rs b/storage/src/backend/registry.rs index 79df3bfaf85..737a03454ff 100644 --- a/storage/src/backend/registry.rs +++ b/storage/src/backend/registry.rs @@ -7,11 +7,11 @@ use std::collections::HashMap; use std::error::Error; use std::io::{Read, Result}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Once, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use std::{fmt, thread}; -use arc_swap::ArcSwapOption; +use arc_swap::{ArcSwap, ArcSwapOption}; use base64::Engine; use reqwest::blocking::Response; pub use reqwest::header::HeaderMap; @@ -36,6 +36,8 @@ const REDIRECTED_STATUS_CODE: [StatusCode; 2] = [ StatusCode::TEMPORARY_REDIRECT, ]; +const REGISTRY_DEFAULT_TOKEN_EXPIRATION: u64 = 10 * 60; // in seconds + /// Error codes related to registry storage backend operations. #[derive(Debug)] pub enum RegistryError { @@ -116,13 +118,15 @@ impl HashCache { #[derive(Clone, serde::Deserialize)] struct TokenResponse { + /// Registry token string. token: String, + /// Registry token period of validity, in seconds. #[serde(default = "default_expires_in")] expires_in: u64, } fn default_expires_in() -> u64 { - 10 * 60 + REGISTRY_DEFAULT_TOKEN_EXPIRATION } #[derive(Debug)] @@ -189,8 +193,8 @@ struct RegistryState { // Example: RwLock", "">> cached_redirect: HashCache, - // The expiration time of the token, which is obtained from the registry server. - refresh_token_time: ArcSwapOption, + // The epoch timestamp of token expiration, which is obtained from the registry server. + token_expired_at: ArcSwapOption, // Cache bearer auth for refreshing token. cached_bearer_auth: ArcSwapOption, } @@ -235,7 +239,7 @@ impl RegistryState { } /// Request registry authentication server to get bearer token - fn get_token(&self, auth: BearerAuth, connection: &Arc) -> Result { + fn get_token(&self, auth: BearerAuth, connection: &Arc) -> Result { // The information needed for getting token needs to be placed both in // the query and in the body to be compatible with different registry // implementations, which have been tested on these platforms: @@ -267,7 +271,6 @@ impl RegistryState { Some(ReqBody::Form(form)), &mut headers, true, - true, ) .map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?; let ret: TokenResponse = token_resp.json().map_err(|e| { @@ -277,7 +280,7 @@ impl RegistryState { )) })?; if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) { - self.refresh_token_time + self.token_expired_at .store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in))); debug!( "cached bearer auth, next time: {}", @@ -288,7 +291,7 @@ impl RegistryState { // Cache bearer auth for refreshing token. self.cached_bearer_auth.store(Some(Arc::new(auth))); - Ok(ret.token) + Ok(ret) } fn get_auth_header(&self, auth: Auth, connection: &Arc) -> Result { @@ -300,7 +303,7 @@ impl RegistryState { .ok_or_else(|| einval!("invalid auth config")), Auth::Bearer(auth) => { let token = self.get_token(auth, connection)?; - Ok(format!("Bearer {}", token)) + Ok(format!("Bearer {}", token.token)) } } } @@ -364,11 +367,75 @@ impl RegistryState { } } +#[derive(Clone)] +struct First { + inner: Arc>, +} + +impl First { + fn new() -> Self { + First { + inner: Arc::new(ArcSwap::new(Arc::new(Once::new()))), + } + } + + fn once(&self, f: F) + where + F: FnOnce(), + { + self.inner.load().call_once(f) + } + + fn renew(&self) { + self.inner.store(Arc::new(Once::new())); + } + + fn handle(&self, handle: &mut F) -> Option> + where + F: FnMut() -> BackendResult, + { + let mut ret = None; + // Call once twice to ensure the subsequent requests use the new + // Once instance after renew happens. + for _ in 0..=1 { + self.once(|| { + ret = Some(handle().map_err(|err| { + // Replace the Once instance so that we can retry it when + // the handle call failed. + self.renew(); + err + })); + }); + if ret.is_some() { + break; + } + } + ret + } + + /// When invoking concurrently, only one of the handle methods will be executed first, + /// then subsequent handle methods will be allowed to execute concurrently. + /// + /// Nydusd uses a registry backend which generates a surge of blob requests without + /// auth tokens on initial startup, this caused mirror backends (e.g. dragonfly) + /// to process very slowly. The method implements waiting for the first blob request + /// to complete before making other blob requests, this ensures the first request + /// caches a valid registry auth token, and subsequent concurrent blob requests can + /// reuse the cached token. + fn handle_force(&self, handle: &mut F) -> BackendResult + where + F: FnMut() -> BackendResult, + { + self.handle(handle).unwrap_or_else(handle) + } +} + struct RegistryReader { blob_id: String, connection: Arc, state: Arc, metrics: Arc, + first: First, } impl RegistryReader { @@ -422,22 +489,14 @@ impl RegistryReader { if let Some(data) = data { return self .connection - .call( - method, - url, - None, - Some(data), - &mut headers, - catch_status, - false, - ) + .call(method, url, None, Some(data), &mut headers, catch_status) .map_err(RegistryError::Request); } // Try to request registry server with `authorization` header let mut resp = self .connection - .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false, false) + .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false) .map_err(RegistryError::Request)?; if resp.status() == StatusCode::UNAUTHORIZED { if headers.contains_key(HEADER_AUTHORIZATION) { @@ -452,7 +511,7 @@ impl RegistryReader { resp = self .connection - .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false, false) + .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false) .map_err(RegistryError::Request)?; }; @@ -472,7 +531,7 @@ impl RegistryReader { // Try to request registry server with `authorization` header again let resp = self .connection - .call(method, url, None, data, &mut headers, catch_status, false) + .call(method, url, None, data, &mut headers, catch_status) .map_err(RegistryError::Request)?; let status = resp.status(); @@ -528,7 +587,6 @@ impl RegistryReader { None, &mut headers, false, - false, ) .map_err(RegistryError::Request)?; @@ -613,7 +671,6 @@ impl RegistryReader { None, &mut headers, true, - false, ) .map_err(RegistryError::Request); match resp_ret { @@ -641,14 +698,20 @@ impl RegistryReader { impl BlobReader for RegistryReader { fn blob_size(&self) -> BackendResult { - let url = format!("/blobs/sha256:{}", self.blob_id); - let url = self - .state - .url(&url, &[]) - .map_err(|e| RegistryError::Url(url, e))?; - - let resp = - match self.request::<&[u8]>(Method::HEAD, url.as_str(), None, HeaderMap::new(), true) { + self.first.handle_force(&mut || -> BackendResult { + let url = format!("/blobs/sha256:{}", self.blob_id); + let url = self + .state + .url(&url, &[]) + .map_err(|e| RegistryError::Url(url, e))?; + + let resp = match self.request::<&[u8]>( + Method::HEAD, + url.as_str(), + None, + HeaderMap::new(), + true, + ) { Ok(res) => res, Err(RegistryError::Request(ConnectionError::Common(e))) if self.state.needs_fallback_http(&e) => @@ -665,21 +728,26 @@ impl BlobReader for RegistryReader { return Err(BackendError::Registry(e)); } }; - let content_length = resp - .headers() - .get(CONTENT_LENGTH) - .ok_or_else(|| RegistryError::Common("invalid content length".to_string()))?; - - Ok(content_length - .to_str() - .map_err(|err| RegistryError::Common(format!("invalid content length: {:?}", err)))? - .parse::() - .map_err(|err| RegistryError::Common(format!("invalid content length: {:?}", err)))?) + let content_length = resp + .headers() + .get(CONTENT_LENGTH) + .ok_or_else(|| RegistryError::Common("invalid content length".to_string()))?; + + Ok(content_length + .to_str() + .map_err(|err| RegistryError::Common(format!("invalid content length: {:?}", err)))? + .parse::() + .map_err(|err| { + RegistryError::Common(format!("invalid content length: {:?}", err)) + })?) + }) } fn try_read(&self, buf: &mut [u8], offset: u64) -> BackendResult { - self._try_read(buf, offset, true) - .map_err(BackendError::Registry) + self.first.handle_force(&mut || -> BackendResult { + self._try_read(buf, offset, true) + .map_err(BackendError::Registry) + }) } fn metrics(&self) -> &BackendMetrics { @@ -696,6 +764,7 @@ pub struct Registry { connection: Arc, state: Arc, metrics: Arc, + first: First, } impl Registry { @@ -741,25 +810,19 @@ impl Registry { blob_url_scheme: config.blob_url_scheme.clone(), blob_redirected_host: config.blob_redirected_host.clone(), cached_redirect: HashCache::new(), - refresh_token_time: ArcSwapOption::new(None), + token_expired_at: ArcSwapOption::new(None), cached_bearer_auth: ArcSwapOption::new(None), }); - let mirrors = connection.mirrors.clone(); - let registry = Registry { connection, state, metrics: BackendMetrics::new(id, "registry"), + first: First::new(), }; - for mirror in mirrors.iter() { - if !mirror.config.auth_through { - registry.start_refresh_token_thread(); - info!("Refresh token thread started."); - break; - } - } + registry.start_refresh_token_thread(); + info!("Refresh token thread started."); Ok(registry) } @@ -794,30 +857,39 @@ impl Registry { fn start_refresh_token_thread(&self) { let conn = self.connection.clone(); let state = self.state.clone(); - // The default refresh token internal is 10 minutes. - let refresh_check_internal = 10 * 60; + // FIXME: we'd better allow users to specify the expiration time. + let mut refresh_interval = REGISTRY_DEFAULT_TOKEN_EXPIRATION; thread::spawn(move || { loop { if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) { - if let Some(next_refresh_timestamp) = state.refresh_token_time.load().as_deref() - { - // If the token will expire in next refresh check internal, get new token now. - // Add 20 seconds to handle critical cases. - if now_timestamp.as_secs() + refresh_check_internal + 20 - >= *next_refresh_timestamp - { + if let Some(token_expired_at) = state.token_expired_at.load().as_deref() { + // If the token will expire within the next refresh interval, + // refresh it immediately. + if now_timestamp.as_secs() + refresh_interval >= *token_expired_at { if let Some(cached_bearer_auth) = state.cached_bearer_auth.load().as_deref() { if let Ok(token) = state.get_token(cached_bearer_auth.to_owned(), &conn) { - let new_cached_auth = format!("Bearer {}", token); - info!("Authorization token for registry has been refreshed."); - // Refresh authorization token + let new_cached_auth = format!("Bearer {}", token.token); + debug!( + "[refresh_token_thread] registry token has been refreshed" + ); + // Refresh cached token. state .cached_auth .set(&state.cached_auth.get(), new_cached_auth); + // Reset refresh interval according to real expiration time, + // and advance 20s to handle the unexpected cases. + refresh_interval = token + .expires_in + .checked_sub(20) + .unwrap_or(token.expires_in); + } else { + error!( + "[refresh_token_thread] failed to refresh registry token" + ); } } } @@ -827,7 +899,7 @@ impl Registry { if conn.shutdown.load(Ordering::Acquire) { break; } - thread::sleep(Duration::from_secs(refresh_check_internal)); + thread::sleep(Duration::from_secs(refresh_interval)); if conn.shutdown.load(Ordering::Acquire) { break; } @@ -851,6 +923,7 @@ impl BlobBackend for Registry { state: self.state.clone(), connection: self.connection.clone(), metrics: self.metrics.clone(), + first: self.first.clone(), })) } } @@ -919,7 +992,7 @@ mod tests { blob_redirected_host: "oss.alibaba-inc.com".to_string(), cached_auth: Default::default(), cached_redirect: Default::default(), - refresh_token_time: ArcSwapOption::new(None), + token_expired_at: ArcSwapOption::new(None), cached_bearer_auth: ArcSwapOption::new(None), }; @@ -971,4 +1044,60 @@ mod tests { assert_eq!(trim(Some(" te st ".to_owned())), Some("te st".to_owned())); assert_eq!(trim(Some("te st".to_owned())), Some("te st".to_owned())); } + + #[test] + #[allow(clippy::redundant_clone)] + fn test_first_basically() { + let first = First::new(); + let mut val = 0; + first.once(|| { + val += 1; + }); + assert_eq!(val, 1); + + first.clone().once(|| { + val += 1; + }); + assert_eq!(val, 1); + + first.renew(); + first.clone().once(|| { + val += 1; + }); + assert_eq!(val, 2); + } + + #[test] + #[allow(clippy::redundant_clone)] + fn test_first_concurrently() { + let val = Arc::new(ArcSwap::new(Arc::new(0))); + let first = First::new(); + + let mut handlers = Vec::new(); + for _ in 0..100 { + let val_cloned = val.clone(); + let first_cloned = first.clone(); + handlers.push(std::thread::spawn(move || { + let _ = first_cloned.handle(&mut || -> BackendResult<()> { + let val = val_cloned.load(); + let ret = if *val.as_ref() == 0 { + std::thread::sleep(std::time::Duration::from_secs(2)); + Err(BackendError::Registry(RegistryError::Common(String::from( + "network error", + )))) + } else { + Ok(()) + }; + val_cloned.store(Arc::new(val.as_ref() + 1)); + ret + }); + })); + } + + for handler in handlers { + handler.join().unwrap(); + } + + assert_eq!(*val.load().as_ref(), 2); + } }