From d3cbc3e8b170e9932ad5390f34db46e872ecbe0b Mon Sep 17 00:00:00 2001 From: Hendrik Date: Thu, 30 Nov 2023 21:44:00 +0100 Subject: [PATCH] Extract authentication into function --- src/presentation/talks_ws.rs | 113 ++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 48 deletions(-) diff --git a/src/presentation/talks_ws.rs b/src/presentation/talks_ws.rs index 546adf8..69d67e2 100644 --- a/src/presentation/talks_ws.rs +++ b/src/presentation/talks_ws.rs @@ -65,7 +65,66 @@ async fn talks_ws_connection( ) -> Result<(), String> { let mut updates = services.talks.register_for_updates(); - let response = match receive(&mut socket).await? { + let (user_id, capabilities) = authenticate(&mut socket, &services).await?; + + for talk in services + .talks + .get_all_talks() + .await + .map_err(|error| error.to_string())? + { + send(&mut socket, &Update::AddTalk(talk)).await?; + } + + loop { + select! { + command = receive(&mut socket) => { + let command = command?; + services + .talks + .trigger(user_id, &capabilities, command) + .await + .map_err(|error| error.to_string())?; + }, + update = updates.recv() => { + let update = update.map_err(|error| error.to_string())?; + send(&mut socket, &update).await?; + }, + } + } +} + +async fn send(socket: &mut WebSocket, message: &impl Serialize) -> Result<(), String> { + socket + .send(Message::Text(to_string(message).unwrap())) + .await + .map_err(|error| error.to_string()) +} + +async fn receive(socket: &mut WebSocket) -> Result { + let Message::Text(message) = socket + .recv() + .await + .ok_or_else(|| "closed".to_string())? + .map_err(|error| error.to_string())? + else { + return Err("expected text message".to_string()); + }; + from_str(&message).map_err(|error| error.to_string()) +} + +async fn authenticate( + socket: &mut WebSocket, + services: &Arc< + Services< + impl AuthenticationService + Send + Sync, + impl CalendarService + Send + Sync, + impl TalksService + Send + Sync, + impl TeamsService + Send + Sync, + >, + >, +) -> Result<(i64, HashSet), String> { + let response = match receive(socket).await? { AuthenticationCommand::Register { name, team, @@ -101,13 +160,14 @@ async fn talks_ws_connection( Response::UnknownToken => "unknown token", Response::Success { .. } => panic!("should not be reached"), }; - return send(&mut socket, &AuthenticationResponse::AuthenticationError { + send(socket, &AuthenticationResponse::AuthenticationError { reason: reason.to_string(), - }).await; + }).await?; + return Err(format!("authentication error: {reason}")); }; send( - &mut socket, + socket, &AuthenticationResponse::AuthenticationSuccess { user_id: user_id as usize, capabilities: capabilities.clone(), @@ -116,50 +176,7 @@ async fn talks_ws_connection( ) .await?; - for talk in services - .talks - .get_all_talks() - .await - .map_err(|error| error.to_string())? - { - send(&mut socket, &Update::AddTalk(talk)).await?; - } - - loop { - select! { - command = receive(&mut socket) => { - let command = command?; - services - .talks - .trigger(user_id, &capabilities, command) - .await - .map_err(|error| error.to_string())?; - }, - update = updates.recv() => { - let update = update.map_err(|error| error.to_string())?; - send(&mut socket, &update).await?; - }, - } - } -} - -async fn send(socket: &mut WebSocket, message: &impl Serialize) -> Result<(), String> { - socket - .send(Message::Text(to_string(message).unwrap())) - .await - .map_err(|error| error.to_string()) -} - -async fn receive(socket: &mut WebSocket) -> Result { - let Message::Text(message) = socket - .recv() - .await - .ok_or_else(|| "closed".to_string())? - .map_err(|error| error.to_string())? - else { - return Err("expected text message".to_string()); - }; - from_str(&message).map_err(|error| error.to_string()) + Ok((user_id, capabilities)) } #[derive(Clone, Debug, Deserialize)]