diff --git a/crates/rpc/src/lib.rs b/crates/rpc/src/lib.rs index 6a4a5a4cc1..b1da1705c6 100644 --- a/crates/rpc/src/lib.rs +++ b/crates/rpc/src/lib.rs @@ -895,42 +895,82 @@ mod tests { assert!(!status.is_success()); } + enum Api { + HttpOnly, + WebsocketOnly, + Both, + } + + impl Api { + fn has_websocket(&self) -> bool { + matches!(self, Self::WebsocketOnly | Self::Both) + } + + fn has_http(&self) -> bool { + matches!(self, Self::HttpOnly | Self::Both) + } + } + #[rustfmt::skip] #[rstest::rstest] - #[case::root_api ("/", "v06/starknet_api_openrpc.json", &[])] - #[case::root_trace("/", "v06/starknet_trace_api_openrpc.json", &[])] - #[case::root_write("/", "v06/starknet_write_api.json", &[])] + #[case::root_api("/", "v06/starknet_api_openrpc.json", &[], Api::HttpOnly)] + #[case::root_api_websocket("/ws", "v06/starknet_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::root_trace("/", "v06/starknet_trace_api_openrpc.json", &[], Api::HttpOnly)] + #[case::root_trace_websocket("/ws", "v06/starknet_trace_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::root_write("/", "v06/starknet_write_api.json", &[], Api::HttpOnly)] + #[case::root_write_websocket("/ws", "v06/starknet_write_api.json", &[], Api::WebsocketOnly)] // get_transaction_status is now part of the official spec, so we are phasing it out. - #[case::root_pathfinder("/", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"])] - - #[case::v0_8_api("/rpc/v0_8", "v08/starknet_api_openrpc.json", &[])] - #[case::v0_8_executables("/rpc/v0_8", "v08/starknet_executables.json", &[])] - #[case::v0_8_trace("/rpc/v0_8", "v08/starknet_trace_api_openrpc.json", &[])] - #[case::v0_8_write("/rpc/v0_8", "v08/starknet_write_api.json", &[])] - + #[case::root_pathfinder("/", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::HttpOnly)] + #[case::root_pathfinder_websocket("/ws", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::WebsocketOnly)] + + #[case::v0_8_api("/rpc/v0_8", "v08/starknet_api_openrpc.json", &[], Api::Both)] + #[case::v0_8_executables("/rpc/v0_8", "v08/starknet_executables.json", &[], Api::Both)] + #[case::v0_8_trace("/rpc/v0_8", "v08/starknet_trace_api_openrpc.json", &[], Api::Both)] + #[case::v0_8_write("/rpc/v0_8", "v08/starknet_write_api.json", &[], Api::Both)] + #[case::v0_8_websocket( + "/rpc/v0_8", + "v08/starknet_ws_api.json", + // "starknet_subscription*" methods are in fact notifications + &[ + "starknet_subscriptionNewHeads", + "starknet_subscriptionPendingTransactions", + "starknet_subscriptionTransactionStatus", + "starknet_subscriptionEvents", + "starknet_subscriptionReorg" + ], + Api::WebsocketOnly)] // get_transaction_status is now part of the official spec, so we are phasing it out. - #[case::v0_8_pathfinder("/rpc/v0_8", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"])] - - #[case::v0_7_api ("/rpc/v0_7", "v07/starknet_api_openrpc.json", &[])] - #[case::v0_7_trace("/rpc/v0_7", "v07/starknet_trace_api_openrpc.json", &[])] - #[case::v0_7_write("/rpc/v0_7", "v07/starknet_write_api.json", &[])] + #[case::v0_8_pathfinder("/rpc/v0_8", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::Both)] + + #[case::v0_7_api("/rpc/v0_7", "v07/starknet_api_openrpc.json", &[], Api::HttpOnly)] + #[case::v0_7_api_websocket("/ws/rpc/v0_7", "v07/starknet_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::v0_7_trace("/rpc/v0_7", "v07/starknet_trace_api_openrpc.json", &[], Api::HttpOnly)] + #[case::v0_7_trace_websocket("/ws/rpc/v0_7", "v07/starknet_trace_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::v0_7_write("/rpc/v0_7", "v07/starknet_write_api.json", &[], Api::HttpOnly)] + #[case::v0_7_write_websocket("/ws/rpc/v0_7", "v07/starknet_write_api.json", &[], Api::WebsocketOnly)] // get_transaction_status is now part of the official spec, so we are phasing it out. - #[case::v0_7_pathfinder("/rpc/v0_7", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"])] - - #[case::v0_6_api ("/rpc/v0_6", "v06/starknet_api_openrpc.json", &[])] - #[case::v0_6_trace("/rpc/v0_6", "v06/starknet_trace_api_openrpc.json", &[])] - #[case::v0_6_write("/rpc/v0_6", "v06/starknet_write_api.json", &[])] + #[case::v0_7_pathfinder("/rpc/v0_7", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::HttpOnly)] + #[case::v0_7_pathfinder_websocket("/ws/rpc/v0_7", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::WebsocketOnly)] + + #[case::v0_6_api("/rpc/v0_6", "v06/starknet_api_openrpc.json", &[], Api::HttpOnly)] + #[case::v0_6_api_websocket("/ws/rpc/v0_6", "v06/starknet_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::v0_6_trace("/rpc/v0_6", "v06/starknet_trace_api_openrpc.json", &[], Api::HttpOnly)] + #[case::v0_6_trace_websocket("/ws/rpc/v0_6", "v06/starknet_trace_api_openrpc.json", &[], Api::WebsocketOnly)] + #[case::v0_6_write("/rpc/v0_6", "v06/starknet_write_api.json", &[], Api::HttpOnly)] + #[case::v0_6_write_websocket("/ws/rpc/v0_6", "v06/starknet_write_api.json", &[], Api::WebsocketOnly)] // get_transaction_status is now part of the official spec, so we are phasing it out. - #[case::v0_6_pathfinder("/rpc/v0_6", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"])] + #[case::v0_6_pathfinder("/rpc/v0_6", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::HttpOnly)] + #[case::v0_6_pathfinder_websocket("/ws/rpc/v0_6", "pathfinder_rpc_api.json", &["pathfinder_version", "pathfinder_getTransactionStatus"], Api::WebsocketOnly)] - #[case::pathfinder("/rpc/pathfinder/v0.1", "pathfinder_rpc_api.json", &[])] - #[case::pathfinder("/rpc/pathfinder/v0_1", "pathfinder_rpc_api.json", &[])] + #[case::pathfinder("/rpc/pathfinder/v0.1", "pathfinder_rpc_api.json", &[], Api::HttpOnly)] + #[case::pathfinder("/ws/rpc/pathfinder/v0_1", "pathfinder_rpc_api.json", &[], Api::WebsocketOnly)] #[tokio::test] async fn rpc_routing( #[case] route: &'static str, #[case] specification: std::path::PathBuf, #[case] exclude: &[&'static str], + #[case] api: Api, ) { let specification = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("..") @@ -958,70 +998,135 @@ mod tests { methods.retain(|x| !exclude.contains(x)); let addr: SocketAddr = "127.0.0.1:0".parse().unwrap(); - let context = RpcContext::for_tests(); + let mut context = RpcContext::for_tests(); + if api.has_websocket() { + let (_, rx_pending) = tokio::sync::watch::channel(Default::default()); + + context = context.with_websockets(context::WebsocketContext::new( + std::num::NonZeroUsize::new(10).unwrap(), + std::num::NonZeroUsize::new(10).unwrap(), + rx_pending, + )); + } let (_jh, addr) = RpcServer::new(addr, context, RpcVersion::V07) .spawn() .await .unwrap(); - let url = format!("http://{addr}{route}"); - let client = reqwest::Client::new(); - let method_not_found = json!(-32601); - let mut failures = Vec::new(); - for method in methods { - let request = json!({ - "jsonrpc": "2.0", - "method": method, - "id": 0, - }); - - let res: serde_json::Value = client - .post(url.clone()) - .json(&request) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + if api.has_http() { + let url = format!("http://{addr}{route}"); + let client = reqwest::Client::new(); + let mut failures: Vec<&&str> = Vec::new(); + + for method in &methods { + let request = json!({ + "jsonrpc": "2.0", + "method": method, + "id": 0, + }); + + let res: serde_json::Value = client + .post(url.clone()) + .json(&request) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); - if res["error"]["code"] == method_not_found { - failures.push(method); + if res["error"]["code"] == method_not_found { + failures.push(method); + } } - } - if !failures.is_empty() { - panic!("{failures:#?} were not found"); - } + if !failures.is_empty() { + panic!("{failures:#?} were not found"); + } - // Check that excluded methods are indeed not present. - failures.clear(); - for excluded in exclude { - let request = json!({ - "jsonrpc": "2.0", - "method": excluded, - "id": 0, - }); - - let res: serde_json::Value = client - .post(url.clone()) - .json(&request) - .send() - .await - .unwrap() - .json() - .await - .unwrap(); + // Check that excluded methods are indeed not present. + failures.clear(); + for excluded in exclude { + let request = json!({ + "jsonrpc": "2.0", + "method": excluded, + "id": 0, + }); + + let res: serde_json::Value = client + .post(url.clone()) + .json(&request) + .send() + .await + .unwrap() + .json() + .await + .unwrap(); + + if res["error"]["code"] != method_not_found { + failures.push(excluded); + } + } - if res["error"]["code"] != method_not_found { - failures.push(excluded); + if !failures.is_empty() { + panic!("{failures:#?} were marked as excluded but are actually present"); } } - if !failures.is_empty() { - panic!("{failures:#?} were marked as excluded but are actually present"); + if api.has_websocket() { + use tokio_tungstenite::tungstenite::Message; + use tokio_tungstenite::tungstenite::client::IntoClientRequest; + use futures::{SinkExt, StreamExt}; + + let request = format!("ws://{addr}{route}").into_client_request().unwrap(); + let (mut stream, _) = tokio_tungstenite::connect_async(request).await.unwrap(); + + let mut failures: Vec<&&str> = Vec::new(); + for method in &methods { + let request = json!({ + "jsonrpc": "2.0", + "method": method, + "id": 0, + }); + + stream.send(Message::Text(request.to_string())).await.unwrap(); + let res = stream.next().await.unwrap().unwrap(); + let res: serde_json::Value = serde_json::from_str(&res.to_string()).unwrap(); + + if res["error"]["code"] == method_not_found { + failures.push(method); + } + } + + if !failures.is_empty() { + panic!("{failures:#?} were not found"); + } + + // Check that excluded methods are indeed not present. + failures.clear(); + for excluded in exclude { + let request = json!({ + "jsonrpc": "2.0", + "method": excluded, + "id": 0, + }); + + stream.send(Message::Text(request.to_string())).await.unwrap(); + let res = stream.next().await.unwrap().unwrap(); + let res: serde_json::Value = serde_json::from_str(&res.to_string()).unwrap(); + + if res["error"]["code"] != method_not_found { + failures.push(excluded); + } + } + + if !failures.is_empty() { + panic!("{failures:#?} were marked as excluded but are actually present"); + } } + + } }