diff --git a/glide-core/src/retry_strategies.rs b/glide-core/src/retry_strategies.rs index dbe5683347..d851cb63dd 100644 --- a/glide-core/src/retry_strategies.rs +++ b/glide-core/src/retry_strategies.rs @@ -56,6 +56,7 @@ pub(crate) fn get_exponential_backoff( } #[cfg(feature = "socket-layer")] +#[allow(dead_code)] pub(crate) fn get_fixed_interval_backoff( fixed_interval: u32, number_of_retries: u32, diff --git a/glide-core/src/socket_listener.rs b/glide-core/src/socket_listener.rs index 50445c881d..6177c1ee7c 100644 --- a/glide-core/src/socket_listener.rs +++ b/glide-core/src/socket_listener.rs @@ -11,11 +11,10 @@ use crate::connection_request::ConnectionRequest; use crate::errors::{error_message, error_type, RequestErrorType}; use crate::response; use crate::response::Response; -use crate::retry_strategies::get_fixed_interval_backoff; use bytes::Bytes; use directories::BaseDirs; -use dispose::{Disposable, Dispose}; use logger_core::{log_debug, log_error, log_info, log_trace, log_warn}; +use once_cell::sync::Lazy; use protobuf::{Chars, Message}; use redis::cluster_routing::{ MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, @@ -23,18 +22,18 @@ use redis::cluster_routing::{ use redis::cluster_routing::{ResponsePolicy, Routable}; use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value}; use std::cell::Cell; +use std::collections::HashSet; use std::rc::Rc; +use std::sync::{Arc, RwLock}; use std::{env, str}; use std::{io, thread}; use thiserror::Error; -use tokio::io::ErrorKind::AddrInUse; use tokio::net::{UnixListener, UnixStream}; use tokio::runtime::Builder; use tokio::sync::mpsc; use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Mutex; use tokio::task; -use tokio_retry::Retry; use tokio_util::task::LocalPoolHandle; use ClosingReason::*; use PipeListeningResult::*; @@ -53,19 +52,6 @@ pub const ZSET: &str = "zset"; pub const HASH: &str = "hash"; pub const STREAM: &str = "stream"; -/// struct containing all objects needed to bind to a socket and clean it. -struct SocketListener { - socket_path: String, - cleanup_socket: bool, -} - -impl Dispose for SocketListener { - fn dispose(self) { - if self.cleanup_socket { - close_socket(&self.socket_path); - } - } -} /// struct containing all objects needed to read from a unix stream. struct UnixStreamListener { @@ -734,108 +720,6 @@ async fn listen_on_client_stream(socket: UnixStream) { log_trace("client closing", "closing connection"); } -enum SocketCreationResult { - // Socket creation was successful, returned a socket listener. - Created(UnixListener), - // There's an existing a socket listener. - PreExisting, - // Socket creation failed with an error. - Err(io::Error), -} - -impl SocketListener { - fn new(socket_path: String) -> Self { - SocketListener { - socket_path, - // Don't cleanup the socket resources unless we know that the socket is in use, and owned by this listener. - cleanup_socket: false, - } - } - - /// Return true if it's possible to connect to socket. - async fn socket_is_available(&self) -> bool { - if UnixStream::connect(&self.socket_path).await.is_ok() { - return true; - } - - let retry_strategy = get_fixed_interval_backoff(10, 3); - - let action = || async { - UnixStream::connect(&self.socket_path) - .await - .map(|_| ()) - .map_err(|_| ()) - }; - let result = Retry::spawn(retry_strategy.get_iterator(), action).await; - result.is_ok() - } - - async fn get_socket_listener(&self) -> SocketCreationResult { - const RETRY_COUNT: u8 = 3; - let mut retries = RETRY_COUNT; - while retries > 0 { - match UnixListener::bind(self.socket_path.clone()) { - Ok(listener) => { - return SocketCreationResult::Created(listener); - } - Err(err) if err.kind() == AddrInUse => { - if self.socket_is_available().await { - return SocketCreationResult::PreExisting; - } else { - // socket file might still exist, even if nothing is listening on it. - close_socket(&self.socket_path); - retries -= 1; - continue; - } - } - Err(err) => { - return SocketCreationResult::Err(err); - } - } - } - SocketCreationResult::Err(io::Error::new( - io::ErrorKind::Other, - "Failed to connect to socket", - )) - } - - pub(crate) async fn listen_on_socket(&mut self, init_callback: InitCallback) - where - InitCallback: FnOnce(Result) + Send + 'static, - { - // Bind to socket - let listener = match self.get_socket_listener().await { - SocketCreationResult::Created(listener) => listener, - SocketCreationResult::Err(err) => { - log_info("listen_on_socket", format!("failed with error: {err}")); - init_callback(Err(err.to_string())); - return; - } - SocketCreationResult::PreExisting => { - init_callback(Ok(self.socket_path.clone())); - return; - } - }; - - self.cleanup_socket = true; - init_callback(Ok(self.socket_path.clone())); - let local_set_pool = LocalPoolHandle::new(num_cpus::get()); - loop { - match listener.accept().await { - Ok((stream, _addr)) => { - local_set_pool.spawn_pinned(move || listen_on_client_stream(stream)); - } - Err(err) => { - log_debug( - "listen_on_socket", - format!("Socket closed with error: `{err}`"), - ); - return; - } - } - } - } -} #[derive(Debug)] /// Enum describing the reason that a socket listener stopped listening on a socket. @@ -924,23 +808,94 @@ pub fn start_socket_listener_internal( init_callback: InitCallback, socket_path: Option, ) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { + static INITIALIZED_SOCKETS: Lazy>>> = Lazy::new(|| { + Arc::new(RwLock::new(HashSet::new())) + }); + + let socket_path = socket_path.unwrap_or_else(get_socket_path); + + { + // Optimize for already initialized + let initialized_sockets = INITIALIZED_SOCKETS.read().expect("Failed to acquire sockets db read guard"); + if initialized_sockets.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + } + + // Retry with write lock, will be dropped upon the function completion + let mut sockets_write_guard = INITIALIZED_SOCKETS.write().expect("Failed to acquire sockets db write guard"); + if sockets_write_guard.contains(&socket_path) { + init_callback(Ok(socket_path.clone())); + return; + } + + let (tx, rx) = std::sync::mpsc::channel(); + let socket_path_cloned = socket_path.clone(); + let init_callback_cloned = init_callback.clone(); + let tx_cloned = tx.clone(); thread::Builder::new() - .name("socket_listener_thread".to_string()) - .spawn(move || { + .name("socket_listener_thread".to_string()) + .spawn(move || { + let init_result = { + let runtime = Builder::new_current_thread().enable_all().build(); - match runtime { - Ok(runtime) => { - let mut listener = Disposable::new(SocketListener::new( - socket_path.unwrap_or_else(get_socket_path), - )); - runtime.block_on(listener.listen_on_socket(init_callback)); + if let Err(err) = runtime { + log_error("listen_on_socket", format!("Error failed to create a new tokio thread: {err}")); + return Err(err); + } + + runtime.unwrap().block_on(async move { + let listener_socket = UnixListener::bind(socket_path_cloned.clone()); + if let Err(err) = listener_socket { + log_error("listen_on_socket", format!("Error failed to bind listening socket: {err}")); + return Err(err); } - Err(err) => init_callback(Err(err.to_string())), - }; - }) - .expect("Thread spawn failed. Cannot report error because callback was moved."); + let listener_socket = listener_socket.unwrap(); + + // signal initialization success + init_callback(Ok(socket_path_cloned.clone())); + let _ = tx.send(true); + + let local_set_pool = LocalPoolHandle::new(num_cpus::get()); + loop { + match listener_socket.accept().await { + Ok((stream, _addr)) => { + local_set_pool.spawn_pinned(move || listen_on_client_stream(stream)); + } + Err(err) => { + log_error( + "listen_on_socket", + format!("Error accepting connection: `{err}`"), + ); + break; + } + } + } + // no more listening on socket - update the sockets db + let _ = std::fs::remove_file(socket_path_cloned.clone()); + let mut sockets_write_guard = INITIALIZED_SOCKETS.write().expect("Failed to acquire sockets db write guard"); + sockets_write_guard.remove(&socket_path_cloned); // ensure socket file removal + Ok(()) + }) + }; + + if let Err(err) = init_result { + init_callback_cloned(Err(err.to_string())); + let _ = tx_cloned.send(false); + } + Ok(()) + }) + .expect("Thread spawn failed. Cannot report error because callback was moved."); + + // wait for thread initialization signaling, callback invocation is done in the thread + let _ = rx.recv().map(|res| { + if res { + sockets_write_guard.insert(socket_path); + } + }); } /// Creates a new thread with a main loop task listening on the socket for new connections. @@ -950,7 +905,7 @@ pub fn start_socket_listener_internal( /// * `init_callback` - called when the socket listener fails to initialize, with the reason for the failure. pub fn start_socket_listener(init_callback: InitCallback) where - InitCallback: FnOnce(Result) + Send + 'static, + InitCallback: FnOnce(Result) + Send + Clone + 'static, { start_socket_listener_internal(init_callback, None); } diff --git a/glide-core/tests/test_socket_listener.rs b/glide-core/tests/test_socket_listener.rs index a242eb80d1..35c7cc9e32 100644 --- a/glide-core/tests/test_socket_listener.rs +++ b/glide-core/tests/test_socket_listener.rs @@ -519,7 +519,7 @@ mod socket_listener { #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_working_after_socket_listener_was_dropped() { let socket_path = - get_socket_path_from_name("test_working_after_socket_listener_was_dropped".to_string()); + get_socket_path_from_name(format!("{}_test_working_after_socket_listener_was_dropped", std::process::id())); close_socket(&socket_path); // create a socket listener and drop it, to simulate a panic in a previous iteration. Builder::new_current_thread() @@ -528,6 +528,8 @@ mod socket_listener { .unwrap() .block_on(async { let _ = UnixListener::bind(socket_path.clone()).unwrap(); + // UDS sockets require explicit removal of the socket file + close_socket(&socket_path); }); const CALLBACK_INDEX: u32 = 99; @@ -555,7 +557,7 @@ mod socket_listener { #[timeout(SHORT_STANDALONE_TEST_TIMEOUT)] fn test_multiple_listeners_competing_for_the_socket() { let socket_path = get_socket_path_from_name( - "test_multiple_listeners_competing_for_the_socket".to_string(), + format!("{}_test_multiple_listeners_competing_for_the_socket", std::process::id()), ); close_socket(&socket_path); let server = Arc::new(RedisServer::new(ServerType::Tcp { tls: false }));