Skip to content

Commit

Permalink
Open cloud connection with address translation, treat advertised addr…
Browse files Browse the repository at this point in the history
…esses as names
  • Loading branch information
dmitrii-ubskii committed Apr 8, 2024
1 parent 8153146 commit f6f4c6f
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 59 deletions.
2 changes: 1 addition & 1 deletion c/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub extern "C" fn replica_info_drop(replica_info: *mut ReplicaInfo) {
/// Retrieves the address of the server hosting this replica
#[no_mangle]
pub extern "C" fn replica_info_get_address(replica_info: *const ReplicaInfo) -> *mut c_char {
release_string(borrow(replica_info).address.to_string())
release_string(borrow(replica_info).server_name.to_string())
}

/// Checks whether this is the primary replica of the raft cluster.
Expand Down
6 changes: 3 additions & 3 deletions rust/src/common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{error::Error as StdError, fmt};
use tonic::{Code, Status};
use typeql::error_messages;

use super::{address::Address, RequestID};
use super::RequestID;

error_messages! { ConnectionError
code: "CXN", type: "Connection Error",
Expand Down Expand Up @@ -84,8 +84,8 @@ error_messages! { InternalError
3: "Unexpected request type for remote procedure call: {request_type}.",
UnexpectedResponseType { response_type: String } =
4: "Unexpected response type for remote procedure call: {response_type}.",
UnknownConnectionAddress { address: Address } =
5: "Received unrecognized address from the server: {address}.",
UnknownConnection { name: String } =
5: "Received unrecognized node ID from the server: {name}.",
EnumOutOfBounds { value: i32, enum_name: &'static str } =
6: "Value '{value}' is out of bounds for enum '{enum_name}'.",
}
Expand Down
6 changes: 3 additions & 3 deletions rust/src/common/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ use std::time::Duration;

use tokio::sync::mpsc::UnboundedSender;

use super::{address::Address, SessionID};
use super::SessionID;
use crate::common::Callback;

#[derive(Clone, Debug)]
pub(crate) struct SessionInfo {
pub(crate) address: Address,
pub(crate) server_name: String,
pub(crate) session_id: SessionID,
pub(crate) network_latency: Duration,
pub(crate) on_close_register_sink: UnboundedSender<Callback>,
Expand All @@ -42,7 +42,7 @@ pub(crate) struct DatabaseInfo {
#[derive(Debug)]
pub struct ReplicaInfo {
/// The address of the server hosting this replica
pub address: Address,
pub server_name: String,
/// Whether this is the primary replica of the raft cluster.
pub is_primary: bool,
/// Whether this is the preferred replica of the raft cluster.
Expand Down
136 changes: 103 additions & 33 deletions rust/src/connection/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ use crate::{
/// A connection to a TypeDB server which serves as the starting point for all interaction.
#[derive(Clone)]
pub struct Connection {
server_connections: HashMap<Address, ServerConnection>,
server_addresses: HashMap<String, Address>,
server_connections: HashMap<String, ServerConnection>,
background_runtime: Arc<BackgroundRuntime>,
username: Option<String>,
is_cloud: bool,
Expand All @@ -76,18 +77,23 @@ impl Connection {
/// Connection::new_core("127.0.0.1:1729")
/// ```
pub fn new_core(address: impl AsRef<str>) -> Result<Self> {
let name = address.as_ref().to_owned();
let address: Address = address.as_ref().parse()?;
let background_runtime = Arc::new(BackgroundRuntime::new()?);
let mut server_connection = ServerConnection::new_core(background_runtime.clone(), address)?;
let address = server_connection
let mut server_connection = ServerConnection::new_core(background_runtime.clone(), name, address.clone())?;

let advertised_name = server_connection
.servers_all()?
.into_iter()
.exactly_one()
.map_err(|e| ConnectionError::ServerConnectionFailedStatusError { error: e.to_string() })?;
server_connection.set_address(address.clone());
.map_err(|e| ConnectionError::ServerConnectionFailedStatusError { error: e.to_string() })?
.to_string();
server_connection.set_name(advertised_name.clone());

match server_connection.validate() {
Ok(()) => Ok(Self {
server_connections: [(address, server_connection)].into(),
server_connections: [(advertised_name.clone(), server_connection)].into(),
server_addresses: [(advertised_name, address)].into(),
background_runtime,
username: None,
is_cloud: false,
Expand Down Expand Up @@ -124,11 +130,68 @@ impl Connection {
let init_addresses = init_addresses.iter().map(|addr| addr.as_ref().parse()).try_collect()?;
let addresses = Self::fetch_current_addresses(background_runtime.clone(), init_addresses, credential.clone())?;

let server_connections: HashMap<Address, ServerConnection> = addresses
Self::new_cloud_impl(
background_runtime,
addresses.into_iter().map(|addr| (addr.to_string(), addr)).collect(),
credential,
)
}

/// Creates a new TypeDB Cloud connection.
///
/// # Arguments
///
/// * `addresses` -- Translation map from addresses received from the TypeDB server(s) to
/// addresses to be used by the driver for connection
/// * `credential` -- User credential and TLS encryption setting
///
/// # Examples
///
/// ```rust
/// Connection::new_cloud(
/// &["localhost:11729", "localhost:21729", "localhost:31729"],
/// Credential::with_tls(
/// "admin",
/// "password",
/// Some(&PathBuf::from(
/// std::env::var("ROOT_CA")
/// .expect("ROOT_CA environment variable needs to be set for cloud tests to run"),
/// )),
/// )?,
/// )
/// ```
pub fn new_cloud_address_map<T: AsRef<str> + Sync, U: AsRef<str> + Sync>(
addresses: HashMap<T, U>,
credential: Credential,
) -> Result<Self> {
let background_runtime = Arc::new(BackgroundRuntime::new()?);

let server_addresses: HashMap<_, _> = addresses
.into_iter()
.map(|address| {
ServerConnection::new_cloud(background_runtime.clone(), address.clone(), credential.clone())
.map(|server_connection| (address, server_connection))
.map(|(name, address)| {
let name = name.as_ref().to_owned();
address.as_ref().parse::<Address>().map(|address| (name, address))
})
.try_collect()?;

Self::new_cloud_impl(background_runtime, server_addresses, credential)
}

fn new_cloud_impl(
background_runtime: Arc<BackgroundRuntime>,
server_addresses: HashMap<String, Address>,
credential: Credential,
) -> Result<Self> {
let server_connections: HashMap<String, ServerConnection> = server_addresses
.iter()
.map(|(name, address)| {
ServerConnection::new_cloud(
background_runtime.clone(),
name.clone(),
address.clone(),
credential.clone(),
)
.map(|server_connection| (name.clone(), server_connection))
})
.try_collect()?;

Expand All @@ -140,6 +203,7 @@ impl Connection {
})?
} else {
Ok(Self {
server_addresses,
server_connections,
background_runtime,
username: Some(credential.username().to_string()),
Expand All @@ -154,8 +218,12 @@ impl Connection {
credential: Credential,
) -> Result<HashSet<Address>> {
for address in &addresses {
let server_connection =
ServerConnection::new_cloud(background_runtime.clone(), address.clone(), credential.clone());
let server_connection = ServerConnection::new_cloud(
background_runtime.clone(),
address.to_string(),
address.clone(),
credential.clone(),
);
match server_connection {
Ok(server_connection) => match server_connection.servers_all() {
Ok(servers) => return Ok(servers.into_iter().collect()),
Expand Down Expand Up @@ -215,14 +283,14 @@ impl Connection {
self.server_connections.len()
}

pub(crate) fn addresses(&self) -> impl Iterator<Item = &Address> {
self.server_connections.keys()
pub(crate) fn server_names(&self) -> impl Iterator<Item = &str> {
self.server_connections.keys().map(|name| name.as_str())
}

pub(crate) fn connection(&self, address: &Address) -> Result<&ServerConnection> {
pub(crate) fn connection(&self, address: &str) -> Result<&ServerConnection> {
self.server_connections
.get(address)
.ok_or_else(|| InternalError::UnknownConnectionAddress { address: address.clone() }.into())
.ok_or_else(|| InternalError::UnknownConnection { name: address.to_owned() }.into())
}

pub(crate) fn connections(&self) -> impl Iterator<Item = &ServerConnection> + '_ {
Expand All @@ -234,9 +302,7 @@ impl Connection {
}

pub(crate) fn unable_to_connect_error(&self) -> Error {
Error::Connection(ConnectionError::ServerConnectionFailedStatusError {
error: self.addresses().map(Address::to_string).collect::<Vec<_>>().join(", "),
})
Error::Connection(ConnectionError::ServerConnectionFailedStatusError { error: self.server_names().join(", ") })
}
}

Expand All @@ -248,22 +314,26 @@ impl fmt::Debug for Connection {

#[derive(Clone)]
pub(crate) struct ServerConnection {
address: Address,
name: String,
background_runtime: Arc<BackgroundRuntime>,
open_sessions: Arc<Mutex<HashMap<SessionID, UnboundedSender<()>>>>,
request_transmitter: Arc<RPCTransmitter>,
}

impl ServerConnection {
fn new_core(background_runtime: Arc<BackgroundRuntime>, address: Address) -> Result<Self> {
let request_transmitter = Arc::new(RPCTransmitter::start_core(address.clone(), &background_runtime)?);
Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter })
fn new_core(background_runtime: Arc<BackgroundRuntime>, name: String, address: Address) -> Result<Self> {
let request_transmitter = Arc::new(RPCTransmitter::start_core(address, &background_runtime)?);
Ok(Self { name, background_runtime, open_sessions: Default::default(), request_transmitter })
}

fn new_cloud(background_runtime: Arc<BackgroundRuntime>, address: Address, credential: Credential) -> Result<Self> {
let request_transmitter =
Arc::new(RPCTransmitter::start_cloud(address.clone(), credential, &background_runtime)?);
Ok(Self { address, background_runtime, open_sessions: Default::default(), request_transmitter })
fn new_cloud(
background_runtime: Arc<BackgroundRuntime>,
name: String,
address: Address,
credential: Credential,
) -> Result<Self> {
let request_transmitter = Arc::new(RPCTransmitter::start_cloud(address, credential, &background_runtime)?);
Ok(Self { name, background_runtime, open_sessions: Default::default(), request_transmitter })
}

pub(crate) fn validate(&self) -> Result {
Expand All @@ -273,12 +343,12 @@ impl ServerConnection {
}
}

fn set_address(&mut self, address: Address) {
self.address = address;
fn set_name(&mut self, name: String) {
self.name = name;
}

pub(crate) fn address(&self) -> &Address {
&self.address
pub(crate) fn name(&self) -> &str {
&self.name
}

#[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
Expand Down Expand Up @@ -392,7 +462,7 @@ impl ServerConnection {
pulse_shutdown_source,
));
Ok(SessionInfo {
address: self.address.clone(),
server_name: self.name.clone(),
session_id,
network_latency: start.elapsed().saturating_sub(server_duration),
on_close_register_sink,
Expand Down Expand Up @@ -507,7 +577,7 @@ impl ServerConnection {
impl fmt::Debug for ServerConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerConnection")
.field("address", &self.address)
.field("name", &self.name)
.field("open_sessions", &self.open_sessions)
.finish()
}
Expand Down
2 changes: 1 addition & 1 deletion rust/src/connection/network/proto/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl TryFromProto<DatabaseProto> for DatabaseInfo {
impl TryFromProto<ReplicaProto> for ReplicaInfo {
fn try_from_proto(proto: ReplicaProto) -> Result<Self> {
Ok(Self {
address: proto.address.parse()?,
server_name: proto.address,
is_primary: proto.primary,
is_preferred: proto.preferred,
term: proto.term,
Expand Down
26 changes: 14 additions & 12 deletions rust/src/database/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use log::{debug, error};

use crate::{
common::{
address::Address,
error::ConnectionError,
info::{DatabaseInfo, ReplicaInfo},
Error, Result,
Expand Down Expand Up @@ -183,11 +182,11 @@ impl Database {
{
let replicas = self.replicas.read().unwrap().clone();
for replica in replicas {
match task(replica.database.clone(), self.connection.connection(&replica.address)?.clone()).await {
match task(replica.database.clone(), self.connection.connection(&replica.server_name)?.clone()).await {
Err(Error::Connection(
ConnectionError::ServerConnectionFailedStatusError { .. } | ConnectionError::ConnectionFailed,
)) => {
debug!("Unable to connect to {}. Attempting next server.", replica.address);
debug!("Unable to connect to {}. Attempting next server.", replica.server_name);
}
res => return res,
}
Expand All @@ -205,8 +204,11 @@ impl Database {
if let Some(replica) = self.primary_replica() { replica } else { self.seek_primary_replica().await? };

for _ in 0..Self::PRIMARY_REPLICA_TASK_MAX_RETRIES {
match task(primary_replica.database.clone(), self.connection.connection(&primary_replica.address)?.clone())
.await
match task(
primary_replica.database.clone(),
self.connection.connection(&primary_replica.server_name)?.clone(),
)
.await
{
Err(Error::Connection(
ConnectionError::CloudReplicaNotPrimary
Expand Down Expand Up @@ -260,8 +262,8 @@ impl fmt::Debug for Database {
/// The metadata and state of an individual raft replica of a database.
#[derive(Clone)]
pub(super) struct Replica {
/// Retrieves address of the server hosting this replica
address: Address,
/// Retrieves the name of the server hosting this replica
server_name: String,
/// Retrieves the database name for which this is a replica
database_name: String,
/// Checks whether this is the primary replica of the raft cluster.
Expand All @@ -277,7 +279,7 @@ pub(super) struct Replica {
impl Replica {
fn new(name: String, metadata: ReplicaInfo, server_connection: ServerConnection) -> Self {
Self {
address: metadata.address,
server_name: metadata.server_name,
database_name: name.clone(),
is_primary: metadata.is_primary,
term: metadata.term,
Expand All @@ -291,15 +293,15 @@ impl Replica {
.replicas
.into_iter()
.map(|replica| {
let server_connection = connection.connection(&replica.address)?.clone();
let server_connection = connection.connection(&replica.server_name)?.clone();
Ok(Self::new(database_info.name.clone(), replica, server_connection))
})
.try_collect()
}

fn to_info(&self) -> ReplicaInfo {
ReplicaInfo {
address: self.address.clone(),
server_name: self.server_name.clone(),
is_primary: self.is_primary,
is_preferred: self.is_preferred,
term: self.term,
Expand All @@ -322,7 +324,7 @@ impl Replica {
error!(
"Failed to fetch replica info for database '{}' from {}. Attempting next server.",
name,
server_connection.address()
server_connection.name()
);
}
Err(err) => return Err(err),
Expand All @@ -335,7 +337,7 @@ impl Replica {
impl fmt::Debug for Replica {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Replica")
.field("address", &self.address)
.field("address", &self.server_name)
.field("database_name", &self.database_name)
.field("is_primary", &self.is_primary)
.field("term", &self.term)
Expand Down
Loading

0 comments on commit f6f4c6f

Please sign in to comment.