From 75b3953259ac8293f7eea8617c4c06fd619bf190 Mon Sep 17 00:00:00 2001 From: clarkohw Date: Thu, 19 Sep 2024 14:31:06 -0400 Subject: [PATCH 1/2] initial implemenation of hyper/axum integration --- Cargo.toml | 8 ++ ratchet_axum/Cargo.toml | 24 ++++++ ratchet_axum/src/lib.rs | 162 ++++++++++++++++++++++++++++++++++++ ratchet_rs/Cargo.toml | 10 ++- ratchet_rs/examples/axum.rs | 40 +++++++++ 5 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 ratchet_axum/Cargo.toml create mode 100644 ratchet_axum/src/lib.rs create mode 100644 ratchet_rs/examples/axum.rs diff --git a/Cargo.toml b/Cargo.toml index d35d747..a0bf60e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "ratchet_deflate", "ratchet_ext", "ratchet_fixture", + "ratchet_axum", "ratchet_rs/autobahn/client", "ratchet_rs/autobahn/server", "ratchet_rs/autobahn/split_client", @@ -43,3 +44,10 @@ flate2 = { version = "1.0", default-features = false } anyhow = "1.0" serde_json = "1.0" tracing-subscriber = "0.3.18" +hyper = "1.4.1" +axum = "0.7.5" +axum-core = "0.4.3" +eyre = "0.6.12" +hyper-util = "0.1.0" +pin-project = "1.1.5" +async-trait = "0.1.79" diff --git a/ratchet_axum/Cargo.toml b/ratchet_axum/Cargo.toml new file mode 100644 index 0000000..b363a34 --- /dev/null +++ b/ratchet_axum/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "ratchet_axum" +description = "Axum Integration for Ratchet" +readme = "README.md" +repository = "https://github.com/swimos/ratchet/" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +categories.workspace = true + + +[dependencies] +ratchet_core = { version = "1.0.3", path = "../ratchet_core" } +ratchet_ext = { version = "1.0.3", path = "../ratchet_ext" } +hyper = { workspace = true } +axum-core = { workspace = true } +eyre = { workspace = true} +hyper-util = { workspace = true , features = ["tokio"]} +pin-project = { workspace = true } +async-trait = { workspace = true } +base64 = "0.22.1" +http = "1.1.0" +sha1 = "0.10.1" diff --git a/ratchet_axum/src/lib.rs b/ratchet_axum/src/lib.rs new file mode 100644 index 0000000..4dfaff4 --- /dev/null +++ b/ratchet_axum/src/lib.rs @@ -0,0 +1,162 @@ +// Port of hyper_tunstenite for fastwebsockets. +// https://github.com/de-vri-es/hyper-tungstenite-rs +// +// Copyright 2021, Maarten de Vries maarten@de-vri.es +// BSD 2-Clause "Simplified" License +// +// Copyright 2023 Divy Srivastava +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//todo missing docs + +#![deny( + // missing_docs, + missing_copy_implementations, + missing_debug_implementations, + trivial_numeric_casts, + unstable_features, + unused_must_use, + unused_mut, + unused_imports, + unused_import_braces +)] + +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; + +use axum_core::body::Body; +use base64; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use eyre::Report; +use http::HeaderMap; +use hyper::Response; +use hyper_util::rt::TokioIo; +use pin_project::pin_project; +use sha1::Digest; +use sha1::Sha1; + +type Error = Report; + +#[derive(Debug)] +pub struct IncomingUpgrade { + key: String, + headers: HeaderMap, + on_upgrade: hyper::upgrade::OnUpgrade, + pub permessage_deflate: bool, +} + +impl IncomingUpgrade { + pub fn upgrade(self) -> Result<(Response, UpgradeFut), Error> { + let mut builder = Response::builder() + .status(hyper::StatusCode::SWITCHING_PROTOCOLS) + .header(hyper::header::CONNECTION, "upgrade") + .header(hyper::header::UPGRADE, "websocket") + .header("Sec-WebSocket-Accept", self.key); + + if self.permessage_deflate { + builder = builder.header("Sec-WebSocket-Extensions", "permessage-deflate"); + } + + let response = builder + .body(Body::default()) + .expect("bug: failed to build response"); + + let stream = UpgradeFut { + inner: self.on_upgrade, + headers: self.headers, + }; + + Ok((response, stream)) + } +} + +#[async_trait::async_trait] +impl axum_core::extract::FromRequestParts for IncomingUpgrade +where + S: Sync, +{ + type Rejection = hyper::StatusCode; + + async fn from_request_parts( + parts: &mut http::request::Parts, + _state: &S, + ) -> Result { + let key = parts + .headers + .get("Sec-WebSocket-Key") + .ok_or(hyper::StatusCode::BAD_REQUEST)?; + if parts + .headers + .get("Sec-WebSocket-Version") + .map(|v| v.as_bytes()) + != Some(b"13".as_slice()) + { + return Err(hyper::StatusCode::BAD_REQUEST); + } + + let permessage_deflate = parts + .headers + .get("Sec-WebSocket-Extensions") + .map(|val| { + val.to_str() + .unwrap_or_default() + .to_lowercase() + .contains("permessage-deflate") + }) + .unwrap_or(false); + + let on_upgrade = parts + .extensions + .remove::() + .ok_or(hyper::StatusCode::BAD_REQUEST)?; + Ok(Self { + on_upgrade, + key: sec_websocket_protocol(key.as_bytes()), + headers: parts.headers.clone(), + permessage_deflate, + }) + } +} + +/// A future that resolves to a websocket stream when the associated HTTP upgrade completes. +#[pin_project] +#[derive(Debug)] +pub struct UpgradeFut { + #[pin] + inner: hyper::upgrade::OnUpgrade, + pub headers: HeaderMap, +} + +impl std::future::Future for UpgradeFut { + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + let upgraded = match this.inner.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(x) => x, + }; + Poll::Ready(upgraded.map(|u| TokioIo::new(u)).map_err(|e| e.into())) + } +} + +fn sec_websocket_protocol(key: &[u8]) -> String { + let mut sha1 = Sha1::new(); + sha1.update(key); + sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); // magic string + let result = sha1.finalize(); + STANDARD.encode(&result[..]) +} diff --git a/ratchet_rs/Cargo.toml b/ratchet_rs/Cargo.toml index dd198e3..54fedc2 100644 --- a/ratchet_rs/Cargo.toml +++ b/ratchet_rs/Cargo.toml @@ -15,11 +15,15 @@ default = [] deflate = ["ratchet_deflate"] split = ["ratchet_core/split"] fixture = ["ratchet_core/fixture"] +with_axum = ["ratchet_axum", "axum", "hyper"] [dependencies] +axum = { workspace = true, optional = true } +hyper = { workspace = true, features = ["http1", "server", "client"], optional = true } ratchet_core = { version = "1.0.3", path = "../ratchet_core" } ratchet_ext = { version = "1.0.3", path = "../ratchet_ext" } ratchet_deflate = { version = "1.0.3", path = "../ratchet_deflate", optional = true } +ratchet_axum = { version = "1.0.3", path = "../ratchet_axum", optional = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } log = { workspace = true } @@ -49,4 +53,8 @@ name = "client" required-features = ["split"] [[example]] -name = "server" \ No newline at end of file +name = "server" + +[[example]] +name = "axum" +required-features = ["with_axum"] diff --git a/ratchet_rs/examples/axum.rs b/ratchet_rs/examples/axum.rs new file mode 100644 index 0000000..28898eb --- /dev/null +++ b/ratchet_rs/examples/axum.rs @@ -0,0 +1,40 @@ +use axum::{response::IntoResponse, routing::get, Router}; +use bytes::BytesMut; +use ratchet_axum::{IncomingUpgrade, UpgradeFut}; +use ratchet_core::{Message, NegotiatedExtension, NoExt, PayloadType, Role, WebSocketConfig}; + +#[tokio::main] +async fn main() { + let app = Router::new().route("/", get(ws_handler)); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +async fn handle_client(fut: UpgradeFut) { + let io = fut.await.unwrap(); + let mut websocket = ratchet_rs::WebSocket::from_upgraded( + WebSocketConfig::default(), + io, + NegotiatedExtension::from(NoExt), + BytesMut::new(), + Role::Server, + ); + let mut buf = BytesMut::new(); + + loop { + match websocket.read(&mut buf).await.unwrap() { + Message::Text => { + websocket.write(&mut buf, PayloadType::Text).await.unwrap(); + buf.clear(); + } + _ => break, + } + } +} + +async fn ws_handler(incoming_upgrade: IncomingUpgrade) -> impl IntoResponse { + let (response, fut) = incoming_upgrade.upgrade().unwrap(); + tokio::task::spawn(async move { handle_client(fut).await }); + response +} From 6bb97f00edc939872b4a79a9faa42ac3844d7f97 Mon Sep 17 00:00:00 2001 From: clarkohw Date: Fri, 20 Sep 2024 16:05:36 -0400 Subject: [PATCH 2/2] quick PR feedback --- Cargo.toml | 6 +- ratchet_axum/Cargo.toml | 20 ++++-- {ratchet_rs => ratchet_axum}/examples/axum.rs | 6 +- ratchet_axum/src/lib.rs | 64 +++++++++---------- ratchet_rs/Cargo.toml | 6 -- 5 files changed, 50 insertions(+), 52 deletions(-) rename {ratchet_rs => ratchet_axum}/examples/axum.rs (82%) diff --git a/Cargo.toml b/Cargo.toml index a0bf60e..4bc5b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [workspace] resolver = "2" members = [ - "ratchet_rs", + "ratchet_axum", "ratchet_core", "ratchet_deflate", "ratchet_ext", "ratchet_fixture", - "ratchet_axum", + "ratchet_rs", "ratchet_rs/autobahn/client", "ratchet_rs/autobahn/server", "ratchet_rs/autobahn/split_client", @@ -47,7 +47,7 @@ tracing-subscriber = "0.3.18" hyper = "1.4.1" axum = "0.7.5" axum-core = "0.4.3" -eyre = "0.6.12" hyper-util = "0.1.0" pin-project = "1.1.5" async-trait = "0.1.79" +sha1 = "0.10.4" diff --git a/ratchet_axum/Cargo.toml b/ratchet_axum/Cargo.toml index b363a34..e9a4cb8 100644 --- a/ratchet_axum/Cargo.toml +++ b/ratchet_axum/Cargo.toml @@ -3,7 +3,7 @@ name = "ratchet_axum" description = "Axum Integration for Ratchet" readme = "README.md" repository = "https://github.com/swimos/ratchet/" -version.workspace = true +version = "0.1.0" edition.workspace = true authors.workspace = true license.workspace = true @@ -11,14 +11,20 @@ categories.workspace = true [dependencies] -ratchet_core = { version = "1.0.3", path = "../ratchet_core" } -ratchet_ext = { version = "1.0.3", path = "../ratchet_ext" } +ratchet_rs = { version = "1.0.3", path = "../ratchet_rs" } hyper = { workspace = true } axum-core = { workspace = true } -eyre = { workspace = true} hyper-util = { workspace = true , features = ["tokio"]} pin-project = { workspace = true } async-trait = { workspace = true } -base64 = "0.22.1" -http = "1.1.0" -sha1 = "0.10.1" +base64 = { workspace = true } +http = { workspace = true } +sha1 = { workspace = true } + +[dev-dependencies] +axum = { workspace = true } +tokio = { workspace = true, features = ["full"] } +bytes = { workspace = true } + +[[example]] +name = "axum" diff --git a/ratchet_rs/examples/axum.rs b/ratchet_axum/examples/axum.rs similarity index 82% rename from ratchet_rs/examples/axum.rs rename to ratchet_axum/examples/axum.rs index 28898eb..8235fca 100644 --- a/ratchet_rs/examples/axum.rs +++ b/ratchet_axum/examples/axum.rs @@ -1,7 +1,7 @@ use axum::{response::IntoResponse, routing::get, Router}; use bytes::BytesMut; -use ratchet_axum::{IncomingUpgrade, UpgradeFut}; -use ratchet_core::{Message, NegotiatedExtension, NoExt, PayloadType, Role, WebSocketConfig}; +use ratchet_axum::{UpgradeFut, WebSocketUpgrade}; +use ratchet_rs::{Message, NegotiatedExtension, NoExt, PayloadType, Role, WebSocketConfig}; #[tokio::main] async fn main() { @@ -33,7 +33,7 @@ async fn handle_client(fut: UpgradeFut) { } } -async fn ws_handler(incoming_upgrade: IncomingUpgrade) -> impl IntoResponse { +async fn ws_handler(incoming_upgrade: WebSocketUpgrade) -> impl IntoResponse { let (response, fut) = incoming_upgrade.upgrade().unwrap(); tokio::task::spawn(async move { handle_client(fut).await }); response diff --git a/ratchet_axum/src/lib.rs b/ratchet_axum/src/lib.rs index 4dfaff4..fc28323 100644 --- a/ratchet_axum/src/lib.rs +++ b/ratchet_axum/src/lib.rs @@ -1,23 +1,3 @@ -// Port of hyper_tunstenite for fastwebsockets. -// https://github.com/de-vri-es/hyper-tungstenite-rs -// -// Copyright 2021, Maarten de Vries maarten@de-vri.es -// BSD 2-Clause "Simplified" License -// -// Copyright 2023 Divy Srivastava -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - //todo missing docs #![deny( @@ -40,7 +20,6 @@ use axum_core::body::Body; use base64; use base64::engine::general_purpose::STANDARD; use base64::Engine; -use eyre::Report; use http::HeaderMap; use hyper::Response; use hyper_util::rt::TokioIo; @@ -48,26 +27,33 @@ use pin_project::pin_project; use sha1::Digest; use sha1::Sha1; -type Error = Report; +const HEADER_CONNECTION: &str = "upgrade"; +const HEADER_UPGRADE: &str = "websocket"; +const WEBSOCKET_VERSION: &[u8] = b"13"; + +type Error = hyper::Error; #[derive(Debug)] -pub struct IncomingUpgrade { +pub struct WebSocketUpgrade { key: String, headers: HeaderMap, on_upgrade: hyper::upgrade::OnUpgrade, pub permessage_deflate: bool, } -impl IncomingUpgrade { +impl WebSocketUpgrade { pub fn upgrade(self) -> Result<(Response, UpgradeFut), Error> { let mut builder = Response::builder() .status(hyper::StatusCode::SWITCHING_PROTOCOLS) - .header(hyper::header::CONNECTION, "upgrade") - .header(hyper::header::UPGRADE, "websocket") - .header("Sec-WebSocket-Accept", self.key); + .header(hyper::header::CONNECTION, HEADER_CONNECTION) + .header(hyper::header::UPGRADE, HEADER_UPGRADE) + .header(hyper::header::SEC_WEBSOCKET_ACCEPT, self.key); if self.permessage_deflate { - builder = builder.header("Sec-WebSocket-Extensions", "permessage-deflate"); + builder = builder.header( + hyper::header::SEC_WEBSOCKET_EXTENSIONS, + "permessage-deflate", + ); } let response = builder @@ -81,10 +67,20 @@ impl IncomingUpgrade { Ok((response, stream)) } + + // pub fn upgrade_2(self, f: F) -> Response + // where + // F: FnOnce(UpgradedServer, E>) -> Fut, + // Fut: Future, + // E: Extension, + // { + // + // + // } } #[async_trait::async_trait] -impl axum_core::extract::FromRequestParts for IncomingUpgrade +impl axum_core::extract::FromRequestParts for WebSocketUpgrade where S: Sync, { @@ -96,20 +92,21 @@ where ) -> Result { let key = parts .headers - .get("Sec-WebSocket-Key") + .get(http::header::SEC_WEBSOCKET_KEY) .ok_or(hyper::StatusCode::BAD_REQUEST)?; + if parts .headers - .get("Sec-WebSocket-Version") + .get(http::header::SEC_WEBSOCKET_VERSION) .map(|v| v.as_bytes()) - != Some(b"13".as_slice()) + != Some(WEBSOCKET_VERSION) { return Err(hyper::StatusCode::BAD_REQUEST); } let permessage_deflate = parts .headers - .get("Sec-WebSocket-Extensions") + .get(http::header::SEC_WEBSOCKET_EXTENSIONS) .map(|val| { val.to_str() .unwrap_or_default() @@ -122,6 +119,7 @@ where .extensions .remove::() .ok_or(hyper::StatusCode::BAD_REQUEST)?; + Ok(Self { on_upgrade, key: sec_websocket_protocol(key.as_bytes()), diff --git a/ratchet_rs/Cargo.toml b/ratchet_rs/Cargo.toml index 54fedc2..38fe70b 100644 --- a/ratchet_rs/Cargo.toml +++ b/ratchet_rs/Cargo.toml @@ -15,7 +15,6 @@ default = [] deflate = ["ratchet_deflate"] split = ["ratchet_core/split"] fixture = ["ratchet_core/fixture"] -with_axum = ["ratchet_axum", "axum", "hyper"] [dependencies] axum = { workspace = true, optional = true } @@ -23,7 +22,6 @@ hyper = { workspace = true, features = ["http1", "server", "client"], optional = ratchet_core = { version = "1.0.3", path = "../ratchet_core" } ratchet_ext = { version = "1.0.3", path = "../ratchet_ext" } ratchet_deflate = { version = "1.0.3", path = "../ratchet_deflate", optional = true } -ratchet_axum = { version = "1.0.3", path = "../ratchet_axum", optional = true } tracing-subscriber = { workspace = true, features = ["env-filter"] } log = { workspace = true } @@ -54,7 +52,3 @@ required-features = ["split"] [[example]] name = "server" - -[[example]] -name = "axum" -required-features = ["with_axum"]