Skip to content

Commit

Permalink
Glide-core UDS Socket Handling Rework:
Browse files Browse the repository at this point in the history
1.Introduced a user-land mechanism for ensuring singleton behavior of the socket, rather than relying on OS-specific semantics. This addresses the issue where macOS and Linux report different errors when the socket path already exists.

2.Simplified the implementation by removing unnecessary abstractions, including redundant connection retry logic.

Signed-off-by: ikolomi <[email protected]>
  • Loading branch information
ikolomi committed Oct 21, 2024
1 parent 6a3a33f commit 99f91f7
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 135 deletions.
1 change: 1 addition & 0 deletions glide-core/src/retry_strategies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub(crate) fn get_exponential_backoff(
}

#[cfg(feature = "socket-layer")]
#[allow(dead_code)]
pub(crate) fn get_fixed_interval_backoff(
fixed_interval: u32,
number_of_retries: u32,
Expand Down
221 changes: 88 additions & 133 deletions glide-core/src/socket_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,29 @@ use crate::connection_request::ConnectionRequest;
use crate::errors::{error_message, error_type, RequestErrorType};
use crate::response;
use crate::response::Response;
use crate::retry_strategies::get_fixed_interval_backoff;
use bytes::Bytes;
use directories::BaseDirs;
use dispose::{Disposable, Dispose};
use logger_core::{log_debug, log_error, log_info, log_trace, log_warn};
use once_cell::sync::Lazy;
use protobuf::{Chars, Message};
use redis::cluster_routing::{
MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr,
};
use redis::cluster_routing::{ResponsePolicy, Routable};
use redis::{Cmd, PushInfo, RedisError, ScanStateRC, Value};
use std::cell::Cell;
use std::collections::HashSet;
use std::rc::Rc;
use std::sync::{Arc, RwLock};
use std::{env, str};
use std::{io, thread};
use thiserror::Error;
use tokio::io::ErrorKind::AddrInUse;
use tokio::net::{UnixListener, UnixStream};
use tokio::runtime::Builder;
use tokio::sync::mpsc;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::Mutex;
use tokio::task;
use tokio_retry::Retry;
use tokio_util::task::LocalPoolHandle;
use ClosingReason::*;
use PipeListeningResult::*;
Expand All @@ -53,19 +52,6 @@ pub const ZSET: &str = "zset";
pub const HASH: &str = "hash";
pub const STREAM: &str = "stream";

/// struct containing all objects needed to bind to a socket and clean it.
struct SocketListener {
socket_path: String,
cleanup_socket: bool,
}

impl Dispose for SocketListener {
fn dispose(self) {
if self.cleanup_socket {
close_socket(&self.socket_path);
}
}
}

/// struct containing all objects needed to read from a unix stream.
struct UnixStreamListener {
Expand Down Expand Up @@ -734,108 +720,6 @@ async fn listen_on_client_stream(socket: UnixStream) {
log_trace("client closing", "closing connection");
}

enum SocketCreationResult {
// Socket creation was successful, returned a socket listener.
Created(UnixListener),
// There's an existing a socket listener.
PreExisting,
// Socket creation failed with an error.
Err(io::Error),
}

impl SocketListener {
fn new(socket_path: String) -> Self {
SocketListener {
socket_path,
// Don't cleanup the socket resources unless we know that the socket is in use, and owned by this listener.
cleanup_socket: false,
}
}

/// Return true if it's possible to connect to socket.
async fn socket_is_available(&self) -> bool {
if UnixStream::connect(&self.socket_path).await.is_ok() {
return true;
}

let retry_strategy = get_fixed_interval_backoff(10, 3);

let action = || async {
UnixStream::connect(&self.socket_path)
.await
.map(|_| ())
.map_err(|_| ())
};
let result = Retry::spawn(retry_strategy.get_iterator(), action).await;
result.is_ok()
}

async fn get_socket_listener(&self) -> SocketCreationResult {
const RETRY_COUNT: u8 = 3;
let mut retries = RETRY_COUNT;
while retries > 0 {
match UnixListener::bind(self.socket_path.clone()) {
Ok(listener) => {
return SocketCreationResult::Created(listener);
}
Err(err) if err.kind() == AddrInUse => {
if self.socket_is_available().await {
return SocketCreationResult::PreExisting;
} else {
// socket file might still exist, even if nothing is listening on it.
close_socket(&self.socket_path);
retries -= 1;
continue;
}
}
Err(err) => {
return SocketCreationResult::Err(err);
}
}
}
SocketCreationResult::Err(io::Error::new(
io::ErrorKind::Other,
"Failed to connect to socket",
))
}

pub(crate) async fn listen_on_socket<InitCallback>(&mut self, init_callback: InitCallback)
where
InitCallback: FnOnce(Result<String, String>) + Send + 'static,
{
// Bind to socket
let listener = match self.get_socket_listener().await {
SocketCreationResult::Created(listener) => listener,
SocketCreationResult::Err(err) => {
log_info("listen_on_socket", format!("failed with error: {err}"));
init_callback(Err(err.to_string()));
return;
}
SocketCreationResult::PreExisting => {
init_callback(Ok(self.socket_path.clone()));
return;
}
};

self.cleanup_socket = true;
init_callback(Ok(self.socket_path.clone()));
let local_set_pool = LocalPoolHandle::new(num_cpus::get());
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
local_set_pool.spawn_pinned(move || listen_on_client_stream(stream));
}
Err(err) => {
log_debug(
"listen_on_socket",
format!("Socket closed with error: `{err}`"),
);
return;
}
}
}
}
}

#[derive(Debug)]
/// Enum describing the reason that a socket listener stopped listening on a socket.
Expand Down Expand Up @@ -924,23 +808,94 @@ pub fn start_socket_listener_internal<InitCallback>(
init_callback: InitCallback,
socket_path: Option<String>,
) where
InitCallback: FnOnce(Result<String, String>) + Send + 'static,
InitCallback: FnOnce(Result<String, String>) + Send + Clone + 'static,
{
static INITIALIZED_SOCKETS: Lazy<Arc<RwLock<HashSet<String>>>> = Lazy::new(|| {
Arc::new(RwLock::new(HashSet::new()))
});

let socket_path = socket_path.unwrap_or_else(get_socket_path);

{
// Optimize for already initialized
let initialized_sockets = INITIALIZED_SOCKETS.read().expect("Failed to acquire sockets db read guard");
if initialized_sockets.contains(&socket_path) {
init_callback(Ok(socket_path.clone()));
return;
}
}

// Retry with write lock, will be dropped upon the function completion
let mut sockets_write_guard = INITIALIZED_SOCKETS.write().expect("Failed to acquire sockets db write guard");
if sockets_write_guard.contains(&socket_path) {
init_callback(Ok(socket_path.clone()));
return;
}

let (tx, rx) = std::sync::mpsc::channel();
let socket_path_cloned = socket_path.clone();
let init_callback_cloned = init_callback.clone();
let tx_cloned = tx.clone();
thread::Builder::new()
.name("socket_listener_thread".to_string())
.spawn(move || {
.name("socket_listener_thread".to_string())
.spawn(move || {
let init_result = {

let runtime = Builder::new_current_thread().enable_all().build();
match runtime {
Ok(runtime) => {
let mut listener = Disposable::new(SocketListener::new(
socket_path.unwrap_or_else(get_socket_path),
));
runtime.block_on(listener.listen_on_socket(init_callback));
if let Err(err) = runtime {
log_error("listen_on_socket", format!("Error failed to create a new tokio thread: {err}"));
return Err(err);
}

runtime.unwrap().block_on(async move {
let listener_socket = UnixListener::bind(socket_path_cloned.clone());
if let Err(err) = listener_socket {
log_error("listen_on_socket", format!("Error failed to bind listening socket: {err}"));
return Err(err);
}
Err(err) => init_callback(Err(err.to_string())),
};
})
.expect("Thread spawn failed. Cannot report error because callback was moved.");
let listener_socket = listener_socket.unwrap();

// signal initialization success
init_callback(Ok(socket_path_cloned.clone()));
let _ = tx.send(true);

let local_set_pool = LocalPoolHandle::new(num_cpus::get());
loop {
match listener_socket.accept().await {
Ok((stream, _addr)) => {
local_set_pool.spawn_pinned(move || listen_on_client_stream(stream));
}
Err(err) => {
log_error(
"listen_on_socket",
format!("Error accepting connection: `{err}`"),
);
break;
}
}
}
// no more listening on socket - update the sockets db
let _ = std::fs::remove_file(socket_path_cloned.clone());
let mut sockets_write_guard = INITIALIZED_SOCKETS.write().expect("Failed to acquire sockets db write guard");
sockets_write_guard.remove(&socket_path_cloned); // ensure socket file removal
Ok(())
})
};

if let Err(err) = init_result {
init_callback_cloned(Err(err.to_string()));
let _ = tx_cloned.send(false);
}
Ok(())
})
.expect("Thread spawn failed. Cannot report error because callback was moved.");

// wait for thread initialization signaling, callback invocation is done in the thread
let _ = rx.recv().map(|res| {
if res {
sockets_write_guard.insert(socket_path);
}
});
}

/// Creates a new thread with a main loop task listening on the socket for new connections.
Expand All @@ -950,7 +905,7 @@ pub fn start_socket_listener_internal<InitCallback>(
/// * `init_callback` - called when the socket listener fails to initialize, with the reason for the failure.
pub fn start_socket_listener<InitCallback>(init_callback: InitCallback)
where
InitCallback: FnOnce(Result<String, String>) + Send + 'static,
InitCallback: FnOnce(Result<String, String>) + Send + Clone + 'static,
{
start_socket_listener_internal(init_callback, None);
}
6 changes: 4 additions & 2 deletions glide-core/tests/test_socket_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ mod socket_listener {
#[timeout(SHORT_STANDALONE_TEST_TIMEOUT)]
fn test_working_after_socket_listener_was_dropped() {
let socket_path =
get_socket_path_from_name("test_working_after_socket_listener_was_dropped".to_string());
get_socket_path_from_name(format!("{}_test_working_after_socket_listener_was_dropped", std::process::id()));
close_socket(&socket_path);
// create a socket listener and drop it, to simulate a panic in a previous iteration.
Builder::new_current_thread()
Expand All @@ -528,6 +528,8 @@ mod socket_listener {
.unwrap()
.block_on(async {
let _ = UnixListener::bind(socket_path.clone()).unwrap();
// UDS sockets require explicit removal of the socket file
close_socket(&socket_path);
});

const CALLBACK_INDEX: u32 = 99;
Expand Down Expand Up @@ -555,7 +557,7 @@ mod socket_listener {
#[timeout(SHORT_STANDALONE_TEST_TIMEOUT)]
fn test_multiple_listeners_competing_for_the_socket() {
let socket_path = get_socket_path_from_name(
"test_multiple_listeners_competing_for_the_socket".to_string(),
format!("{}_test_multiple_listeners_competing_for_the_socket", std::process::id()),
);
close_socket(&socket_path);
let server = Arc::new(RedisServer::new(ServerType::Tcp { tls: false }));
Expand Down

0 comments on commit 99f91f7

Please sign in to comment.