From 91fcf57de20cc63d13e2d11b360fab56b3265e68 Mon Sep 17 00:00:00 2001 From: Ben Cherry Date: Thu, 31 Oct 2024 14:10:26 -0700 Subject: [PATCH] RPC updates (#476) * Check RPC version * invocation data, duration * Duration * fmt * params * data --- examples/rpc/src/main.rs | 120 +++++++++--------- livekit-ffi/src/server/participant.rs | 67 +++++----- livekit-ffi/src/server/requests.rs | 2 +- livekit/src/prelude.rs | 3 +- livekit/src/room/mod.rs | 7 +- .../src/room/participant/local_participant.rs | 85 ++++++------- livekit/src/room/participant/rpc.rs | 39 ++++++ livekit/src/rtc_engine/mod.rs | 6 +- livekit/src/rtc_engine/rtc_session.rs | 6 +- 9 files changed, 180 insertions(+), 155 deletions(-) diff --git a/examples/rpc/src/main.rs b/examples/rpc/src/main.rs index 152f380e..11e0694b 100644 --- a/examples/rpc/src/main.rs +++ b/examples/rpc/src/main.rs @@ -78,71 +78,73 @@ async fn main() -> Result<(), Box> { } async fn register_receiver_methods(greeters_room: &Arc, math_genius_room: &Arc) { - greeters_room.local_participant().register_rpc_method( - "arrival".to_string(), - |_, caller_identity, payload, _| { - Box::pin(async move { - println!( - "[{}] [Greeter] Oh {} arrived and said \"{}\"", - elapsed_time(), - caller_identity, - payload - ); - sleep(Duration::from_secs(2)).await; - Ok("Welcome and have a wonderful day!".to_string()) - }) - }, - ); - - math_genius_room.local_participant().register_rpc_method("square-root".to_string(), |_, caller_identity, payload, response_timeout_ms| { + greeters_room.local_participant().register_rpc_method("arrival".to_string(), |data| { Box::pin(async move { - let json_data: Value = serde_json::from_str(&payload).unwrap(); - let number = json_data["number"].as_f64().unwrap(); println!( - "[{}] [Math Genius] I guess {} wants the square root of {}. I've only got {} seconds to respond but I think I can pull it off.", + "[{}] [Greeter] Oh {} arrived and said \"{}\"", elapsed_time(), - caller_identity, - number, - response_timeout_ms.as_secs() + data.caller_identity, + data.payload ); - - println!("[{}] [Math Genius] *doing math*…", elapsed_time()); sleep(Duration::from_secs(2)).await; - - let result = number.sqrt(); - println!("[{}] [Math Genius] Aha! It's {}", elapsed_time(), result); - Ok(json!({"result": result}).to_string()) + Ok("Welcome and have a wonderful day!".to_string()) }) }); math_genius_room.local_participant().register_rpc_method( - "divide".to_string(), - |_, caller_identity, payload, _| { + "square-root".to_string(), + |data| { Box::pin(async move { - let json_data: Value = serde_json::from_str(&payload).unwrap(); - let dividend = json_data["dividend"].as_i64().unwrap(); - let divisor = json_data["divisor"].as_i64().unwrap(); + let json_data: Value = serde_json::from_str(&data.payload).unwrap(); + let number = json_data["number"].as_f64().unwrap(); println!( - "[{}] [Math Genius] {} wants me to divide {} by {}.", + "[{}] [Math Genius] I guess {} wants the square root of {}. I've only got {} seconds to respond but I think I can pull it off.", elapsed_time(), - caller_identity, - dividend, - divisor + data.caller_identity, + number, + data.response_timeout.as_secs() ); - let result = dividend / divisor; - println!("[{}] [Math Genius] The result is {}", elapsed_time(), result); + println!("[{}] [Math Genius] *doing math*…", elapsed_time()); + sleep(Duration::from_secs(2)).await; + + let result = number.sqrt(); + println!("[{}] [Math Genius] Aha! It's {}", elapsed_time(), result); Ok(json!({"result": result}).to_string()) }) }, ); + + math_genius_room.local_participant().register_rpc_method("divide".to_string(), |data| { + Box::pin(async move { + let json_data: Value = serde_json::from_str(&data.payload).unwrap(); + let dividend = json_data["dividend"].as_i64().unwrap(); + let divisor = json_data["divisor"].as_i64().unwrap(); + println!( + "[{}] [Math Genius] {} wants me to divide {} by {}.", + elapsed_time(), + data.caller_identity, + dividend, + divisor + ); + + let result = dividend / divisor; + println!("[{}] [Math Genius] The result is {}", elapsed_time(), result); + Ok(json!({"result": result}).to_string()) + }) + }); } async fn perform_greeting(room: &Arc) -> Result<(), Box> { println!("[{}] Letting the greeter know that I've arrived", elapsed_time()); match room .local_participant() - .perform_rpc("greeter".to_string(), "arrival".to_string(), "Hello".to_string(), None) + .perform_rpc(PerformRpcData { + destination_identity: "greeter".to_string(), + method: "arrival".to_string(), + payload: "Hello".to_string(), + ..Default::default() + }) .await { Ok(response) => { @@ -157,12 +159,12 @@ async fn perform_square_root(room: &Arc) -> Result<(), Box { @@ -180,12 +182,12 @@ async fn perform_quantum_hypergeometric_series( println!("[{}] What's the quantum hypergeometric series of 42?", elapsed_time()); match room .local_participant() - .perform_rpc( - "math-genius".to_string(), - "quantum-hypergeometric-series".to_string(), - json!({"number": 42}).to_string(), - None, - ) + .perform_rpc(PerformRpcData { + destination_identity: "math-genius".to_string(), + method: "quantum-hypergeometric-series".to_string(), + payload: json!({"number": 42}).to_string(), + ..Default::default() + }) .await { Ok(response) => { @@ -207,12 +209,12 @@ async fn perform_division(room: &Arc) -> Result<(), Box { diff --git a/livekit-ffi/src/server/participant.rs b/livekit-ffi/src/server/participant.rs index fdf54bd6..afa4bade 100644 --- a/livekit-ffi/src/server/participant.rs +++ b/livekit-ffi/src/server/participant.rs @@ -51,12 +51,15 @@ impl FfiParticipant { let handle = server.async_runtime.spawn(async move { let result = local - .perform_rpc( - request.destination_identity.to_string(), - request.method, - request.payload, - request.response_timeout_ms, - ) + .perform_rpc(PerformRpcData { + destination_identity: request.destination_identity.to_string(), + method: request.method, + payload: request.payload, + response_timeout: request + .response_timeout_ms + .map(|ms| Duration::from_millis(ms as u64)) + .unwrap_or(PerformRpcData::default().response_timeout), + }) .await; let callback = proto::PerformRpcCallback { @@ -91,34 +94,27 @@ impl FfiParticipant { let local_participant_handle = self.handle.clone(); let room: Arc = self.room.clone(); - local.register_rpc_method( - method.clone(), - move |request_id, caller_identity, payload, response_timeout| { - Box::pin({ - let room = room.clone(); - let method = method.clone(); - async move { - forward_rpc_method_invocation( - server, - room, - local_participant_handle, - method, - request_id, - caller_identity, - payload, - response_timeout, - ) - .await - } - }) - }, - ); + local.register_rpc_method(method.clone(), move |data| { + Box::pin({ + let room = room.clone(); + let method = method.clone(); + async move { + forward_rpc_method_invocation( + server, + room, + local_participant_handle, + method, + data, + ) + .await + } + }) + }); Ok(proto::RegisterRpcMethodResponse {}) } pub fn unregister_rpc_method( &self, - server: &'static FfiServer, request: proto::UnregisterRpcMethodRequest, ) -> FfiResult { let local = match &self.participant { @@ -139,10 +135,7 @@ async fn forward_rpc_method_invocation( room: Arc, local_participant_handle: FfiHandleId, method: String, - request_id: String, - caller_identity: ParticipantIdentity, - payload: String, - response_timeout: Duration, + data: RpcInvocationData, ) -> Result { let (tx, rx) = oneshot::channel(); let invocation_id = server.next_id(); @@ -152,10 +145,10 @@ async fn forward_rpc_method_invocation( local_participant_handle: local_participant_handle as u64, invocation_id, method, - request_id, - caller_identity: caller_identity.into(), - payload, - response_timeout_ms: response_timeout.as_millis() as u32, + request_id: data.request_id, + caller_identity: data.caller_identity.into(), + payload: data.payload, + response_timeout_ms: data.response_timeout.as_millis() as u32, }, )); diff --git a/livekit-ffi/src/server/requests.rs b/livekit-ffi/src/server/requests.rs index 6af1fd66..5c4aa32f 100644 --- a/livekit-ffi/src/server/requests.rs +++ b/livekit-ffi/src/server/requests.rs @@ -815,7 +815,7 @@ fn on_unregister_rpc_method( ) -> FfiResult { let ffi_participant = server.retrieve_handle::(request.local_participant_handle)?.clone(); - return ffi_participant.unregister_rpc_method(server, request); + return ffi_participant.unregister_rpc_method(request); } fn on_rpc_method_invocation_response( diff --git a/livekit/src/prelude.rs b/livekit/src/prelude.rs index a7b86d62..d86e328c 100644 --- a/livekit/src/prelude.rs +++ b/livekit/src/prelude.rs @@ -15,7 +15,8 @@ pub use crate::{ id::*, participant::{ - ConnectionQuality, LocalParticipant, Participant, RemoteParticipant, RpcError, RpcErrorCode, + ConnectionQuality, LocalParticipant, Participant, PerformRpcData, RemoteParticipant, + RpcError, RpcErrorCode, RpcInvocationData, }, publication::{LocalTrackPublication, RemoteTrackPublication, TrackPublication}, track::{ diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index 28d51902..7e38cbff 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -261,7 +261,7 @@ pub struct RpcRequest { pub id: String, pub method: String, pub payload: String, - pub response_timeout_ms: u32, + pub response_timeout: Duration, pub version: u32, } @@ -689,7 +689,7 @@ impl RoomSession { request_id, method, payload, - response_timeout_ms, + response_timeout, version, } => { if caller_identity.is_none() { @@ -702,7 +702,8 @@ impl RoomSession { request_id, method, payload, - response_timeout_ms, + response_timeout, + version, ) .await; } diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index 8c459676..037c3170 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -19,7 +19,7 @@ use crate::{ e2ee::EncryptionType, options::{self, compute_video_encodings, video_layers_from_encodings, TrackPublishOptions}, prelude::*, - room::participant::rpc::{RpcError, RpcErrorCode, MAX_PAYLOAD_BYTES}, + room::participant::rpc::{RpcError, RpcErrorCode, RpcInvocationData, MAX_PAYLOAD_BYTES}, rtc_engine::{EngineError, RtcEngine}, ChatMessage, DataPacket, RpcAck, RpcRequest, RpcResponse, SipDTMF, Transcription, }; @@ -36,12 +36,7 @@ use semver::Version; use tokio::sync::oneshot; type RpcHandler = Arc< - dyn Fn( - String, // request_id - ParticipantIdentity, // caller_identity - String, // payload - Duration, // response_timeout_ms - ) -> Pin> + Send>> + dyn Fn(RpcInvocationData) -> Pin> + Send>> + Send + Sync, >; @@ -72,7 +67,6 @@ impl RpcState { } } } - struct LocalInfo { events: LocalEvents, encryption_type: EncryptionType, @@ -517,7 +511,7 @@ impl LocalParticipant { id: rpc_request.id, method: rpc_request.method, payload: rpc_request.payload, - response_timeout_ms: rpc_request.response_timeout_ms, + response_timeout_ms: rpc_request.response_timeout.as_millis() as u32, version: rpc_request.version, ..Default::default() }; @@ -643,17 +637,10 @@ impl LocalParticipant { self.inner.info.read().kind } - pub async fn perform_rpc( - &self, - destination_identity: String, - method: String, - payload: String, - response_timeout_ms: Option, - ) -> Result { - let response_timeout = Duration::from_millis(response_timeout_ms.unwrap_or(10000) as u64); + pub async fn perform_rpc(&self, data: PerformRpcData) -> Result { let max_round_trip_latency = Duration::from_millis(2000); - if payload.len() > MAX_PAYLOAD_BYTES { + if data.payload.len() > MAX_PAYLOAD_BYTES { return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None)); } @@ -675,11 +662,11 @@ impl LocalParticipant { match self .publish_rpc_request(RpcRequest { - destination_identity: destination_identity.clone(), + destination_identity: data.destination_identity.clone(), id: id.clone(), - method: method.clone(), - payload: payload.clone(), - response_timeout_ms: (response_timeout - max_round_trip_latency).as_millis() as u32, + method: data.method.clone(), + payload: data.payload.clone(), + response_timeout: data.response_timeout - max_round_trip_latency, version: 1, }) .await @@ -709,7 +696,7 @@ impl LocalParticipant { } // Wait for response timout - let response = match tokio::time::timeout(response_timeout, response_rx).await { + let response = match tokio::time::timeout(data.response_timeout, response_rx).await { Err(_) => { self.local.rpc_state.lock().pending_responses.remove(&id); return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None)); @@ -736,12 +723,7 @@ impl LocalParticipant { pub fn register_rpc_method( &self, method: String, - handler: impl Fn( - String, - ParticipantIdentity, - String, - Duration, - ) -> Pin> + Send>> + handler: impl Fn(RpcInvocationData) -> Pin> + Send>> + Send + Sync + 'static, @@ -785,7 +767,8 @@ impl LocalParticipant { request_id: String, method: String, payload: String, - response_timeout_ms: u32, + response_timeout: Duration, + version: u32, ) { if let Err(e) = self .publish_rpc_ack(RpcAck { @@ -797,32 +780,36 @@ impl LocalParticipant { log::error!("Failed to publish RPC ACK: {:?}", e); } - let handler = self.local.rpc_state.lock().handlers.get(&method).cloned(); - let caller_identity_2 = caller_identity.clone(); let request_id_2 = request_id.clone(); - let response = match handler { - Some(handler) => { - match tokio::task::spawn(async move { - handler( - request_id.clone(), - caller_identity.clone(), - payload.clone(), - Duration::from_millis(response_timeout_ms as u64), - ) + let response = if version != 1 { + Err(RpcError::built_in(RpcErrorCode::UnsupportedVersion, None)) + } else { + let handler = self.local.rpc_state.lock().handlers.get(&method).cloned(); + + match handler { + Some(handler) => { + match tokio::task::spawn(async move { + handler(RpcInvocationData { + request_id: request_id.clone(), + caller_identity: caller_identity.clone(), + payload: payload.clone(), + response_timeout, + }) + .await + }) .await - }) - .await - { - Ok(result) => result, - Err(e) => { - log::error!("RPC method handler returned an error: {:?}", e); - Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) + { + Ok(result) => result, + Err(e) => { + log::error!("RPC method handler returned an error: {:?}", e); + Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) + } } } + None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), } - None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), }; let (payload, error) = match response { diff --git a/livekit/src/room/participant/rpc.rs b/livekit/src/room/participant/rpc.rs index 3ddb0f54..34e043ef 100644 --- a/livekit/src/room/participant/rpc.rs +++ b/livekit/src/room/participant/rpc.rs @@ -2,7 +2,44 @@ // // SPDX-License-Identifier: Apache-2.0 +use crate::room::participant::ParticipantIdentity; use livekit_protocol::RpcError as RpcError_Proto; +use std::time::Duration; + +/// Parameters for performing an RPC call +#[derive(Debug, Clone)] +pub struct PerformRpcData { + pub destination_identity: String, + pub method: String, + pub payload: String, + pub response_timeout: Duration, +} + +impl Default for PerformRpcData { + fn default() -> Self { + Self { + destination_identity: Default::default(), + method: Default::default(), + payload: Default::default(), + response_timeout: Duration::from_secs(10), + } + } +} + +/// Data passed to method handler for incoming RPC invocations +/// +/// Attributes: +/// request_id (String): The unique request ID. Will match at both sides of the call, useful for debugging or logging. +/// caller_identity (ParticipantIdentity): The unique participant identity of the caller. +/// payload (String): The payload of the request. User-definable format, typically JSON. +/// response_timeout (Duration): The maximum time the caller will wait for a response. +#[derive(Debug, Clone)] +pub struct RpcInvocationData { + pub request_id: String, + pub caller_identity: ParticipantIdentity, + pub payload: String, + pub response_timeout: Duration, +} /// Specialized error handling for RPC methods. /// @@ -60,6 +97,7 @@ pub enum RpcErrorCode { RecipientNotFound = 1401, RequestPayloadTooLarge = 1402, UnsupportedServer = 1403, + UnsupportedVersion = 1404, } impl RpcErrorCode { @@ -76,6 +114,7 @@ impl RpcErrorCode { Self::RecipientNotFound => "Recipient not found", Self::RequestPayloadTooLarge => "Request payload too large", Self::UnsupportedServer => "RPC not supported by server", + Self::UnsupportedVersion => "Unsupported RPC version", } } } diff --git a/livekit/src/rtc_engine/mod.rs b/livekit/src/rtc_engine/mod.rs index d848571b..4b3c4493 100644 --- a/livekit/src/rtc_engine/mod.rs +++ b/livekit/src/rtc_engine/mod.rs @@ -118,7 +118,7 @@ pub enum EngineEvent { request_id: String, method: String, payload: String, - response_timeout_ms: u32, + response_timeout: Duration, version: u32, }, RpcResponse { @@ -487,7 +487,7 @@ impl EngineInner { request_id, method, payload, - response_timeout_ms, + response_timeout, version, } => { let _ = self.engine_tx.send(EngineEvent::RpcRequest { @@ -495,7 +495,7 @@ impl EngineInner { request_id, method, payload, - response_timeout_ms, + response_timeout, version, }); } diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index 5377be51..fa01a69f 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -101,7 +101,7 @@ pub enum SessionEvent { request_id: String, method: String, payload: String, - response_timeout_ms: u32, + response_timeout: Duration, version: u32, }, RpcResponse { @@ -689,7 +689,9 @@ impl SessionInner { request_id: rpc_request.id.clone(), method: rpc_request.method.clone(), payload: rpc_request.payload.clone(), - response_timeout_ms: rpc_request.response_timeout_ms, + response_timeout: Duration::from_millis( + rpc_request.response_timeout_ms as u64, + ), version: rpc_request.version, }); }