Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gold-180 - fix: inproper nodejs task scheduling, use mpsc for multi_send #10

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 38 additions & 46 deletions shardus_net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,13 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>
let header_js_string: String = cx.argument::<JsString>(3)?.value(cx) as String;
let data_js_string: String = cx.argument::<JsString>(4)?.value(cx) as String;
let complete_cb = cx.argument::<JsFunction>(5)?.root(cx);
let await_processing = cx.argument::<JsBoolean>(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback
let schedule_complete_callback = cx.argument::<JsBoolean>(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback

let shardus_net_sender = cx.this().get::<JsBox<Arc<ShardusNetSender>>, _, _>(cx, "_sender")?;
let stats_incrementers = cx.this().get::<JsBox<Incrementers>, _, _>(cx, "_stats_incrementers")?;

let this = cx.this().root(cx);
let channel = cx.channel();
let nodejs_thread_channel = cx.channel();

for _ in 0..ports.len() {
stats_incrementers.increment_outstanding_sends();
Expand All @@ -322,51 +322,42 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>

let data = data_js_string.into_bytes().to_vec();

// Create oneshot channels for each host-port pair
let mut senders = Vec::with_capacity(hosts.len());
let mut receivers = Vec::with_capacity(hosts.len());
let complete_cb = Arc::new(complete_cb);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<SendResult>();

// should a check be added to see if ports.len == hosts.len
for _ in 0..hosts.len() {
let (sender, receiver) = oneshot::channel::<SendResult>();
senders.push(sender);
receivers.push(receiver);
}
RUNTIME.spawn(async move {

let complete_cb = Arc::new(complete_cb);
let this = Arc::new(this);

// Handle the responses asynchronously
for receiver in receivers {
let channel = channel.clone();
let complete_cb = complete_cb.clone();
let this = this.clone();

RUNTIME.spawn(async move {
let result = receiver.await.expect("Complete send tx dropped before notify");

if await_processing {
RUNTIME.spawn_blocking(move || {
channel.send(move |mut cx| {
let cx = &mut cx;
let stats = this.to_inner(cx).get::<JsBox<RefCell<Stats>>, _, _>(cx, "_stats")?;
(**stats).borrow_mut().decrement_outstanding_sends();

let this = cx.undefined();

if let Err(err) = result {
let error = cx.string(format!("{:?}", err));
complete_cb.to_inner(cx).call(cx, this, [error.upcast()])?;
} else {
complete_cb.to_inner(cx).call(cx, this, [])?;
}

Ok(())
});
});
}
});
}
let mut results = Vec::new();

// recv will return None when all tx are dropped
// So this'll not hang forever.
while let Some(result) = rx.recv().await {
results.push(result);
}

if schedule_complete_callback {
nodejs_thread_channel.send(move |mut cx| {
let cx = &mut cx;

let js_arr = cx.empty_array();
for i in 0..results.len() {
let stats = this.to_inner(cx).get::<JsBox<RefCell<Stats>>, _, _>(cx, "_stats")?;
(**stats).borrow_mut().decrement_outstanding_sends();
if let Err(err) = &results[i] {
let err = cx.string(format!("{:?}", err));
js_arr.set(cx, i as u32, err)?;
}
}

let undef = cx.undefined();

complete_cb.to_inner(cx).call(cx, undef, [js_arr.upcast()])?;
kgmyatthu marked this conversation as resolved.
Show resolved Hide resolved

Ok(())
});
}

});

let mut addresses = Vec::new();
for (host, port) in hosts.iter().zip(ports.iter()) {
Expand All @@ -384,7 +375,7 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>
}

// Send each address with its corresponding sender
shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, senders);
shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, tx);

Ok(cx.undefined())
}
Expand Down Expand Up @@ -577,6 +568,7 @@ fn get_sender_address(mut cx: FunctionContext) -> JsResult<JsObject> {

#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {

cx.export_function("Sn", create_shardus_net)?;

cx.export_function("setLoggingEnabled", set_logging_enabled)?;
Expand Down
33 changes: 22 additions & 11 deletions shardus_net/src/shardus_net_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::runtime::RUNTIME;
use crate::header::header_types::Header;
use crate::header_factory::{header_serialize_factory, wrap_serialized_message};
use crate::message::Message;
use crate::oneshot::Sender;
use crate::shardus_crypto;
use log::error;
#[cfg(debug)]
Expand Down Expand Up @@ -31,9 +30,14 @@ pub enum SenderError {

pub type SendResult = Result<(), SenderError>;

pub enum Transmitter<T> {
Oneshot(tokio::sync::oneshot::Sender<T>),
Mpsc(UnboundedSender<T>),
}

pub struct ShardusNetSender {
key_pair: crypto::KeyPair,
send_channel: UnboundedSender<(SocketAddr, Vec<u8>, Sender<SendResult>)>,
send_channel: UnboundedSender<(SocketAddr, Vec<u8>, Transmitter<SendResult>)>,
evict_socket_channel: UnboundedSender<SocketAddr>,
}

Expand All @@ -53,38 +57,38 @@ impl ShardusNetSender {
}

// send: send data to a socket address without a header
pub fn send(&self, address: SocketAddr, data: String, complete_tx: Sender<SendResult>) {
pub fn send(&self, address: SocketAddr, data: String, complete_tx: tokio::sync::oneshot::Sender<SendResult>) {
let data = data.into_bytes();
self.send_channel
.send((address, data, complete_tx))
.send((address, data, Transmitter::Oneshot(complete_tx)))
.expect("Unexpected! Failed to send data to channel. Sender task must have been dropped.");
}

// send_with_header: send data to a socket address with a header and signature
pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: Sender<SendResult>) {
pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: tokio::sync::oneshot::Sender<SendResult>) {
let compressed_data = header.compress(data);
header.set_message_length(compressed_data.len() as u32);
let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header");
let mut message = Message::new_unsigned(header_version, serialized_header, compressed_data);
message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair);
let serialized_message = wrap_serialized_message(message.serialize());
self.send_channel
.send((address, serialized_message, complete_tx))
.send((address, serialized_message, Transmitter::Oneshot(complete_tx)))
.expect("Unexpected! Failed to send data with header to channel. Sender task must have been dropped.");
}

// multi_send_with_header: send data to multiple socket addresses with a single header and signature
pub fn multi_send_with_header(&self, addresses: Vec<SocketAddr>, header_version: u8, mut header: Header, data: Vec<u8>, senders: Vec<Sender<SendResult>>) {
pub fn multi_send_with_header(&self, addresses: Vec<SocketAddr>, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: tokio::sync::mpsc::UnboundedSender<SendResult>) {
let compressed_data = header.compress(data);
header.set_message_length(compressed_data.len() as u32);
let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header");
let mut message = Message::new_unsigned(header_version, serialized_header.clone(), compressed_data.clone());
message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair);
let serialized_message = wrap_serialized_message(message.serialize());

for (address, sender) in addresses.into_iter().zip(senders.into_iter()) {
for address in addresses {
self.send_channel
.send((address, serialized_message.clone(), sender))
.send((address, serialized_message.clone(), Transmitter::Mpsc(complete_tx.clone())))
.expect("Failed to send data with header to channel");
}
}
Expand All @@ -111,7 +115,7 @@ impl ShardusNetSender {
});
}

fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec<u8>, Sender<SendResult>)>, connections: Arc<Mutex<dyn ConnectionCache + Send>>) {
fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec<u8>, Transmitter<SendResult>)>, connections: Arc<Mutex<dyn ConnectionCache + Send>>) {
RUNTIME.spawn(async move {
let mut send_channel_rx = send_channel_rx;

Expand All @@ -123,7 +127,14 @@ impl ShardusNetSender {

RUNTIME.spawn(async move {
let result = connection.send(data).await;
complete_tx.send(result).ok();
match complete_tx {
Transmitter::Oneshot(complete_tx) => {
complete_tx.send(result).ok().expect("Failed to send result to oneshot rx")
}
Transmitter::Mpsc(complete_tx) => {
complete_tx.send(result).ok().expect("Failed to send result to mspc rx, rx might have been dropped")
}
}
});
}

Expand Down
4 changes: 2 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export const Sn = (opts: SnOpts) => {
version: number
headerData: CombinedHeader
},
awaitProcessing: boolean = true
callbackEnabled: boolean = true
) => {
return new Promise<{ success: boolean; error?: string }>((resolve, reject) => {
const stringifiedData = jsonStringify(augData, opts.customStringifier)
Expand Down Expand Up @@ -168,7 +168,7 @@ export const Sn = (opts: SnOpts) => {
stringifiedHeader,
stringifiedData,
sendCallback,
awaitProcessing
callbackEnabled
)
} else {
if (logFlags.net_verbose) console.log('send_with_header')
Expand Down