From 729e7dc045b5c9cfde06d532a8aa24a8b0e7dbf0 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Thu, 30 Nov 2023 10:04:53 +0800 Subject: [PATCH 1/4] remove debug macros --- tardis/src/cluster/cluster_hashmap.rs | 1 - tardis/src/cluster/cluster_publish.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/tardis/src/cluster/cluster_hashmap.rs b/tardis/src/cluster/cluster_hashmap.rs index c8d338f1..ab833a3c 100644 --- a/tardis/src/cluster/cluster_hashmap.rs +++ b/tardis/src/cluster/cluster_hashmap.rs @@ -52,7 +52,6 @@ where self.map.write().await.insert(key.clone(), value.clone()); let event = CshmEvent::::Insert(vec![(key, value)]); let json = TardisJson.obj_to_json(&event)?; - dbg!(&json); let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; Ok(()) } diff --git a/tardis/src/cluster/cluster_publish.rs b/tardis/src/cluster/cluster_publish.rs index a7f9d38e..d176f3e1 100644 --- a/tardis/src/cluster/cluster_publish.rs +++ b/tardis/src/cluster/cluster_publish.rs @@ -113,7 +113,6 @@ pub async fn publish_event_with_listener( listener: S, ) -> TardisResult { let node_id = local_node_id().await.to_string(); - dbg!(&node_id); let event = event.into(); let target = target.into(); let target_debug = format!("{target:?}"); From 2cf8df2d391deed43e00cae02c290051ca145b20 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Sat, 2 Dec 2023 22:27:56 +0800 Subject: [PATCH 2/4] add unit test and remove dbg macros --- tardis/src/cluster/cluster_broadcast.rs | 4 +- tardis/src/cluster/cluster_hashmap.rs | 47 +++++++++--- tardis/src/cluster/cluster_processor.rs | 19 ++--- tardis/src/cluster/cluster_receive.rs | 22 +++--- tardis/src/lib.rs | 2 +- tardis/src/utils/tardis_static.rs | 8 +- tardis/src/web/ws_client.rs | 26 ++++--- tardis/tests/test_cluster.rs | 98 ++++++++++++++++++++++++- 8 files changed, 173 insertions(+), 53 deletions(-) diff --git a/tardis/src/cluster/cluster_broadcast.rs b/tardis/src/cluster/cluster_broadcast.rs index 664b32af..8aa28fdd 100644 --- a/tardis/src/cluster/cluster_broadcast.rs +++ b/tardis/src/cluster/cluster_broadcast.rs @@ -45,7 +45,7 @@ where ident: ident.into(), local_broadcast_channel: sender, }); - + tracing::trace!("[Tardis.Cluster] create broadcast channel: {}", cluster_chan.event_name()); let subscriber = BroadcastChannelSubscriber { channel: Arc::downgrade(&cluster_chan), event_name: cluster_chan.event_name(), @@ -97,7 +97,7 @@ where async fn subscribe(&self, message_req: TardisClusterMessageReq) -> TardisResult> { if let Ok(message) = serde_json::from_value(message_req.msg) { if let Some(chan) = self.channel.upgrade() { - let _ = chan.send(message); + let _ = chan.local_broadcast_channel.send(message); } else { unsubscribe(&self.event_name()).await; } diff --git a/tardis/src/cluster/cluster_hashmap.rs b/tardis/src/cluster/cluster_hashmap.rs index ab833a3c..61e8da92 100644 --- a/tardis/src/cluster/cluster_hashmap.rs +++ b/tardis/src/cluster/cluster_hashmap.rs @@ -22,6 +22,7 @@ use super::{ pub struct ClusterStaticHashMap { pub map: Arc>>, pub ident: &'static str, + pub cluster_sync: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -40,8 +41,19 @@ where Self { map: Arc::new(RwLock::new(HashMap::new())), ident, + cluster_sync: true, } } + pub fn new_standalone(ident: &'static str) -> Self { + Self { + map: Arc::new(RwLock::new(HashMap::new())), + ident, + cluster_sync: false, + } + } + pub fn is_cluster(&self) -> bool { + self.cluster_sync + } pub fn event_name(&self) -> String { format!("tardis/hashmap/{ident}", ident = self.ident) } @@ -50,9 +62,11 @@ where } pub async fn insert(&self, key: K, value: V) -> TardisResult<()> { self.map.write().await.insert(key.clone(), value.clone()); - let event = CshmEvent::::Insert(vec![(key, value)]); - let json = TardisJson.obj_to_json(&event)?; - let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + if self.is_cluster() { + let event = CshmEvent::::Insert(vec![(key, value)]); + let json = TardisJson.obj_to_json(&event)?; + let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + } Ok(()) } pub async fn batch_insert(&self, pairs: Vec<(K, V)>) -> TardisResult<()> { @@ -62,16 +76,20 @@ where wg.insert(key.clone(), value.clone()); } } - let event = CshmEvent::::Insert(pairs); - let json = TardisJson.obj_to_json(&event)?; - let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + if self.is_cluster() { + let event = CshmEvent::::Insert(pairs); + let json = TardisJson.obj_to_json(&event)?; + let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + } Ok(()) } pub async fn remove(&self, key: K) -> TardisResult<()> { self.map.write().await.remove(&key); - let event = CshmEvent::::Remove { keys: vec![key] }; - let json = TardisJson.obj_to_json(&event)?; - let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + if self.is_cluster() { + let event = CshmEvent::::Remove { keys: vec![key] }; + let json = TardisJson.obj_to_json(&event)?; + let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + } Ok(()) } pub async fn batch_remove(&self, keys: Vec) -> TardisResult<()> { @@ -81,9 +99,11 @@ where wg.remove(key); } } - let event = CshmEvent::::Remove { keys }; - let json = TardisJson.obj_to_json(&event)?; - let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + if self.is_cluster() { + let event = CshmEvent::::Remove { keys }; + let json = TardisJson.obj_to_json(&event)?; + let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + } Ok(()) } pub async fn get(&self, key: K) -> TardisResult> { @@ -94,6 +114,9 @@ where } } async fn get_remote(&self, key: K) -> TardisResult> { + if !self.is_cluster() { + return Ok(None); + } let peer_count = peer_count().await; if peer_count == 0 { return Ok(None); diff --git a/tardis/src/cluster/cluster_processor.rs b/tardis/src/cluster/cluster_processor.rs index e0c1208d..f76b344e 100644 --- a/tardis/src/cluster/cluster_processor.rs +++ b/tardis/src/cluster/cluster_processor.rs @@ -28,15 +28,14 @@ use async_trait::async_trait; pub const CLUSTER_NODE_WHOAMI: &str = "__cluster_node_who_am_i__"; pub const EVENT_PING: &str = "tardis/ping"; pub const CLUSTER_MESSAGE_CACHE_SIZE: usize = 10000; -pub const WHOIAM_TIMEOUT: Duration = Duration::from_secs(30); +pub const WHOAMI_TIMEOUT: Duration = Duration::from_secs(30); type StaticCowStr = Cow<'static, str>; -// static LOCAL_NODE_ID_SETTER: OnceLock = OnceLock::new(); -// static LOCAL_SOCKET_ADDR: OnceLock = OnceLock::new(); + tardis_static! { pub async set local_socket_addr: SocketAddr; pub async set local_node_id: String; - pub async set responsor_dispatcher: mpsc::Sender; + pub async set responser_dispatcher: mpsc::Sender; pub(crate) cache_nodes: Arc>>; subscribers: Arc>>>; } @@ -190,8 +189,8 @@ async fn init_node(cluster_server: &TardisWebServer, access_addr: SocketAddr) -> info!("[Tardis.Cluster] Initializing node"); set_local_node_id(TardisFuns::field.nanoid()); set_local_socket_addr(access_addr); - debug!("[Tardis.Cluster] Initializing response dispathcer"); - set_responsor_dispatcher(init_response_dispatcher()); + debug!("[Tardis.Cluster] Initializing response dispatcher"); + set_responser_dispatcher(init_response_dispatcher()); debug!("[Tardis.Cluster] Register exchange route"); cluster_server.add_route(ClusterAPI).await; @@ -240,7 +239,9 @@ pub async fn refresh_nodes(active_nodes: &HashSet) -> TardisResult<( let mut table = String::new(); for (k, v) in cache_nodes.iter() { use std::fmt::Write; - writeln!(&mut table, "{k:20} | {v:40} ").expect("shouldn't fail"); + if matches!(k, ClusterRemoteNodeKey::NodeId(_)) { + writeln!(&mut table, "{k:20} | {v:40} ").expect("shouldn't fail"); + } } trace!("[Tardis.Cluster] cache nodes table \n{table}"); Ok(()) @@ -259,7 +260,7 @@ async fn add_remote_node(socket_addr: SocketAddr) -> TardisResult(&message) { Ok(message_resp) => { - if let Err(error) = responsor_dispatcher().await.send(message_resp).await { + if let Err(error) = responser_dispatcher().await.send(message_resp).await { error!("[Tardis.Cluster] [Client] response message {message}: {error}"); } } @@ -270,7 +271,7 @@ async fn add_remote_node(socket_addr: SocketAddr) -> TardisResult bool + Send + Sync>), } tardis_static! { - responsor_subscribers: RwLock>; + responser_subscribers: RwLock>; } pub async fn listen_reply(strategy: S, id: String) -> S::Reply { @@ -33,11 +33,11 @@ pub(crate) fn init_response_dispatcher() -> mpsc::Sender { tokio::spawn(async move { - if let Some(ResponseFn::Once(f)) = responsor_subscribers().write().await.remove(&id) { + if let Some(ResponseFn::Once(f)) = responser_subscribers().write().await.remove(&id) { f(resp) } }); @@ -46,7 +46,7 @@ pub(crate) fn init_response_dispatcher() -> mpsc::Sender Config { + async fn retrieve_config() -> Config { config().clone() } tardis_static! { async async_config: Config = async { wait_other_async().await; - retrive_config().await + retrieve_config().await }; } tardis_static! { - config_defualt: Config; + config_default: Config; } } diff --git a/tardis/src/web/ws_client.rs b/tardis/src/web/ws_client.rs index c7f9451f..06a8e189 100644 --- a/tardis/src/web/ws_client.rs +++ b/tardis/src/web/ws_client.rs @@ -101,19 +101,23 @@ impl TardisWSClient { // let reply = ws_tx.clone(); // let (outbound_tx, mut outbound_rx) = mpsc::unbounded_channel::(); - let (outbound_quene_tx, mut outbound_quene_rx) = mpsc::unbounded_channel::(); + let (outbound_queue_tx, mut outbound_queue_rx) = mpsc::unbounded_channel::(); - // there should be two quene: - // 1. out to client quene - // 2. client to remote quene + // there should be two queue: + // 1. out to client queue + // 2. client to remote queue // outbound side let ob_handle = { let url = url.clone(); tokio::spawn(async move { - while let Some(message) = outbound_quene_rx.recv().await { + while let Some(message) = outbound_queue_rx.recv().await { if let Err(e) = ws_tx.send(message).await { - debug!("[Tardis.WSClient] client: {url} error when send to websocket: {e}") + warn!("[Tardis.WSClient] client: {url} error when send to websocket: {e}"); + match e { + tokio_tungstenite::tungstenite::Error::ConnectionClosed | tokio_tungstenite::tungstenite::Error::AlreadyClosed => break, + _ => {} + } // websocket was closed } } @@ -124,7 +128,7 @@ impl TardisWSClient { let ib_handle = { let on_message = on_message.clone(); - let outbound_quene_tx = outbound_quene_tx.clone(); + let outbound_queue_tx = outbound_queue_tx.clone(); let url = url.clone(); tokio::spawn(async move { // stream would be owned by one single task and @@ -135,13 +139,13 @@ impl TardisWSClient { Ok(message) => { trace!("[Tardis.WSClient] WS receive: {}", message); let fut_response = on_message(message); - let outbound_quene_tx = outbound_quene_tx.clone(); + let outbound_queue_tx = outbound_queue_tx.clone(); let url = url.clone(); tokio::spawn(async move { if let Some(resp) = fut_response.await { trace!("[Tardis.WSClient] WS send: {}", resp); - if let Err(e) = outbound_quene_tx.send(resp) { - debug!("[Tardis.WSClient] client: {url} error when send to outbound message quene: {e}") + if let Err(e) = outbound_queue_tx.send(resp) { + debug!("[Tardis.WSClient] client: {url} error when send to outbound message queue: {e}") // outbound channel was closed } } @@ -163,7 +167,7 @@ impl TardisWSClient { drop(permit) }); - Ok(outbound_quene_tx) + Ok(outbound_queue_tx) } pub async fn send_obj(&self, msg: &E) -> TardisResult<()> { diff --git a/tardis/tests/test_cluster.rs b/tardis/tests/test_cluster.rs index eba78807..badc438a 100644 --- a/tardis/tests/test_cluster.rs +++ b/tardis/tests/test_cluster.rs @@ -2,7 +2,10 @@ use std::{ borrow::Cow, env, path::Path, - sync::atomic::{AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, time::Duration, }; @@ -12,11 +15,14 @@ use serde_json::{json, Value}; use tardis::{ basic::{result::TardisResult, tracing::TardisTracing}, cluster::{ - cluster_processor::{self, ClusterEventTarget, TardisClusterMessageReq, TardisClusterSubscriber}, + cluster_broadcast::ClusterBroadcastChannel, + cluster_hashmap::ClusterStaticHashMap, + cluster_processor::{self, subscribe, ClusterEventTarget, TardisClusterMessageReq, TardisClusterSubscriber}, cluster_publish::publish_event_one_response, }, config::config_dto::{CacheModuleConfig, ClusterConfig, FrameworkConfig, LogConfig, TardisConfig, WebServerCommonConfig, WebServerConfig, WebServerModuleConfig}, consts::IP_LOCALHOST, + tardis_static, test::test_container::TardisTestContainer, TardisFuns, }; @@ -106,6 +112,9 @@ async fn invoke_node(cluster_url: &str, node_id: &str, program: &Path) -> Tardis } async fn start_node(cluster_url: String, node_id: &str) -> TardisResult<()> { + subscribe(map().clone()).await; + // subscribe + broadcast(); cluster_processor::set_local_node_id(format!("node_{node_id}")); let port = portpicker::pick_unused_port().unwrap(); TardisTracing::initializer().with_fmt_layer().with_env_layer().init(); @@ -143,9 +152,25 @@ async fn start_node(cluster_url: String, node_id: &str) -> TardisResult<()> { } TardisFuns::web_server().start().await?; sleep(Duration::from_secs(1)).await; - + let task = { + let node_id = node_id.to_string(); + { + let node_id = node_id.to_string(); + tokio::spawn(async move { + let mut receiver = broadcast().subscribe(); + while let Ok(msg) = receiver.recv().await { + println!("node[{node_id}]/broadcast: {msg}"); + bc_recv_count().fetch_add(1, Ordering::SeqCst); + } + }); + } + tokio::spawn(async move { + test_broadcast(&node_id).await; + }) + }; test_ping(node_id).await?; test_echo(node_id).await?; + test_hash_map(node_id).await?; if node_id == "1" { sleep(Duration::from_secs(1)).await; @@ -154,6 +179,8 @@ async fn start_node(cluster_url: String, node_id: &str) -> TardisResult<()> { } else { sleep(Duration::from_secs(10)).await; } + let result = tokio::join!(task); + result.0.unwrap(); Ok(()) } @@ -227,3 +254,68 @@ async fn test_echo(node_id: &str) -> TardisResult<()> { } Ok(()) } + +tardis_static! { + pub map: ClusterStaticHashMap = ClusterStaticHashMap::new("test"); + broadcast: Arc> = ClusterBroadcastChannel::new("test_channel", 100); + bc_recv_count: AtomicUsize = AtomicUsize::new(0); +} +async fn test_hash_map(node_id: &str) -> TardisResult<()> { + match node_id { + "1" => { + map().insert("item1".to_string(), "from_node1".to_string()).await?; + let value = map().get("item1".to_string()).await?; + assert_eq!(value, Some("from_node1".to_string())); + } + "2" => { + map().insert("item2".to_string(), "from_node2".to_string()).await?; + loop { + tokio::time::sleep(Duration::from_secs(1)).await; + let value = map().get("item1".to_string()).await?; + if value.is_some() { + assert_eq!(value, Some("from_node1".to_string())); + break; + } + } + let value = map().get("item2".to_string()).await?; + assert_eq!(value, Some("from_node2".to_string())); + tokio::time::sleep(Duration::from_secs(5)).await; + map().remove("item2".to_string()).await?; + let value = map().get("item2".to_string()).await?; + assert_eq!(value, None); + } + "3" => {} + _ => {} + } + Ok(()) +} + +async fn test_broadcast(node_id: &str) { + tokio::time::sleep(Duration::from_secs(6)).await; + match node_id { + "1" => { + broadcast().send("message1-1".to_string()); + broadcast().send("message1-2".to_string()); + } + "2" => { + broadcast().send("message2-1".to_string()); + broadcast().send("message2-2".to_string()); + } + "3" => { + broadcast().send("message3-1".to_string()); + broadcast().send("message3-2".to_string()); + } + _ => {} + } + let result = tokio::time::timeout(Duration::from_secs(20), async move { + loop { + if bc_recv_count().load(Ordering::SeqCst) == 6 { + break; + } else { + tokio::task::yield_now().await; + } + } + }) + .await; + assert!(result.is_ok()); +} From 7a9babdeca59be96809484a0fbc8c8edbe57374f Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 4 Dec 2023 16:49:13 +0800 Subject: [PATCH 3/4] clippy and fix --- tardis/src/cluster/cluster_hashmap.rs | 75 ++++++++++++++++++- tardis/src/cluster/cluster_processor.rs | 5 +- tardis/src/web/ws_processor.rs | 51 ++++++------- .../src/web/ws_processor/cluster_protocol.rs | 45 +---------- 4 files changed, 103 insertions(+), 73 deletions(-) diff --git a/tardis/src/cluster/cluster_hashmap.rs b/tardis/src/cluster/cluster_hashmap.rs index 61e8da92..34818e13 100644 --- a/tardis/src/cluster/cluster_hashmap.rs +++ b/tardis/src/cluster/cluster_hashmap.rs @@ -1,6 +1,7 @@ use std::{ borrow::Cow, collections::HashMap, + fmt, sync::Arc, time::{Duration, Instant}, }; @@ -18,11 +19,18 @@ use super::{ }; // Cshm = ClusterStaticHashMap -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ClusterStaticHashMap { pub map: Arc>>, pub ident: &'static str, pub cluster_sync: bool, + pub modify_handler: Arc>>, +} + +impl fmt::Debug for ClusterStaticHashMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ClusterStaticHashMap").field("ident", &self.ident).field("cluster_sync", &self.cluster_sync).finish() + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -30,6 +38,41 @@ enum CshmEvent { Insert(Vec<(K, V)>), Remove { keys: Vec }, Get { key: K }, + Modify { key: K, mapper: String, modify: Value }, +} + +pub struct ClusterStaticHashMapBuilder { + ident: &'static str, + cluster_sync: bool, + modify_handler: HashMap>, + _phantom: std::marker::PhantomData<(K, V)>, +} + +impl ClusterStaticHashMapBuilder { + pub fn new(ident: &'static str) -> Self { + Self { + ident, + cluster_sync: true, + modify_handler: HashMap::new(), + _phantom: std::marker::PhantomData, + } + } + pub fn sync(mut self, cluster_sync: bool) -> Self { + self.cluster_sync = cluster_sync; + self + } + pub fn modify_handler(mut self, mapper: &'static str, handler: impl Fn(&mut V, &Value) + Send + Sync + 'static) -> Self { + self.modify_handler.insert(mapper.to_string(), Box::new(handler)); + self + } + pub fn build(self) -> ClusterStaticHashMap { + ClusterStaticHashMap { + map: Arc::new(RwLock::new(HashMap::new())), + ident: self.ident, + cluster_sync: self.cluster_sync, + modify_handler: Arc::new(self.modify_handler), + } + } } impl ClusterStaticHashMap @@ -37,11 +80,16 @@ where K: Send + Sync + 'static + Clone + serde::Serialize + serde::de::DeserializeOwned + Hash + Eq, V: Send + Sync + 'static + Clone + serde::Serialize + serde::de::DeserializeOwned, { + pub fn builder(ident: &'static str) -> ClusterStaticHashMapBuilder { + ClusterStaticHashMapBuilder::new(ident) + } + pub fn new(ident: &'static str) -> Self { Self { map: Arc::new(RwLock::new(HashMap::new())), ident, cluster_sync: true, + modify_handler: Arc::new(HashMap::new()), } } pub fn new_standalone(ident: &'static str) -> Self { @@ -49,6 +97,7 @@ where map: Arc::new(RwLock::new(HashMap::new())), ident, cluster_sync: false, + modify_handler: Arc::new(HashMap::new()), } } pub fn is_cluster(&self) -> bool { @@ -148,6 +197,21 @@ where } Ok(None) } + pub async fn modify(&self, key: K, mapper: &'static str, modify: Value) -> TardisResult<()> { + let mapper = mapper.to_string(); + let mut wg = self.map.write().await; + if let Some(v) = wg.get_mut(&key) { + if let Some(handler) = self.modify_handler.get(&mapper) { + handler(v, &modify); + } + } + if self.is_cluster() { + let event = CshmEvent::::Modify { key, mapper, modify }; + let json = TardisJson.obj_to_json(&event)?; + let _result = publish_event_no_response(self.event_name(), json, ClusterEventTarget::Broadcast).await; + } + Ok(()) + } } #[async_trait::async_trait] @@ -178,6 +242,15 @@ where let value = rg.get(&key); Ok(Some(TardisJson.obj_to_json(&value)?)) } + CshmEvent::Modify { key, mapper, modify } => { + let mut wg = self.map.write().await; + if let Some(v) = wg.get_mut(&key) { + if let Some(handler) = self.modify_handler.get(&mapper) { + handler(v, &modify); + } + } + Ok(None) + } } } fn event_name(&self) -> Cow<'static, str> { diff --git a/tardis/src/cluster/cluster_processor.rs b/tardis/src/cluster/cluster_processor.rs index f76b344e..d497813e 100644 --- a/tardis/src/cluster/cluster_processor.rs +++ b/tardis/src/cluster/cluster_processor.rs @@ -21,7 +21,8 @@ use crate::config::config_dto::FrameworkConfig; use crate::tardis_static; use crate::web::web_server::TardisWebServer; use crate::web::ws_client::TardisWSClient; -use crate::web::ws_processor::cluster_protocol::Avatar; +use crate::web::ws_processor::ws_insts_mapping_avatars; +// use crate::web::ws_processor::cluster_protocol::Avatar; use crate::{basic::result::TardisResult, TardisFuns}; use async_trait::async_trait; @@ -198,7 +199,7 @@ async fn init_node(cluster_server: &TardisWebServer, access_addr: SocketAddr) -> subscribe(EventPing).await; #[cfg(feature = "web-server")] { - subscribe(Avatar).await; + subscribe(ws_insts_mapping_avatars().clone()).await; } info!("[Tardis.Cluster] Initialized node"); diff --git a/tardis/src/web/ws_processor.rs b/tardis/src/web/ws_processor.rs index 228b25fd..a834d966 100644 --- a/tardis/src/web/ws_processor.rs +++ b/tardis/src/web/ws_processor.rs @@ -9,10 +9,11 @@ use lru::LruCache; use poem::web::websocket::{BoxWebSocketUpgraded, CloseCode, Message, WebSocket}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use tracing::trace; use tracing::warn; +use crate::cluster::cluster_hashmap::ClusterStaticHashMap; use crate::{tardis_static, TardisFuns}; pub const WS_SYSTEM_EVENT_INFO: &str = "__sys_info__"; @@ -27,7 +28,19 @@ pub const WS_SENDER_CACHE_SIZE: NonZeroUsize = unsafe { NonZeroUsize::new_unchec tardis_static! { // Websocket instance Id -> Avatars - ws_insts_mapping_avatars: Arc>>>; + // ws_insts_mapping_avatars: Arc>>>; + pub ws_insts_mapping_avatars: ClusterStaticHashMap> = ClusterStaticHashMap::>::builder("tardis/avatar") + .modify_handler("del_avatar", |v, modify| { + if let Some(del) = modify.as_str() { + v.retain(|value| *value != del); + } + }) + .modify_handler("add_avatar", |v, modify| { + if let Some(add) = modify.as_str() { + v.push(add.to_string() ); + } + }) + .build(); } lazy_static! { // Single instance reply guard @@ -141,12 +154,10 @@ where let mut inner_receiver = inner_sender.subscribe(); websocket .on_upgrade(move |socket| async move { - // corresponed to the current ws connection + // corresponded to the current ws connection let inst_id = TardisFuns::field.nanoid(); let current_receive_inst_id = inst_id.clone(); - { - ws_insts_mapping_avatars().write().await.insert(inst_id.clone(), avatars); - } + let _ = ws_insts_mapping_avatars().insert(inst_id.clone(), avatars).await; let (mut ws_sink, mut ws_stream) = socket.split(); let insts_in_send = ws_insts_mapping_avatars().clone(); @@ -156,7 +167,7 @@ where match message { Message::Text(text) => { let msg_id = TardisFuns::field.nanoid(); - let Some(current_avatars) = insts_in_send.read().await.get(&inst_id).cloned() else { + let Ok(Some(current_avatars)) = insts_in_send.get(inst_id.clone()).await else { warn!("[Tardis.WebServer] insts_in_send of inst_id {inst_id} not found"); continue; }; @@ -228,33 +239,18 @@ where ws_send_error_to_channel(&text, "spec_inst_id is not specified", &avatar_self, &inst_id, &inner_sender); continue; }; - let mut write_locked = insts_in_send.write().await; - let Some(inst) = write_locked.get_mut(&spec_inst_id) else { - ws_send_error_to_channel(&text, "spec_inst_id not found", &avatar_self, &inst_id, &inner_sender); - continue; - }; - inst.push(new_avatar.to_string()); - drop(write_locked); - trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", msg_id, new_avatar, spec_inst_id); - + trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", msg_id, &new_avatar, &spec_inst_id); + let _ = insts_in_send.modify(spec_inst_id, "add_avatar", json!(new_avatar)).await; continue; } else if req_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_DEL.to_string()) { let Some(del_avatar) = req_msg.msg.as_str() else { ws_send_error_to_channel(&text, "msg is not a string", &avatar_self, &inst_id, &inner_sender); continue; }; - let mut write_locked = insts_in_send.write().await; - let Some(inst) = write_locked.get_mut(&inst_id) else { - ws_send_error_to_channel(&text, "spec_inst_id not found", &avatar_self, &inst_id, &inner_sender); - continue; - }; - inst.retain(|value| *value != del_avatar); - drop(write_locked); + let _ = insts_in_send.modify(inst_id.clone(), "del_avatar", json!(del_avatar)).await; trace!("[Tardis.WebServer] WS message delete avatar {},{} to {}", msg_id, del_avatar, &inst_id); continue; } - - // Normal process if let Some(resp_msg) = process_fun(req_msg.clone(), ext.clone()).await { trace!( "[Tardis.WebServer] WS message send to channel: {},{} to {:?} ignore {:?}", @@ -275,7 +271,7 @@ where echo: false, }; ws_send_to_channel(send_msg, &inner_sender); - } + }; } } } @@ -297,11 +293,10 @@ where }); let reply_once_guard = REPLY_ONCE_GUARD.clone(); - let insts_in_receive = ws_insts_mapping_avatars().clone(); tokio::spawn(async move { while let Ok(mgr_message) = inner_receiver.recv().await { - let Some(current_avatars) = ({ insts_in_receive.read().await.get(¤t_receive_inst_id).cloned() }) else { + let Ok(Some(current_avatars)) = ({ ws_insts_mapping_avatars().get(current_receive_inst_id.clone()).await }) else { warn!("[Tardis.WebServer] Instance id {current_receive_inst_id} not found"); continue; }; diff --git a/tardis/src/web/ws_processor/cluster_protocol.rs b/tardis/src/web/ws_processor/cluster_protocol.rs index da6c81ad..9109a33f 100644 --- a/tardis/src/web/ws_processor/cluster_protocol.rs +++ b/tardis/src/web/ws_processor/cluster_protocol.rs @@ -1,47 +1,8 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::{borrow::Cow, collections::HashMap, sync::Arc}; +use std::sync::Arc; -use crate::{ - basic::result::TardisResult, - cluster::{ - cluster_broadcast::ClusterBroadcastChannel, - cluster_processor::{TardisClusterMessageReq, TardisClusterSubscriber}, - }, -}; +use crate::cluster::cluster_broadcast::ClusterBroadcastChannel; -use super::{ws_insts_mapping_avatars, TardisWebsocketMgrMessage, WsBroadcastSender}; - -pub const EVENT_AVATAR: &str = "tardis/avatar"; - -pub(crate) struct Avatar; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) enum AvatarMessage { - Sync { table: HashMap> }, -} - -#[async_trait::async_trait] -impl TardisClusterSubscriber for Avatar { - fn event_name(&self) -> Cow<'static, str> { - EVENT_AVATAR.into() - } - - async fn subscribe(&self, message_req: TardisClusterMessageReq) -> TardisResult> { - // let from_node = message_req.req_node_id; - if let Ok(message) = serde_json::from_value(message_req.msg) { - match message { - AvatarMessage::Sync { table } => { - let mut routes = ws_insts_mapping_avatars().write().await; - for (k, v) in table { - routes.insert(k, v); - } - } - } - } - Ok(None) - } -} +use super::{TardisWebsocketMgrMessage, WsBroadcastSender}; impl WsBroadcastSender for ClusterBroadcastChannel { fn subscribe(&self) -> tokio::sync::broadcast::Receiver { From d8c3f2f0e69bc1a3a5ef328a3f1af737511f8441 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 4 Dec 2023 18:23:03 +0800 Subject: [PATCH 4/4] fix test --- tardis/src/web/ws_processor.rs | 15 ++++++++++++--- tardis/tests/test_websocket.rs | 6 ++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tardis/src/web/ws_processor.rs b/tardis/src/web/ws_processor.rs index a834d966..f922fe58 100644 --- a/tardis/src/web/ws_processor.rs +++ b/tardis/src/web/ws_processor.rs @@ -10,8 +10,8 @@ use poem::web::websocket::{BoxWebSocketUpgraded, CloseCode, Message, WebSocket}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use tokio::sync::Mutex; -use tracing::trace; use tracing::warn; +use tracing::{debug, trace}; use crate::cluster::cluster_hashmap::ClusterStaticHashMap; use crate::{tardis_static, TardisFuns}; @@ -161,6 +161,7 @@ where let (mut ws_sink, mut ws_stream) = socket.split(); let insts_in_send = ws_insts_mapping_avatars().clone(); + debug!("[Tardis.WebServer] WS message receive: new connection {inst_id}"); tokio::spawn(async move { // message inbound while let Some(Ok(message)) = ws_stream.next().await { @@ -184,7 +185,7 @@ where }; match TardisFuns::json.str_to_obj::(&text) { Err(_) => { - ws_send_error_to_channel(&text, "message not illegal", &avatar_self, &inst_id, &inner_sender); + ws_send_error_to_channel(&text, "message illegal", &avatar_self, &inst_id, &inner_sender); break; } Ok(req_msg) => { @@ -211,7 +212,7 @@ where "[Tardis.WebServer] can't serialize {struct_name}, error: {error}", struct_name = stringify!(TardisWebsocketInstInfo) ); - ws_send_error_to_channel(&text, "message not illegal", &avatar_self, &inst_id, &inner_sender); + ws_send_error_to_channel(&text, "message illegal", &avatar_self, &inst_id, &inner_sender); }) else { break; @@ -239,10 +240,18 @@ where ws_send_error_to_channel(&text, "spec_inst_id is not specified", &avatar_self, &inst_id, &inner_sender); continue; }; + let Ok(Some(_)) = insts_in_send.get(spec_inst_id.clone()).await else { + ws_send_error_to_channel(&text, "spec_inst_id not found", &avatar_self, &inst_id, &inner_sender); + continue; + }; trace!("[Tardis.WebServer] WS message add avatar {}:{} to {}", msg_id, &new_avatar, &spec_inst_id); let _ = insts_in_send.modify(spec_inst_id, "add_avatar", json!(new_avatar)).await; continue; } else if req_msg.event == Some(WS_SYSTEM_EVENT_AVATAR_DEL.to_string()) { + let Ok(Some(_)) = insts_in_send.get(inst_id.clone()).await else { + ws_send_error_to_channel(&text, "spec_inst_id not found", &avatar_self, &inst_id, &inner_sender); + continue; + }; let Some(del_avatar) = req_msg.msg.as_str() else { ws_send_error_to_channel(&text, "msg is not a string", &avatar_self, &inst_id, &inner_sender); continue; diff --git a/tardis/tests/test_websocket.rs b/tardis/tests/test_websocket.rs index 99049f71..c719cac2 100644 --- a/tardis/tests/test_websocket.rs +++ b/tardis/tests/test_websocket.rs @@ -10,6 +10,7 @@ use poem::web::websocket::{BoxWebSocketUpgraded, WebSocket}; use poem_openapi::param::Path; use serde_json::json; use tardis::basic::result::TardisResult; +use tardis::cluster::cluster_processor::set_local_node_id; use tardis::consts::IP_LOCALHOST; use tardis::web::web_server::{TardisWebServer, WebServerModule}; use tardis::web::ws_client::TardisWebSocketMessageExt; @@ -31,6 +32,7 @@ lazy_static! { async fn test_websocket() -> TardisResult<()> { env::set_var("RUST_LOG", "info,tardis=trace"); TardisFuns::init_log()?; + set_local_node_id("test".into()); let serv = TardisWebServer::init_simple(IP_LOCALHOST, 8080).unwrap(); serv.add_route(WebServerModule::from(Api).with_ws(100)).await; serv.start().await?; @@ -47,11 +49,11 @@ async fn test_normal() -> TardisResult<()> { static SUB_COUNTER: AtomicUsize = AtomicUsize::new(0); static NON_SUB_COUNTER: AtomicUsize = AtomicUsize::new(0); - // message not illegal test + // message illegal test let error_client_a = TardisFuns::ws_client("ws://127.0.0.1:8080/ws/broadcast/gerror/a", move |msg| async move { if let Message::Text(msg) = msg { println!("client_not_found recv:{}", msg); - assert_eq!(msg, r#"{"msg":"message not illegal","event":"__sys_error__"}"#); + assert_eq!(msg, r#"{"msg":"message illegal","event":"__sys_error__"}"#); ERROR_COUNTER.fetch_add(1, Ordering::SeqCst); } None