Skip to content

Commit

Permalink
Add handler for initial TransactionOpen response instead of erroring
Browse files Browse the repository at this point in the history
  • Loading branch information
flyingsilverfin committed Oct 22, 2024
1 parent 2cf962a commit a4e0087
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 52 deletions.
7 changes: 3 additions & 4 deletions rust/src/connection/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

use std::{collections::HashMap, time::Duration};
use std::time::Duration;

use tokio::sync::mpsc::UnboundedSender;
use tonic::Streaming;
Expand All @@ -26,7 +26,6 @@ use uuid::Uuid;

use crate::{
answer::{
concept_document,
concept_document::{ConceptDocumentHeader, Node},
concept_row::ConceptRowHeader,
},
Expand Down Expand Up @@ -99,8 +98,8 @@ pub(super) enum Response {
DatabaseRuleSchema {
schema: String,
},
TransactionOpen {
request_id: RequestID,
TransactionStream {
open_request_id: RequestID,
request_sink: UnboundedSender<transaction::Client>,
response_source: Streaming<transaction::Server>,
},
Expand Down
24 changes: 24 additions & 0 deletions rust/src/connection/network/transmitter/response_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
* under the License.
*/

use std::{fmt, fmt::Formatter, sync::Arc};

use crossbeam::channel::Sender as SyncSender;
use itertools::Either;
use log::{debug, error};
Expand All @@ -30,11 +32,28 @@ use crate::{

#[derive(Debug)]
pub(super) enum ResponseSink<T> {
ImmediateOneShot(ImmediateHandler<Result<T>>),
AsyncOneShot(AsyncOneshotSender<Result<T>>),
BlockingOneShot(SyncSender<Result<T>>),
Streamed(UnboundedSender<StreamResponse<T>>),
}

pub(super) struct ImmediateHandler<T> {
pub(super) handler: Arc<dyn Fn(T) -> () + Sync + Send>,
}

impl<T> ImmediateHandler<T> {
pub(super) fn run(&self, value: T) {
(self.handler)(value)
}
}

impl<T> fmt::Debug for ImmediateHandler<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Immediate handler")
}
}

pub(super) enum StreamResponse<T> {
Result(Result<T>),
Continue(RequestID),
Expand All @@ -43,6 +62,10 @@ pub(super) enum StreamResponse<T> {
impl<T> ResponseSink<T> {
pub(super) fn finish(self, response: Result<T>) {
let result = match self {
Self::ImmediateOneShot(handler) => {
handler.run(response);
Ok(())
}
Self::AsyncOneShot(sink) => sink.send(response).map_err(|_| InternalError::SendError.into()),
Self::BlockingOneShot(sink) => sink.send(response).map_err(Error::from),
Self::Streamed(sink) => sink.send(StreamResponse::Result(response)).map_err(Error::from),
Expand Down Expand Up @@ -83,6 +106,7 @@ impl<T> ResponseSink<T> {
Self::AsyncOneShot(sink) => sink.send(Err(error.into())).ok(),
Self::BlockingOneShot(sink) => sink.send(Err(error.into())).ok(),
Self::Streamed(sink) => sink.send(StreamResponse::Result(Err(error.into()))).ok(),
Self::ImmediateOneShot(_) => None,
};
}
}
2 changes: 1 addition & 1 deletion rust/src/connection/network/transmitter/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl RPCTransmitter {
let req = transaction_request.into_proto();
let req_id = RequestID::from(req.req_id.clone());
let (request_sink, response_source) = rpc.transaction(req).await?;
Ok(Response::TransactionOpen { request_id: req_id, request_sink, response_source })
Ok(Response::TransactionStream { open_request_id: req_id, request_sink, response_source })
}

Request::UsersAll => rpc.users_all(request.try_into_proto()?).await.map(Response::from_proto),
Expand Down
16 changes: 14 additions & 2 deletions rust/src/connection/network/transmitter/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ use tokio::{
};
use tonic::Streaming;
use typedb_protocol::transaction::{self, res_part::ResPart, server::Server, stream_signal::res_part::State};
use uuid::Uuid;

#[cfg(feature = "sync")]
use super::oneshot_blocking as oneshot;
use super::response_sink::{ResponseSink, StreamResponse};
use super::response_sink::{ImmediateHandler, ResponseSink, StreamResponse};
use crate::{
common::{
box_promise,
Expand All @@ -60,6 +61,7 @@ use crate::{
message::{QueryResponse, Request, Response, TransactionRequest, TransactionResponse},
network::proto::{FromProto, IntoProto, TryFromProto},
runtime::BackgroundRuntime,
server_connection::LatencyTracker,
},
Error,
};
Expand All @@ -85,6 +87,8 @@ impl TransactionTransmitter {
background_runtime: Arc<BackgroundRuntime>,
request_sink: UnboundedSender<transaction::Client>,
response_source: Streaming<transaction::Server>,
initial_request_id: RequestID,
initial_response_handler: Arc<dyn Fn(Result<TransactionResponse>) -> () + Sync + Send>,
) -> Self {
let callback_handler_sink = background_runtime.callback_handler_sink();
let (buffer_sink, buffer_source) = unbounded_async();
Expand All @@ -102,6 +106,8 @@ impl TransactionTransmitter {
callback_handler_sink,
shutdown_sink.clone(),
shutdown_source,
initial_request_id,
initial_response_handler,
));
Self { request_sink: buffer_sink, is_open, error, on_close_register_sink, shutdown_sink, background_runtime }
}
Expand Down Expand Up @@ -230,14 +236,20 @@ impl TransactionTransmitter {
callback_handler_sink: Sender<(Callback, AsyncOneshotSender<()>)>,
shutdown_sink: UnboundedSender<()>,
shutdown_signal: UnboundedReceiver<()>,
initial_request_id: RequestID,
initial_response_handler: Arc<dyn Fn(Result<TransactionResponse>) + Sync + Send>,
) {
let collector = ResponseCollector {
let mut collector = ResponseCollector {
callbacks: Default::default(),
is_open,
error,
on_close: Default::default(),
callback_handler_sink,
};
collector.register(
initial_request_id,
ResponseSink::ImmediateOneShot(ImmediateHandler { handler: initial_response_handler }),
);
tokio::spawn(Self::dispatch_loop(
queue_source,
request_sink,
Expand Down
79 changes: 34 additions & 45 deletions rust/src/connection/server_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use uuid::Uuid;
use crate::{
common::{address::Address, RequestID},
connection::{
message::{Request, Response, TransactionRequest},
message::{Request, Response, TransactionRequest, TransactionResponse},
network::transmitter::{RPCTransmitter, TransactionTransmitter},
runtime::BackgroundRuntime,
TransactionStream,
Expand Down Expand Up @@ -187,44 +187,6 @@ impl ServerConnection {
Ok(())
}

// #[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
// pub(crate) async fn open_session(
// &self,
// database_name: String,
// session_type: SessionType,
// options: Options,
// ) -> Result<SessionInfo> {
// let start = Instant::now();
// match self.request(Request::SessionOpen { database_name, session_type, options }).await? {
// Response::SessionOpen { session_id, server_duration } => {
// let (on_close_register_sink, on_close_register_source) = unbounded_async();
// let (pulse_shutdown_sink, pulse_shutdown_source) = unbounded_async();
// self.open_sessions.lock().unwrap().insert(session_id.clone(), pulse_shutdown_sink);
// self.background_runtime.spawn(session_pulse(
// session_id.clone(),
// self.request_transmitter.clone(),
// on_close_register_source,
// self.background_runtime.callback_handler_sink(),
// pulse_shutdown_source,
// ));
// Ok(SessionInfo {
// session_id,
// network_latency: start.elapsed().saturating_sub(server_duration),
// on_close_register_sink,
// })
// }
// other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()),
// }
// }
//
// pub(crate) fn close_session(&self, session_id: SessionID) -> Result {
// if let Some(sink) = self.open_sessions.lock().unwrap().remove(&session_id) {
// sink.send(()).ok();
// }
// self.request_blocking(Request::SessionClose { session_id })?;
// Ok(())
// }

#[cfg_attr(feature = "sync", maybe_async::must_be_sync)]
pub(crate) async fn open_transaction(
&self,
Expand All @@ -233,6 +195,28 @@ impl ServerConnection {
options: Options,
) -> crate::Result<TransactionStream> {
let network_latency = self.latency_tracker.current_latency();

let open_request_start = Instant::now();
let tracker = self.latency_tracker.clone();
let initial_response_handler = Arc::new(move |result: crate::common::Result<TransactionResponse>| {
match result {
Ok(TransactionResponse::Open { server_duration_millis }) => {
let open_latency =
Instant::now().duration_since(open_request_start).as_millis() as u64 - server_duration_millis;
tracker.update_latency(open_latency)
}
Err(_) => {
// ignore, error will manifest elsewhere
}
Ok(TransactionResponse::Commit)
| Ok(TransactionResponse::Rollback)
| Ok(TransactionResponse::Query(_))
| Ok(TransactionResponse::Close) => {
panic!("Unexpected - transaction open response was not TransactionOpen.")
}
}
});

match self
.request(Request::Transaction(TransactionRequest::Open {
database: database_name.to_owned(),
Expand All @@ -242,9 +226,14 @@ impl ServerConnection {
}))
.await?
{
Response::TransactionOpen { request_id, request_sink, response_source } => {
let transmitter =
TransactionTransmitter::new(self.background_runtime.clone(), request_sink, response_source);
Response::TransactionStream { open_request_id: request_id, request_sink, response_source } => {
let transmitter = TransactionTransmitter::new(
self.background_runtime.clone(),
request_sink,
response_source,
request_id.clone(),
initial_response_handler,
);
let transmitter_shutdown_sink = transmitter.shutdown_sink().clone();
let transaction_stream = TransactionStream::new(transaction_type, options, transmitter);
self.transaction_shutdown_senders.lock().unwrap().insert(request_id, transmitter_shutdown_sink);
Expand Down Expand Up @@ -323,7 +312,7 @@ impl fmt::Debug for ServerConnection {
}

#[derive(Debug, Clone)]
struct LatencyTracker {
pub(crate) struct LatencyTracker {
latency_millis: Arc<AtomicU64>,
}

Expand All @@ -332,11 +321,11 @@ impl LatencyTracker {
Self { latency_millis: Arc::new(AtomicU64::new(initial_latency.as_millis() as u64)) }
}

pub(crate) fn update_latency(&self, latency_millis: Duration) {
pub(crate) fn update_latency(&self, latency_millis: u64) {
let previous_latency = self.latency_millis.load(Ordering::Relaxed);
// TODO: this is a strange but simple averaging scheme
// it might actually be useful as it weights the recent measurement the same as the entire history
self.latency_millis.store((latency_millis.as_millis() as u64 + previous_latency) / 2, Ordering::Relaxed);
self.latency_millis.store((latency_millis + previous_latency) / 2, Ordering::Relaxed);
}

fn current_latency(&self) -> Duration {
Expand Down

0 comments on commit a4e0087

Please sign in to comment.