diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..c9c25fab --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,114 @@ +name: Rust + +on: + push: + # Run jobs when commits are pushed to + # develop or release-like branches: + branches: + - develop + - release* + pull_request: + # Run jobs for any PR that wants to merge + # to develop: + branches: + - develop + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Check Code + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Build + uses: actions-rs/cargo@v1.0.3 + with: + command: check + args: --all-targets + + fmt: + name: Run rustfmt + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + components: clippy, rustfmt + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Cargo fmt + uses: actions-rs/cargo@v1.0.3 + with: + command: fmt + args: --all -- --check + + docs: + name: Check Documentation + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Check internal documentation links + run: RUSTDOCFLAGS="--deny broken_intra_doc_links" cargo doc --verbose --workspace --no-deps --document-private-items + + tests: + name: Run tests + runs-on: ubuntu-latest + steps: + - name: Checkout sources + uses: actions/checkout@v2.3.4 + + - name: Install Rust stable toolchain + uses: actions-rs/toolchain@v1.0.7 + with: + profile: minimal + toolchain: stable + override: true + + - name: Rust Cache + uses: Swatinem/rust-cache@v1.3.0 + + - name: Cargo build + uses: actions-rs/cargo@v1.0.3 + with: + command: build + args: --workspace + + - name: Cargo test + uses: actions-rs/cargo@v1.0.3 + with: + command: test + diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 21325699..00000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: Rust - -on: - push: - branches: [ develop ] - pull_request: - branches: [ develop ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Build - run: cargo build --examples diff --git a/examples/autobahn_client.rs b/examples/autobahn_client.rs index 8b8a9f0b..ecb1832a 100644 --- a/examples/autobahn_client.rs +++ b/examples/autobahn_client.rs @@ -15,7 +15,7 @@ // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; -use soketto::{BoxedError, connection, handshake}; +use soketto::{connection, handshake, BoxedError}; use std::str::FromStr; use tokio::net::TcpStream; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; @@ -24,76 +24,76 @@ const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION"); #[tokio::main] async fn main() -> Result<(), BoxedError> { - let n = num_of_cases().await?; - for i in 1 ..= n { - if let Err(e) = run_case(i).await { - log::error!("case {}: {:?}", i, e) - } - } - update_report().await?; - Ok(()) + let n = num_of_cases().await?; + for i in 1..=n { + if let Err(e) = run_case(i).await { + log::error!("case {}: {:?}", i, e) + } + } + update_report().await?; + Ok(()) } async fn num_of_cases() -> Result { - let socket = TcpStream::connect("127.0.0.1:9001").await?; - let mut client = new_client(socket, "/getCaseCount"); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); - let (_, mut receiver) = client.into_builder().finish(); - let mut data = Vec::new(); - let kind = receiver.receive_data(&mut data).await?; - assert!(kind.is_text()); - let num = usize::from_str(std::str::from_utf8(&data)?)?; - log::info!("{} cases to run", num); - Ok(num) + let socket = TcpStream::connect("127.0.0.1:9001").await?; + let mut client = new_client(socket, "/getCaseCount"); + assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); + let (_, mut receiver) = client.into_builder().finish(); + let mut data = Vec::new(); + let kind = receiver.receive_data(&mut data).await?; + assert!(kind.is_text()); + let num = usize::from_str(std::str::from_utf8(&data)?)?; + log::info!("{} cases to run", num); + Ok(num) } async fn run_case(n: usize) -> Result<(), BoxedError> { - log::info!("running case {}", n); - let resource = format!("/runCase?case={}&agent=soketto-{}", n, SOKETTO_VERSION); - let socket = TcpStream::connect("127.0.0.1:9001").await?; - let mut client = new_client(socket, &resource); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); - let (mut sender, mut receiver) = client.into_builder().finish(); - let mut message = Vec::new(); - loop { - message.clear(); - match receiver.receive_data(&mut message).await { - Ok(soketto::Data::Binary(n)) => { - assert_eq!(n, message.len()); - sender.send_binary_mut(&mut message).await?; - sender.flush().await? - } - Ok(soketto::Data::Text(n)) => { - assert_eq!(n, message.len()); - sender.send_text(std::str::from_utf8(&message)?).await?; - sender.flush().await? - } - Err(connection::Error::Closed) => return Ok(()), - Err(e) => return Err(e.into()) - } - } + log::info!("running case {}", n); + let resource = format!("/runCase?case={}&agent=soketto-{}", n, SOKETTO_VERSION); + let socket = TcpStream::connect("127.0.0.1:9001").await?; + let mut client = new_client(socket, &resource); + assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); + let (mut sender, mut receiver) = client.into_builder().finish(); + let mut message = Vec::new(); + loop { + message.clear(); + match receiver.receive_data(&mut message).await { + Ok(soketto::Data::Binary(n)) => { + assert_eq!(n, message.len()); + sender.send_binary_mut(&mut message).await?; + sender.flush().await? + } + Ok(soketto::Data::Text(n)) => { + assert_eq!(n, message.len()); + sender.send_text(std::str::from_utf8(&message)?).await?; + sender.flush().await? + } + Err(connection::Error::Closed) => return Ok(()), + Err(e) => return Err(e.into()), + } + } } async fn update_report() -> Result<(), BoxedError> { - log::info!("requesting report generation"); - let resource = format!("/updateReports?agent=soketto-{}", SOKETTO_VERSION); - let socket = TcpStream::connect("127.0.0.1:9001").await?; - let mut client = new_client(socket, &resource); - assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted {..})); - client.into_builder().finish().0.close().await?; - Ok(()) + log::info!("requesting report generation"); + let resource = format!("/updateReports?agent=soketto-{}", SOKETTO_VERSION); + let socket = TcpStream::connect("127.0.0.1:9001").await?; + let mut client = new_client(socket, &resource); + assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); + client.into_builder().finish().0.close().await?; + Ok(()) } #[cfg(not(feature = "deflate"))] fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { - handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "127.0.0.1:9001", path) + handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "127.0.0.1:9001", path) } #[cfg(feature = "deflate")] fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { - let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(64 * 1024, socket.compat())); - let mut client = handshake::Client::new(socket, "127.0.0.1:9001", path); - let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Client); - client.add_extension(Box::new(deflate)); - client + let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(64 * 1024, socket.compat())); + let mut client = handshake::Client::new(socket, "127.0.0.1:9001", path); + let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Client); + client.add_extension(Box::new(deflate)); + client } diff --git a/examples/autobahn_server.rs b/examples/autobahn_server.rs index c0bdc8ad..893a020b 100644 --- a/examples/autobahn_server.rs +++ b/examples/autobahn_server.rs @@ -15,62 +15,62 @@ // See https://github.com/crossbario/autobahn-testsuite for details. use futures::io::{BufReader, BufWriter}; -use soketto::{BoxedError, connection, handshake}; +use soketto::{connection, handshake, BoxedError}; use tokio::net::{TcpListener, TcpStream}; -use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; +use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; #[tokio::main] async fn main() -> Result<(), BoxedError> { - let listener = TcpListener::bind("127.0.0.1:9001").await?; - let mut incoming = TcpListenerStream::new(listener); - while let Some(socket) = incoming.next().await { - let mut server = new_server(socket?); - let key = { - let req = server.receive_request().await?; - req.key() - }; - let accept = handshake::server::Response::Accept { key, protocol: None }; - server.send_response(&accept).await?; - let (mut sender, mut receiver) = server.into_builder().finish(); - let mut message = Vec::new(); - loop { - message.clear(); - match receiver.receive_data(&mut message).await { - Ok(soketto::Data::Binary(n)) => { - assert_eq!(n, message.len()); - sender.send_binary_mut(&mut message).await?; - sender.flush().await? - } - Ok(soketto::Data::Text(n)) => { - assert_eq!(n, message.len()); - if let Ok(txt) = std::str::from_utf8(&message) { - sender.send_text(txt).await?; - sender.flush().await? - } else { - break - } - } - Err(connection::Error::Closed) => break, - Err(e) => { - log::error!("connection error: {}", e); - break - } - } - } - } - Ok(()) + let listener = TcpListener::bind("127.0.0.1:9001").await?; + let mut incoming = TcpListenerStream::new(listener); + while let Some(socket) = incoming.next().await { + let mut server = new_server(socket?); + let key = { + let req = server.receive_request().await?; + req.key() + }; + let accept = handshake::server::Response::Accept { key, protocol: None }; + server.send_response(&accept).await?; + let (mut sender, mut receiver) = server.into_builder().finish(); + let mut message = Vec::new(); + loop { + message.clear(); + match receiver.receive_data(&mut message).await { + Ok(soketto::Data::Binary(n)) => { + assert_eq!(n, message.len()); + sender.send_binary_mut(&mut message).await?; + sender.flush().await? + } + Ok(soketto::Data::Text(n)) => { + assert_eq!(n, message.len()); + if let Ok(txt) = std::str::from_utf8(&message) { + sender.send_text(txt).await?; + sender.flush().await? + } else { + break; + } + } + Err(connection::Error::Closed) => break, + Err(e) => { + log::error!("connection error: {}", e); + break; + } + } + } + } + Ok(()) } #[cfg(not(feature = "deflate"))] fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { - handshake::Server::new(BufReader::new(BufWriter::new(socket.compat()))) + handshake::Server::new(BufReader::new(BufWriter::new(socket.compat()))) } #[cfg(feature = "deflate")] fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { - let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(16 * 1024, socket.compat())); - let mut server = handshake::Server::new(socket); - let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); - server.add_extension(Box::new(deflate)); - server + let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(16 * 1024, socket.compat())); + let mut server = handshake::Server::new(socket); + let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); + server.add_extension(Box::new(deflate)); + server } diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..c699603f --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +hard_tabs = true +max_width = 120 +use_small_heuristics = "Max" +edition = "2018" diff --git a/src/base.rs b/src/base.rs index d5bfcd82..5e80ed1f 100644 --- a/src/base.rs +++ b/src/base.rs @@ -31,89 +31,89 @@ pub(crate) const MAX_CTRL_BODY_SIZE: u64 = 125; /// Operation codes defined in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.2). #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)] pub enum OpCode { - /// A continuation frame of a fragmented message. - Continue, - /// A text data frame. - Text, - /// A binary data frame. - Binary, - /// A close control frame. - Close, - /// A ping control frame. - Ping, - /// A pong control frame. - Pong, - /// A reserved op code. - Reserved3, - /// A reserved op code. - Reserved4, - /// A reserved op code. - Reserved5, - /// A reserved op code. - Reserved6, - /// A reserved op code. - Reserved7, - /// A reserved op code. - Reserved11, - /// A reserved op code. - Reserved12, - /// A reserved op code. - Reserved13, - /// A reserved op code. - Reserved14, - /// A reserved op code. - Reserved15 + /// A continuation frame of a fragmented message. + Continue, + /// A text data frame. + Text, + /// A binary data frame. + Binary, + /// A close control frame. + Close, + /// A ping control frame. + Ping, + /// A pong control frame. + Pong, + /// A reserved op code. + Reserved3, + /// A reserved op code. + Reserved4, + /// A reserved op code. + Reserved5, + /// A reserved op code. + Reserved6, + /// A reserved op code. + Reserved7, + /// A reserved op code. + Reserved11, + /// A reserved op code. + Reserved12, + /// A reserved op code. + Reserved13, + /// A reserved op code. + Reserved14, + /// A reserved op code. + Reserved15, } impl OpCode { - /// Is this a control opcode? - pub fn is_control(self) -> bool { - if let OpCode::Close | OpCode::Ping | OpCode::Pong = self { - true - } else { - false - } - } - - /// Is this opcode reserved? - pub fn is_reserved(self) -> bool { - match self { - OpCode::Reserved3 - | OpCode::Reserved4 - | OpCode::Reserved5 - | OpCode::Reserved6 - | OpCode::Reserved7 - | OpCode::Reserved11 - | OpCode::Reserved12 - | OpCode::Reserved13 - | OpCode::Reserved14 - | OpCode::Reserved15 => true, - _ => false - } - } + /// Is this a control opcode? + pub fn is_control(self) -> bool { + if let OpCode::Close | OpCode::Ping | OpCode::Pong = self { + true + } else { + false + } + } + + /// Is this opcode reserved? + pub fn is_reserved(self) -> bool { + match self { + OpCode::Reserved3 + | OpCode::Reserved4 + | OpCode::Reserved5 + | OpCode::Reserved6 + | OpCode::Reserved7 + | OpCode::Reserved11 + | OpCode::Reserved12 + | OpCode::Reserved13 + | OpCode::Reserved14 + | OpCode::Reserved15 => true, + _ => false, + } + } } impl fmt::Display for OpCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - OpCode::Continue => f.write_str("Continue"), - OpCode::Text => f.write_str("Text"), - OpCode::Binary => f.write_str("Binary"), - OpCode::Close => f.write_str("Close"), - OpCode::Ping => f.write_str("Ping"), - OpCode::Pong => f.write_str("Pong"), - OpCode::Reserved3 => f.write_str("Reserved:3"), - OpCode::Reserved4 => f.write_str("Reserved:4"), - OpCode::Reserved5 => f.write_str("Reserved:5"), - OpCode::Reserved6 => f.write_str("Reserved:6"), - OpCode::Reserved7 => f.write_str("Reserved:7"), - OpCode::Reserved11 => f.write_str("Reserved:11"), - OpCode::Reserved12 => f.write_str("Reserved:12"), - OpCode::Reserved13 => f.write_str("Reserved:13"), - OpCode::Reserved14 => f.write_str("Reserved:14"), - OpCode::Reserved15 => f.write_str("Reserved:15") - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + OpCode::Continue => f.write_str("Continue"), + OpCode::Text => f.write_str("Text"), + OpCode::Binary => f.write_str("Binary"), + OpCode::Close => f.write_str("Close"), + OpCode::Ping => f.write_str("Ping"), + OpCode::Pong => f.write_str("Pong"), + OpCode::Reserved3 => f.write_str("Reserved:3"), + OpCode::Reserved4 => f.write_str("Reserved:4"), + OpCode::Reserved5 => f.write_str("Reserved:5"), + OpCode::Reserved6 => f.write_str("Reserved:6"), + OpCode::Reserved7 => f.write_str("Reserved:7"), + OpCode::Reserved11 => f.write_str("Reserved:11"), + OpCode::Reserved12 => f.write_str("Reserved:12"), + OpCode::Reserved13 => f.write_str("Reserved:13"), + OpCode::Reserved14 => f.write_str("Reserved:14"), + OpCode::Reserved15 => f.write_str("Reserved:15"), + } + } } /// Error returned by `OpCode::try_from` if an unknown opcode @@ -122,60 +122,60 @@ impl fmt::Display for OpCode { pub struct UnknownOpCode(()); impl fmt::Display for UnknownOpCode { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("unknown opcode") - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("unknown opcode") + } } impl std::error::Error for UnknownOpCode {} impl TryFrom for OpCode { - type Error = UnknownOpCode; - - fn try_from(val: u8) -> Result { - match val { - 0 => Ok(OpCode::Continue), - 1 => Ok(OpCode::Text), - 2 => Ok(OpCode::Binary), - 3 => Ok(OpCode::Reserved3), - 4 => Ok(OpCode::Reserved4), - 5 => Ok(OpCode::Reserved5), - 6 => Ok(OpCode::Reserved6), - 7 => Ok(OpCode::Reserved7), - 8 => Ok(OpCode::Close), - 9 => Ok(OpCode::Ping), - 10 => Ok(OpCode::Pong), - 11 => Ok(OpCode::Reserved11), - 12 => Ok(OpCode::Reserved12), - 13 => Ok(OpCode::Reserved13), - 14 => Ok(OpCode::Reserved14), - 15 => Ok(OpCode::Reserved15), - _ => Err(UnknownOpCode(())) - } - } + type Error = UnknownOpCode; + + fn try_from(val: u8) -> Result { + match val { + 0 => Ok(OpCode::Continue), + 1 => Ok(OpCode::Text), + 2 => Ok(OpCode::Binary), + 3 => Ok(OpCode::Reserved3), + 4 => Ok(OpCode::Reserved4), + 5 => Ok(OpCode::Reserved5), + 6 => Ok(OpCode::Reserved6), + 7 => Ok(OpCode::Reserved7), + 8 => Ok(OpCode::Close), + 9 => Ok(OpCode::Ping), + 10 => Ok(OpCode::Pong), + 11 => Ok(OpCode::Reserved11), + 12 => Ok(OpCode::Reserved12), + 13 => Ok(OpCode::Reserved13), + 14 => Ok(OpCode::Reserved14), + 15 => Ok(OpCode::Reserved15), + _ => Err(UnknownOpCode(())), + } + } } impl From for u8 { - fn from(opcode: OpCode) -> u8 { - match opcode { - OpCode::Continue => 0, - OpCode::Text => 1, - OpCode::Binary => 2, - OpCode::Close => 8, - OpCode::Ping => 9, - OpCode::Pong => 10, - OpCode::Reserved3 => 3, - OpCode::Reserved4 => 4, - OpCode::Reserved5 => 5, - OpCode::Reserved6 => 6, - OpCode::Reserved7 => 7, - OpCode::Reserved11 => 11, - OpCode::Reserved12 => 12, - OpCode::Reserved13 => 13, - OpCode::Reserved14 => 14, - OpCode::Reserved15 => 15 - } - } + fn from(opcode: OpCode) -> u8 { + match opcode { + OpCode::Continue => 0, + OpCode::Text => 1, + OpCode::Binary => 2, + OpCode::Close => 8, + OpCode::Ping => 9, + OpCode::Pong => 10, + OpCode::Reserved3 => 3, + OpCode::Reserved4 => 4, + OpCode::Reserved5 => 5, + OpCode::Reserved6 => 6, + OpCode::Reserved7 => 7, + OpCode::Reserved11 => 11, + OpCode::Reserved12 => 12, + OpCode::Reserved13 => 13, + OpCode::Reserved14 => 14, + OpCode::Reserved15 => 15, + } + } } // Frame header /////////////////////////////////////////////////////////////////////////////////// @@ -183,132 +183,126 @@ impl From for u8 { /// A websocket base frame header, i.e. everything but the payload. #[derive(Debug, Clone)] pub struct Header { - fin: bool, - rsv1: bool, - rsv2: bool, - rsv3: bool, - masked: bool, - opcode: OpCode, - mask: u32, - payload_len: usize + fin: bool, + rsv1: bool, + rsv2: bool, + rsv3: bool, + masked: bool, + opcode: OpCode, + mask: u32, + payload_len: usize, } impl fmt::Display for Header { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", - self.opcode, - self.fin as u8, - self.rsv1 as u8, - self.rsv2 as u8, - self.rsv3 as u8, - self.masked as u8, - self.mask, - self.payload_len) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", + self.opcode, + self.fin as u8, + self.rsv1 as u8, + self.rsv2 as u8, + self.rsv3 as u8, + self.masked as u8, + self.mask, + self.payload_len + ) + } } impl Header { - /// Create a new frame header with a given [`OpCode`]. - pub fn new(oc: OpCode) -> Self { - Header { - fin: true, - rsv1: false, - rsv2: false, - rsv3: false, - masked: false, - opcode: oc, - mask: 0, - payload_len: 0 - } - } - - /// Is the `fin` flag set? - pub fn is_fin(&self) -> bool { - self.fin - } - - /// Set the `fin` flag. - pub fn set_fin(&mut self, fin: bool) -> &mut Self { - self.fin = fin; - self - } - - /// Is the `rsv1` flag set? - pub fn is_rsv1(&self) -> bool { - self.rsv1 - } - - /// Set the `rsv1` flag. - pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self { - self.rsv1 = rsv1; - self - } - - /// Is the `rsv2` flag set? - pub fn is_rsv2(&self) -> bool { - self.rsv2 - } - - /// Set the `rsv2` flag. - pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self { - self.rsv2 = rsv2; - self - } - - /// Is the `rsv3` flag set? - pub fn is_rsv3(&self) -> bool { - self.rsv3 - } - - /// Set the `rsv3` flag. - pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self { - self.rsv3 = rsv3; - self - } - - /// Is the `masked` flag set? - pub fn is_masked(&self) -> bool { - self.masked - } - - /// Set the `masked` flag. - pub fn set_masked(&mut self, masked: bool) -> &mut Self { - self.masked = masked; - self - } - - /// Get the `opcode`. - pub fn opcode(&self) -> OpCode { - self.opcode - } - - /// Set the `opcode` - pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self { - self.opcode = opcode; - self - } - - /// Get the `mask`. - pub fn mask(&self) -> u32 { - self.mask - } - - /// Set the `mask` - pub fn set_mask(&mut self, mask: u32) -> &mut Self { - self.mask = mask; - self - } - - /// Get the payload length. - pub fn payload_len(&self) -> usize { - self.payload_len - } - - /// Set the payload length. - pub fn set_payload_len(&mut self, len: usize) -> &mut Self { - self.payload_len = len; - self - } + /// Create a new frame header with a given [`OpCode`]. + pub fn new(oc: OpCode) -> Self { + Header { fin: true, rsv1: false, rsv2: false, rsv3: false, masked: false, opcode: oc, mask: 0, payload_len: 0 } + } + + /// Is the `fin` flag set? + pub fn is_fin(&self) -> bool { + self.fin + } + + /// Set the `fin` flag. + pub fn set_fin(&mut self, fin: bool) -> &mut Self { + self.fin = fin; + self + } + + /// Is the `rsv1` flag set? + pub fn is_rsv1(&self) -> bool { + self.rsv1 + } + + /// Set the `rsv1` flag. + pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self { + self.rsv1 = rsv1; + self + } + + /// Is the `rsv2` flag set? + pub fn is_rsv2(&self) -> bool { + self.rsv2 + } + + /// Set the `rsv2` flag. + pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self { + self.rsv2 = rsv2; + self + } + + /// Is the `rsv3` flag set? + pub fn is_rsv3(&self) -> bool { + self.rsv3 + } + + /// Set the `rsv3` flag. + pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self { + self.rsv3 = rsv3; + self + } + + /// Is the `masked` flag set? + pub fn is_masked(&self) -> bool { + self.masked + } + + /// Set the `masked` flag. + pub fn set_masked(&mut self, masked: bool) -> &mut Self { + self.masked = masked; + self + } + + /// Get the `opcode`. + pub fn opcode(&self) -> OpCode { + self.opcode + } + + /// Set the `opcode` + pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self { + self.opcode = opcode; + self + } + + /// Get the `mask`. + pub fn mask(&self) -> u32 { + self.mask + } + + /// Set the `mask` + pub fn set_mask(&mut self, mask: u32) -> &mut Self { + self.mask = mask; + self + } + + /// Get the payload length. + pub fn payload_len(&self) -> usize { + self.payload_len + } + + /// Set the payload length. + pub fn set_payload_len(&mut self, len: usize) -> &mut Self { + self.payload_len = len; + self + } } // Base codec ////////////////////////////////////////////////////////////////////////////////////. @@ -326,420 +320,404 @@ const EIGHT_EXT: u8 = 127; /// [base]: https://tools.ietf.org/html/rfc6455#section-5.2 #[derive(Debug, Clone)] pub struct Codec { - /// Maximum size of payload data per frame. - max_data_size: usize, - /// Bits reserved by an extension. - reserved_bits: u8, - /// Scratch buffer used during header encoding. - header_buffer: [u8; MAX_HEADER_SIZE] + /// Maximum size of payload data per frame. + max_data_size: usize, + /// Bits reserved by an extension. + reserved_bits: u8, + /// Scratch buffer used during header encoding. + header_buffer: [u8; MAX_HEADER_SIZE], } impl Default for Codec { - fn default() -> Self { - Codec { - max_data_size: 256 * 1024 * 1024, - reserved_bits: 0, - header_buffer: [0; MAX_HEADER_SIZE] - } - } + fn default() -> Self { + Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, header_buffer: [0; MAX_HEADER_SIZE] } + } } impl Codec { - /// Create a new base frame codec. - /// - /// The codec will support decoding payload lengths up to 256 MiB - /// (use `set_max_data_size` to change this value). - pub fn new() -> Self { - Codec::default() - } - - /// Get the configured maximum payload length. - pub fn max_data_size(&self) -> usize { - self.max_data_size - } - - /// Limit the maximum size of payload data to `size` bytes. - pub fn set_max_data_size(&mut self, size: usize) -> &mut Self { - self.max_data_size = size; - self - } - - /// The reserved bits currently configured. - pub fn reserved_bits(&self) -> (bool, bool, bool) { - let r = self.reserved_bits; - (r & 4 == 4, r & 2 == 2, r & 1 == 1) - } - - /// Add to the reserved bits in use. - pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self { - let (r1, r2, r3) = bits; - self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8; - self - } - - /// Reset the reserved bits. - pub fn clear_reserved_bits(&mut self) { - self.reserved_bits = 0 - } - - /// Decode a websocket frame header. - pub fn decode_header(&self, bytes: &[u8]) -> Result, Error> { - if bytes.len() < 2 { - return Ok(Parsing::NeedMore(2 - bytes.len())) - } - - let first = bytes[0]; - let second = bytes[1]; - let mut offset = 2; - - let fin = first & 0x80 != 0; - let opcode = OpCode::try_from(first & 0xF)?; - - if opcode.is_reserved() { - return Err(Error::ReservedOpCode) - } - - if opcode.is_control() && !fin { - return Err(Error::FragmentedControl) - } - - let mut header = Header::new(opcode); - header.set_fin(fin); - - let rsv1 = first & 0x40 != 0; - if rsv1 && (self.reserved_bits & 4 == 0) { - return Err(Error::InvalidReservedBit(1)) - } - header.set_rsv1(rsv1); - - let rsv2 = first & 0x20 != 0; - if rsv2 && (self.reserved_bits & 2 == 0) { - return Err(Error::InvalidReservedBit(2)) - } - header.set_rsv2(rsv2); - - let rsv3 = first & 0x10 != 0; - if rsv3 && (self.reserved_bits & 1 == 0) { - return Err(Error::InvalidReservedBit(3)) - } - header.set_rsv3(rsv3); - header.set_masked(second & 0x80 != 0); - - let len: u64 = match second & 0x7F { - TWO_EXT => { - if bytes.len() < offset + 2 { - return Ok(Parsing::NeedMore(offset + 2 - bytes.len())) - } - let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]); - offset += 2; - u64::from(len) - } - EIGHT_EXT => { - if bytes.len() < offset + 8 { - return Ok(Parsing::NeedMore(offset + 8 - bytes.len())) - } - let mut b = [0; 8]; - b.copy_from_slice(&bytes[offset .. offset + 8]); - offset += 8; - u64::from_be_bytes(b) - } - n => u64::from(n) - }; - - if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() { - return Err(Error::InvalidControlFrameLen) - } - - let len: usize = - if len > as_u64(self.max_data_size) { - return Err(Error::PayloadTooLarge { - actual: len, - maximum: as_u64(self.max_data_size) - }) - } else { - len as usize - }; - - header.set_payload_len(len); - - if header.is_masked() { - if bytes.len() < offset + 4 { - return Ok(Parsing::NeedMore(offset + 4 - bytes.len())) - } - let mut b = [0; 4]; - b.copy_from_slice(&bytes[offset .. offset + 4]); - offset += 4; - header.set_mask(u32::from_be_bytes(b)); - } - - Ok(Parsing::Done { value: header, offset }) - } - - /// Encode a websocket frame header. - pub fn encode_header(&mut self, header: &Header) -> &[u8] { - let mut offset = 0; - - let mut first_byte = 0_u8; - if header.is_fin() { - first_byte |= 0x80 - } - if header.is_rsv1() { - first_byte |= 0x40 - } - if header.is_rsv2() { - first_byte |= 0x20 - } - if header.is_rsv3() { - first_byte |= 0x10 - } - - let opcode: u8 = header.opcode().into(); - first_byte |= opcode; - - self.header_buffer[offset] = first_byte; - offset += 1; - - let mut second_byte = 0_u8; - if header.is_masked() { - second_byte |= 0x80 - } - - let len = header.payload_len(); - - if len < usize::from(TWO_EXT) { - second_byte |= len as u8; - self.header_buffer[offset] = second_byte; - offset += 1; - } else if len <= usize::from(u16::max_value()) { - second_byte |= TWO_EXT; - self.header_buffer[offset] = second_byte; - offset += 1; - self.header_buffer[offset .. offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); - offset += 2; - } else { - second_byte |= EIGHT_EXT; - self.header_buffer[offset] = second_byte; - offset += 1; - self.header_buffer[offset .. offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); - offset += 8; - } - - if header.is_masked() { - self.header_buffer[offset .. offset + 4].copy_from_slice(&header.mask().to_be_bytes()); - offset += 4; - } - - &self.header_buffer[.. offset] - } - - /// Use the given header's mask and apply it to the data. - pub fn apply_mask(header: &Header, data: &mut [u8]) { - if header.is_masked() { - let mask = header.mask().to_be_bytes(); - for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) { - *byte ^= key; - } - } - } + /// Create a new base frame codec. + /// + /// The codec will support decoding payload lengths up to 256 MiB + /// (use `set_max_data_size` to change this value). + pub fn new() -> Self { + Codec::default() + } + + /// Get the configured maximum payload length. + pub fn max_data_size(&self) -> usize { + self.max_data_size + } + + /// Limit the maximum size of payload data to `size` bytes. + pub fn set_max_data_size(&mut self, size: usize) -> &mut Self { + self.max_data_size = size; + self + } + + /// The reserved bits currently configured. + pub fn reserved_bits(&self) -> (bool, bool, bool) { + let r = self.reserved_bits; + (r & 4 == 4, r & 2 == 2, r & 1 == 1) + } + + /// Add to the reserved bits in use. + pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self { + let (r1, r2, r3) = bits; + self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8; + self + } + + /// Reset the reserved bits. + pub fn clear_reserved_bits(&mut self) { + self.reserved_bits = 0 + } + + /// Decode a websocket frame header. + pub fn decode_header(&self, bytes: &[u8]) -> Result, Error> { + if bytes.len() < 2 { + return Ok(Parsing::NeedMore(2 - bytes.len())); + } + + let first = bytes[0]; + let second = bytes[1]; + let mut offset = 2; + + let fin = first & 0x80 != 0; + let opcode = OpCode::try_from(first & 0xF)?; + + if opcode.is_reserved() { + return Err(Error::ReservedOpCode); + } + + if opcode.is_control() && !fin { + return Err(Error::FragmentedControl); + } + + let mut header = Header::new(opcode); + header.set_fin(fin); + + let rsv1 = first & 0x40 != 0; + if rsv1 && (self.reserved_bits & 4 == 0) { + return Err(Error::InvalidReservedBit(1)); + } + header.set_rsv1(rsv1); + + let rsv2 = first & 0x20 != 0; + if rsv2 && (self.reserved_bits & 2 == 0) { + return Err(Error::InvalidReservedBit(2)); + } + header.set_rsv2(rsv2); + + let rsv3 = first & 0x10 != 0; + if rsv3 && (self.reserved_bits & 1 == 0) { + return Err(Error::InvalidReservedBit(3)); + } + header.set_rsv3(rsv3); + header.set_masked(second & 0x80 != 0); + + let len: u64 = match second & 0x7F { + TWO_EXT => { + if bytes.len() < offset + 2 { + return Ok(Parsing::NeedMore(offset + 2 - bytes.len())); + } + let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]); + offset += 2; + u64::from(len) + } + EIGHT_EXT => { + if bytes.len() < offset + 8 { + return Ok(Parsing::NeedMore(offset + 8 - bytes.len())); + } + let mut b = [0; 8]; + b.copy_from_slice(&bytes[offset..offset + 8]); + offset += 8; + u64::from_be_bytes(b) + } + n => u64::from(n), + }; + + if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() { + return Err(Error::InvalidControlFrameLen); + } + + let len: usize = if len > as_u64(self.max_data_size) { + return Err(Error::PayloadTooLarge { actual: len, maximum: as_u64(self.max_data_size) }); + } else { + len as usize + }; + + header.set_payload_len(len); + + if header.is_masked() { + if bytes.len() < offset + 4 { + return Ok(Parsing::NeedMore(offset + 4 - bytes.len())); + } + let mut b = [0; 4]; + b.copy_from_slice(&bytes[offset..offset + 4]); + offset += 4; + header.set_mask(u32::from_be_bytes(b)); + } + + Ok(Parsing::Done { value: header, offset }) + } + + /// Encode a websocket frame header. + pub fn encode_header(&mut self, header: &Header) -> &[u8] { + let mut offset = 0; + + let mut first_byte = 0_u8; + if header.is_fin() { + first_byte |= 0x80 + } + if header.is_rsv1() { + first_byte |= 0x40 + } + if header.is_rsv2() { + first_byte |= 0x20 + } + if header.is_rsv3() { + first_byte |= 0x10 + } + + let opcode: u8 = header.opcode().into(); + first_byte |= opcode; + + self.header_buffer[offset] = first_byte; + offset += 1; + + let mut second_byte = 0_u8; + if header.is_masked() { + second_byte |= 0x80 + } + + let len = header.payload_len(); + + if len < usize::from(TWO_EXT) { + second_byte |= len as u8; + self.header_buffer[offset] = second_byte; + offset += 1; + } else if len <= usize::from(u16::max_value()) { + second_byte |= TWO_EXT; + self.header_buffer[offset] = second_byte; + offset += 1; + self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); + offset += 2; + } else { + second_byte |= EIGHT_EXT; + self.header_buffer[offset] = second_byte; + offset += 1; + self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); + offset += 8; + } + + if header.is_masked() { + self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes()); + offset += 4; + } + + &self.header_buffer[..offset] + } + + /// Use the given header's mask and apply it to the data. + pub fn apply_mask(header: &Header, data: &mut [u8]) { + if header.is_masked() { + let mask = header.mask().to_be_bytes(); + for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) { + *byte ^= key; + } + } + } } /// Error cases the base frame decoder may encounter. #[non_exhaustive] #[derive(Debug)] pub enum Error { - /// An I/O error has been encountered. - Io(io::Error), - /// Some unknown opcode number has been decoded. - UnknownOpCode, - /// The opcode decoded is reserved. - ReservedOpCode, - /// A fragmented control frame (fin bit not set) has been decoded. - FragmentedControl, - /// A control frame with an invalid length code has been decoded. - InvalidControlFrameLen, - /// The reserved bit is invalid. - InvalidReservedBit(u8), - /// The payload length of a frame exceeded the configured maximum. - PayloadTooLarge { actual: u64, maximum: u64 } + /// An I/O error has been encountered. + Io(io::Error), + /// Some unknown opcode number has been decoded. + UnknownOpCode, + /// The opcode decoded is reserved. + ReservedOpCode, + /// A fragmented control frame (fin bit not set) has been decoded. + FragmentedControl, + /// A control frame with an invalid length code has been decoded. + InvalidControlFrameLen, + /// The reserved bit is invalid. + InvalidReservedBit(u8), + /// The payload length of a frame exceeded the configured maximum. + PayloadTooLarge { actual: u64, maximum: u64 }, } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::UnknownOpCode => - f.write_str("unknown opcode"), - Error::ReservedOpCode => - f.write_str("reserved opcode"), - Error::FragmentedControl => - f.write_str("fragmented control frame"), - Error::InvalidControlFrameLen => - f.write_str("invalid control frame length"), - Error::InvalidReservedBit(n) => - write!(f, "invalid reserved bit: {}", n), - Error::PayloadTooLarge { actual, maximum } => - write!(f, "payload too large: len = {}, maximum = {}", actual, maximum) - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::UnknownOpCode => f.write_str("unknown opcode"), + Error::ReservedOpCode => f.write_str("reserved opcode"), + Error::FragmentedControl => f.write_str("fragmented control frame"), + Error::InvalidControlFrameLen => f.write_str("invalid control frame length"), + Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n), + Error::PayloadTooLarge { actual, maximum } => { + write!(f, "payload too large: len = {}, maximum = {}", actual, maximum) + } + } + } } impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Error::Io(e) => Some(e), - Error::UnknownOpCode - | Error::ReservedOpCode - | Error::FragmentedControl - | Error::InvalidControlFrameLen - | Error::InvalidReservedBit(_) - | Error::PayloadTooLarge {..} - => None - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Io(e) => Some(e), + Error::UnknownOpCode + | Error::ReservedOpCode + | Error::FragmentedControl + | Error::InvalidControlFrameLen + | Error::InvalidReservedBit(_) + | Error::PayloadTooLarge { .. } => None, + } + } } impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(e) - } + fn from(e: io::Error) -> Self { + Error::Io(e) + } } impl From for Error { - fn from(_: UnknownOpCode) -> Self { - Error::UnknownOpCode - } + fn from(_: UnknownOpCode) -> Self { + Error::UnknownOpCode + } } - // Tests ////////////////////////////////////////////////////////////////////////////////////////// #[cfg(test)] mod test { - use crate::Parsing; - use quickcheck::QuickCheck; - use super::{OpCode, Codec, Error}; - - #[test] - fn decode_partial_header() { - let partial_header: &[u8] = &[0x89]; - assert!(matches! { - Codec::new().decode_header(partial_header), - Ok(Parsing::NeedMore(1)) - }) - } - - #[test] - fn decode_partial_len() { - let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01]; - assert!(matches! { - Codec::new().decode_header(partial_length_1), - Ok(Parsing::NeedMore(1)) - }); - let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04]; - assert!(matches! { - Codec::new().decode_header(partial_length_2), - Ok(Parsing::NeedMore(4)) - }) - } - - #[test] - fn decode_partial_mask() { - let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00]; - assert!(matches! { - Codec::new().decode_header(partial_mask), - Ok(Parsing::NeedMore(2)) - }) - } - - #[test] - fn decode_partial_payload() { - let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00]; - if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) { - assert_eq!(3, value.payload_len() - (partial_payload.len() - offset)) - } else { - assert!(false) - } - } - - #[test] - fn decode_invalid_control_payload_len() { - // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less. - let ctrl_payload_len : &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; - assert!(matches! { - Codec::new().decode_header(ctrl_payload_len), - Err(Error::InvalidControlFrameLen) - }) - } - - /// Checking that rsv1, rsv2, and rsv3 bit set returns error. - #[test] - fn decode_reserved() { - // rsv1, rsv2, and rsv3. - let reserved = [0x90, 0xa0, 0xc0]; - for res in &reserved { - let mut buf = [0; 2]; - buf[0] |= *res; - assert!(matches! { - Codec::new().decode_header(&buf), - Err(Error::InvalidReservedBit(_)) - }) - } - } - - /// Checking that a control frame, where fin bit is 0, returns an error. - #[test] - fn decode_fragmented_control() { - let second_bytes = [8, 9, 10]; - for sb in &second_bytes { - let mut buf = [0; 2]; - buf[0] |= *sb; - assert!(matches! { - Codec::new().decode_header(&buf), - Err(Error::FragmentedControl) - }) - } - } - - /// Checking that reserved opcodes return an error. - #[test] - fn decode_reserved_opcodes() { - let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15]; - for res in &reserved { - let mut buf = [0; 2]; - buf[0] |= 0x80 | *res; - assert!(matches! { - Codec::new().decode_header(&buf), - Err(Error::ReservedOpCode) - }) - } - } - - #[test] - fn decode_ping_no_data() { - let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01]; - let c = Codec::new(); - if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) { - assert!(header.is_fin()); - assert!(!header.is_rsv1()); - assert!(!header.is_rsv2()); - assert!(!header.is_rsv3()); - assert!(header.opcode() == OpCode::Ping); - assert!(header.payload_len() == 0) - } else { - assert!(false) - } - } - - #[test] - fn reserved_bits() { - fn property(bits: (bool, bool, bool)) -> bool { - let mut c = Codec::new(); - assert_eq!((false, false, false), c.reserved_bits()); - c.add_reserved_bits(bits); - bits == c.reserved_bits() - } - QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool) - } + use super::{Codec, Error, OpCode}; + use crate::Parsing; + use quickcheck::QuickCheck; + + #[test] + fn decode_partial_header() { + let partial_header: &[u8] = &[0x89]; + assert!(matches! { + Codec::new().decode_header(partial_header), + Ok(Parsing::NeedMore(1)) + }) + } + + #[test] + fn decode_partial_len() { + let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01]; + assert!(matches! { + Codec::new().decode_header(partial_length_1), + Ok(Parsing::NeedMore(1)) + }); + let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04]; + assert!(matches! { + Codec::new().decode_header(partial_length_2), + Ok(Parsing::NeedMore(4)) + }) + } + + #[test] + fn decode_partial_mask() { + let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00]; + assert!(matches! { + Codec::new().decode_header(partial_mask), + Ok(Parsing::NeedMore(2)) + }) + } + + #[test] + fn decode_partial_payload() { + let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00]; + if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) { + assert_eq!(3, value.payload_len() - (partial_payload.len() - offset)) + } else { + assert!(false) + } + } + + #[test] + fn decode_invalid_control_payload_len() { + // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less. + let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert!(matches! { + Codec::new().decode_header(ctrl_payload_len), + Err(Error::InvalidControlFrameLen) + }) + } + + /// Checking that rsv1, rsv2, and rsv3 bit set returns error. + #[test] + fn decode_reserved() { + // rsv1, rsv2, and rsv3. + let reserved = [0x90, 0xa0, 0xc0]; + for res in &reserved { + let mut buf = [0; 2]; + buf[0] |= *res; + assert!(matches! { + Codec::new().decode_header(&buf), + Err(Error::InvalidReservedBit(_)) + }) + } + } + + /// Checking that a control frame, where fin bit is 0, returns an error. + #[test] + fn decode_fragmented_control() { + let second_bytes = [8, 9, 10]; + for sb in &second_bytes { + let mut buf = [0; 2]; + buf[0] |= *sb; + assert!(matches! { + Codec::new().decode_header(&buf), + Err(Error::FragmentedControl) + }) + } + } + + /// Checking that reserved opcodes return an error. + #[test] + fn decode_reserved_opcodes() { + let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15]; + for res in &reserved { + let mut buf = [0; 2]; + buf[0] |= 0x80 | *res; + assert!(matches! { + Codec::new().decode_header(&buf), + Err(Error::ReservedOpCode) + }) + } + } + + #[test] + fn decode_ping_no_data() { + let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01]; + let c = Codec::new(); + if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) { + assert!(header.is_fin()); + assert!(!header.is_rsv1()); + assert!(!header.is_rsv2()); + assert!(!header.is_rsv3()); + assert!(header.opcode() == OpCode::Ping); + assert!(header.payload_len() == 0) + } else { + assert!(false) + } + } + + #[test] + fn reserved_bits() { + fn property(bits: (bool, bool, bool)) -> bool { + let mut c = Codec::new(); + assert_eq!((false, false, false), c.reserved_bits()); + c.add_reserved_bits(bits); + bits == c.reserved_bits() + } + QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool) + } } - diff --git a/src/connection.rs b/src/connection.rs index 1e1110fc..06033cbb 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -9,10 +9,18 @@ //! A persistent websocket connection after the handshake phase, represented //! as a [`Sender`] and [`Receiver`] pair. -use bytes::{Buf, BytesMut}; -use crate::{Storage, Parsing, base::{self, Header, MAX_HEADER_SIZE, OpCode}, extension::Extension}; use crate::data::{ByteSlice125, Data, Incoming}; -use futures::{io::{ReadHalf, WriteHalf}, lock::BiLock, prelude::*}; +use crate::{ + base::{self, Header, OpCode, MAX_HEADER_SIZE}, + extension::Extension, + Parsing, Storage, +}; +use bytes::{Buf, BytesMut}; +use futures::{ + io::{ReadHalf, WriteHalf}, + lock::BiLock, + prelude::*, +}; use std::{fmt, io, str}; /// Accumulated max. size of a complete message. @@ -24,24 +32,24 @@ const MAX_FRAME_SIZE: usize = MAX_MESSAGE_SIZE; /// Is the connection used by a client or server? #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum Mode { - /// Client-side of a connection (implies masking of payload data). - Client, - /// Server-side of a connection. - Server + /// Client-side of a connection (implies masking of payload data). + Client, + /// Server-side of a connection. + Server, } impl Mode { - pub fn is_client(self) -> bool { - if let Mode::Client = self { - true - } else { - false - } - } - - pub fn is_server(self) -> bool { - !self.is_client() - } + pub fn is_client(self) -> bool { + if let Mode::Client = self { + true + } else { + false + } + } + + pub fn is_server(self) -> bool { + !self.is_client() + } } /// Connection ID. @@ -49,37 +57,37 @@ impl Mode { struct Id(u32); impl fmt::Display for Id { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:08x}", self.0) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:08x}", self.0) + } } /// The sending half of a connection. #[derive(Debug)] pub struct Sender { - id: Id, - mode: Mode, - codec: base::Codec, - writer: BiLock>, - mask_buffer: Vec, - extensions: BiLock>>, - has_extensions: bool + id: Id, + mode: Mode, + codec: base::Codec, + writer: BiLock>, + mask_buffer: Vec, + extensions: BiLock>>, + has_extensions: bool, } /// The receiving half of a connection. #[derive(Debug)] pub struct Receiver { - id: Id, - mode: Mode, - codec: base::Codec, - reader: ReadHalf, - writer: BiLock>, - extensions: BiLock>>, - has_extensions: bool, - buffer: BytesMut, - ctrl_buffer: BytesMut, - max_message_size: usize, - is_closed: bool + id: Id, + mode: Mode, + codec: base::Codec, + reader: ReadHalf, + writer: BiLock>, + extensions: BiLock>>, + has_extensions: bool, + buffer: BytesMut, + ctrl_buffer: BytesMut, + max_message_size: usize, + is_closed: bool, } /// A connection builder. @@ -89,486 +97,506 @@ pub struct Receiver { /// connection. #[derive(Debug)] pub struct Builder { - id: Id, - mode: Mode, - socket: T, - codec: base::Codec, - extensions: Vec>, - buffer: BytesMut, - max_message_size: usize + id: Id, + mode: Mode, + socket: T, + codec: base::Codec, + extensions: Vec>, + buffer: BytesMut, + max_message_size: usize, } impl Builder { - /// Create a new `Builder` from the given async I/O resource and mode. - /// - /// **Note**: Use this type only after a successful [handshake][0]. - /// You can either use this crate's [handshake functionality][1] - /// or perform the handshake by some other means. - /// - /// [0]: https://tools.ietf.org/html/rfc6455#section-4 - /// [1]: crate::handshake - pub fn new(socket: T, mode: Mode) -> Self { - let mut codec = base::Codec::default(); - codec.set_max_data_size(MAX_FRAME_SIZE); - Builder { - id: Id(rand::random()), - mode, - socket, - codec, - extensions: Vec::new(), - buffer: BytesMut::new(), - max_message_size: MAX_MESSAGE_SIZE - } - } - - /// Set a custom buffer to use. - pub fn set_buffer(&mut self, b: BytesMut) { - self.buffer = b - } - - /// Add extensions to use with this connection. - /// - /// Only enabled extensions will be considered. - pub fn add_extensions(&mut self, extensions: I) - where - I: IntoIterator> - { - for e in extensions.into_iter().filter(|e| e.is_enabled()) { - log::debug!("{}: using extension: {}", self.id, e.name()); - self.codec.add_reserved_bits(e.reserved_bits()); - self.extensions.push(e) - } - } - - /// Set the maximum size of a complete message. - /// - /// Message fragments will be buffered and concatenated up to this value, - /// i.e. the sum of all message frames payload lengths will not be greater - /// than this maximum. However, extensions may increase the total message - /// size further, e.g. by decompressing the payload data. - pub fn set_max_message_size(&mut self, max: usize) { - self.max_message_size = max - } - - /// Set the maximum size of a single websocket frame payload. - pub fn set_max_frame_size(&mut self, max: usize) { - self.codec.set_max_data_size(max); - } - - /// Create a configured [`Sender`]/[`Receiver`] pair. - pub fn finish(self) -> (Sender, Receiver) { - let (rhlf, whlf) = self.socket.split(); - let (wrt1, wrt2) = BiLock::new(whlf); - let has_extensions = !self.extensions.is_empty(); - let (ext1, ext2) = BiLock::new(self.extensions); - - let recv = Receiver { - id: self.id, - mode: self.mode, - reader: rhlf, - writer: wrt1, - codec: self.codec.clone(), - extensions: ext1, - has_extensions, - buffer: self.buffer, - ctrl_buffer: BytesMut::new(), - max_message_size: self.max_message_size, - is_closed: false - }; - - let send = Sender { - id: self.id, - mode: self.mode, - writer: wrt2, - mask_buffer: Vec::new(), - codec: self.codec, - extensions: ext2, - has_extensions - }; - - (send, recv) - } + /// Create a new `Builder` from the given async I/O resource and mode. + /// + /// **Note**: Use this type only after a successful [handshake][0]. + /// You can either use this crate's [handshake functionality][1] + /// or perform the handshake by some other means. + /// + /// [0]: https://tools.ietf.org/html/rfc6455#section-4 + /// [1]: crate::handshake + pub fn new(socket: T, mode: Mode) -> Self { + let mut codec = base::Codec::default(); + codec.set_max_data_size(MAX_FRAME_SIZE); + Builder { + id: Id(rand::random()), + mode, + socket, + codec, + extensions: Vec::new(), + buffer: BytesMut::new(), + max_message_size: MAX_MESSAGE_SIZE, + } + } + + /// Set a custom buffer to use. + pub fn set_buffer(&mut self, b: BytesMut) { + self.buffer = b + } + + /// Add extensions to use with this connection. + /// + /// Only enabled extensions will be considered. + pub fn add_extensions(&mut self, extensions: I) + where + I: IntoIterator>, + { + for e in extensions.into_iter().filter(|e| e.is_enabled()) { + log::debug!("{}: using extension: {}", self.id, e.name()); + self.codec.add_reserved_bits(e.reserved_bits()); + self.extensions.push(e) + } + } + + /// Set the maximum size of a complete message. + /// + /// Message fragments will be buffered and concatenated up to this value, + /// i.e. the sum of all message frames payload lengths will not be greater + /// than this maximum. However, extensions may increase the total message + /// size further, e.g. by decompressing the payload data. + pub fn set_max_message_size(&mut self, max: usize) { + self.max_message_size = max + } + + /// Set the maximum size of a single websocket frame payload. + pub fn set_max_frame_size(&mut self, max: usize) { + self.codec.set_max_data_size(max); + } + + /// Create a configured [`Sender`]/[`Receiver`] pair. + pub fn finish(self) -> (Sender, Receiver) { + let (rhlf, whlf) = self.socket.split(); + let (wrt1, wrt2) = BiLock::new(whlf); + let has_extensions = !self.extensions.is_empty(); + let (ext1, ext2) = BiLock::new(self.extensions); + + let recv = Receiver { + id: self.id, + mode: self.mode, + reader: rhlf, + writer: wrt1, + codec: self.codec.clone(), + extensions: ext1, + has_extensions, + buffer: self.buffer, + ctrl_buffer: BytesMut::new(), + max_message_size: self.max_message_size, + is_closed: false, + }; + + let send = Sender { + id: self.id, + mode: self.mode, + writer: wrt2, + mask_buffer: Vec::new(), + codec: self.codec, + extensions: ext2, + has_extensions, + }; + + (send, recv) + } } impl Receiver { - /// Receive the next websocket message. - /// - /// The received frames forming the complete message will be appended to - /// the given `message` argument. The returned [`Incoming`] value describes - /// the type of data that was received, e.g. binary or textual data. - /// - /// Interleaved PONG frames are returned immediately as `Data::Pong` - /// values. If PONGs are not expected or uninteresting, - /// [`Receiver::receive_data`] may be used instead which skips over PONGs - /// and considers only application payload data. - pub async fn receive(&mut self, message: &mut Vec) -> Result, Error> { - let mut first_fragment_opcode = None; - let mut length: usize = 0; - let message_len = message.len(); - loop { - if self.is_closed { - log::debug!("{}: cannot receive, connection is closed", self.id); - return Err(Error::Closed); - } - - self.ctrl_buffer.clear(); - let mut header = self.receive_header().await?; - log::trace!("{}: recv: {}", self.id, header); - - // Handle control frames: PING, PONG and CLOSE. - if header.opcode().is_control() { - self.read_buffer(&header).await?; - self.ctrl_buffer = self.buffer.split_to(header.payload_len()); - base::Codec::apply_mask(&header, &mut self.ctrl_buffer); - if header.opcode() == OpCode::Pong { - return Ok(Incoming::Pong(&self.ctrl_buffer[..])); - } - if let Some(close_reason) = self.on_control(&header).await? { - log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason); - return Ok(Incoming::Closed(close_reason)); - } - continue; - } - - length = length.saturating_add(header.payload_len()); - - // Check if total message does not exceed maximum. - if length > self.max_message_size { - log::warn!("{}: accumulated message length exceeds maximum", self.id); - return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size }) - } - - // Get the frame's payload data bytes from buffer or socket. - { - let old_msg_len = message.len(); - - let bytes_to_read = { - let required = header.payload_len(); - let buffered = self.buffer.len(); - - if buffered == 0 { - required - } else if required > buffered { - message.extend_from_slice(&self.buffer); - self.buffer.clear(); - required - buffered - } else { - message.extend_from_slice(&self.buffer.split_to(required)); - 0 - } - }; - - if bytes_to_read > 0 { - let n = message.len(); - message.resize(n + bytes_to_read, 0u8); - self.reader.read_exact(&mut message[n ..]).await? - } - - debug_assert_eq!(header.payload_len(), message.len() - old_msg_len); - - base::Codec::apply_mask(&header, &mut message[old_msg_len ..]); - } - - match (header.is_fin(), header.opcode()) { - (false, OpCode::Continue) => { // Intermediate message fragment. - if first_fragment_opcode.is_none() { - log::debug!("{}: continue frame while not processing message fragments", self.id); - return Err(Error::UnexpectedOpCode(OpCode::Continue)) - } - continue - } - (false, oc) => { // Initial message fragment. - if first_fragment_opcode.is_some() { - log::debug!("{}: initial fragment while processing a fragmented message", self.id); - return Err(Error::UnexpectedOpCode(oc)) - } - first_fragment_opcode = Some(oc); - self.decode_with_extensions(&mut header, message).await?; - continue - } - (true, OpCode::Continue) => { // Last message fragment. - if let Some(oc) = first_fragment_opcode.take() { - header.set_payload_len(message.len()); - log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len()); - self.decode_with_extensions(&mut header, message).await?; - header.set_opcode(oc); - } else { - log::debug!("{}: last continue frame while not processing message fragments", self.id); - return Err(Error::UnexpectedOpCode(OpCode::Continue)) - } - } - (true, oc) => { // Regular non-fragmented message. - if first_fragment_opcode.is_some() { - log::debug!("{}: regular message while processing fragmented message", self.id); - return Err(Error::UnexpectedOpCode(oc)) - } - self.decode_with_extensions(&mut header, message).await? - } - } - - let num_bytes = message.len() - message_len; - - if header.opcode() == OpCode::Text { - return Ok(Incoming::Data(Data::Text(num_bytes))) - } else { - return Ok(Incoming::Data(Data::Binary(num_bytes))) - } - } - } - - /// Receive the next websocket message, skipping over control frames. - pub async fn receive_data(&mut self, message: &mut Vec) -> Result { - loop { - if let Incoming::Data(d) = self.receive(message).await? { - return Ok(d) - } - } - } - - /// Read the next frame header. - async fn receive_header(&mut self) -> Result { - loop { - match self.codec.decode_header(&self.buffer)? { - Parsing::Done { value: header, offset } => { - debug_assert!(offset <= MAX_HEADER_SIZE); - self.buffer.advance(offset); - return Ok(header) - } - Parsing::NeedMore(n) => { - crate::read(&mut self.reader, &mut self.buffer, n).await? - } - } - } - } - - /// Read the complete payload data into the read buffer. - async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> { - if header.payload_len() <= self.buffer.len() { - return Ok(()) - } - let i = self.buffer.len(); - let d = header.payload_len() - i; - self.buffer.resize(i + d, 0u8); - self.reader.read_exact(&mut self.buffer[i ..]).await?; - Ok(()) - } - - /// Answer incoming control frames. - /// `PING`: replied to immediately with a `PONG` - /// `PONG`: no action - /// `CLOSE`: replied to immediately with a `CLOSE`; returns the [`CloseReason`] - /// All other [`OpCode`]s return [`Error::UnexpectedOpCode`] - async fn on_control(&mut self, header: &Header) -> Result, Error> { - match header.opcode() { - OpCode::Ping => { - let mut answer = Header::new(OpCode::Pong); - let mut unused = Vec::new(); - let mut data = Storage::Unique(&mut self.ctrl_buffer); - write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused).await?; - self.flush().await?; - Ok(None) - } - OpCode::Pong => Ok(None), - OpCode::Close => { - log::trace!("{}: Acknowledging CLOSE to sender", self.id); - self.is_closed = true; - let (mut header, reason) = close_answer(&self.ctrl_buffer)?; - // Write back a Close frame - let mut unused = Vec::new(); - if let Some(CloseReason { code, .. }) = reason { - let mut data = code.to_be_bytes(); - let mut data = Storage::Unique(&mut data); - let _ = write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await; - } else { - let mut data = Storage::Unique(&mut []); - let _ = write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await; - } - self.flush().await?; - self.writer.lock().await.close().await?; - Ok(reason) - } - OpCode::Binary - | OpCode::Text - | OpCode::Continue - | OpCode::Reserved3 - | OpCode::Reserved4 - | OpCode::Reserved5 - | OpCode::Reserved6 - | OpCode::Reserved7 - | OpCode::Reserved11 - | OpCode::Reserved12 - | OpCode::Reserved13 - | OpCode::Reserved14 - | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())) - } - } - - /// Apply all extensions to the given header and the internal message buffer. - async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec) -> Result<(), Error> { - if !self.has_extensions { - return Ok(()) - } - for e in self.extensions.lock().await.iter_mut() { - log::trace!("{}: decoding with extension: {}", self.id, e.name()); - e.decode(header, message).map_err(Error::Extension)? - } - Ok(()) - } - - /// Flush the socket buffer. - async fn flush(&mut self) -> Result<(), Error> { - log::trace!("{}: Receiver flushing connection", self.id); - if self.is_closed { - return Ok(()) - } - self.writer.lock().await.flush().await.or(Err(Error::Closed)) - } + /// Receive the next websocket message. + /// + /// The received frames forming the complete message will be appended to + /// the given `message` argument. The returned [`Incoming`] value describes + /// the type of data that was received, e.g. binary or textual data. + /// + /// Interleaved PONG frames are returned immediately as `Data::Pong` + /// values. If PONGs are not expected or uninteresting, + /// [`Receiver::receive_data`] may be used instead which skips over PONGs + /// and considers only application payload data. + pub async fn receive(&mut self, message: &mut Vec) -> Result, Error> { + let mut first_fragment_opcode = None; + let mut length: usize = 0; + let message_len = message.len(); + loop { + if self.is_closed { + log::debug!("{}: cannot receive, connection is closed", self.id); + return Err(Error::Closed); + } + + self.ctrl_buffer.clear(); + let mut header = self.receive_header().await?; + log::trace!("{}: recv: {}", self.id, header); + + // Handle control frames: PING, PONG and CLOSE. + if header.opcode().is_control() { + self.read_buffer(&header).await?; + self.ctrl_buffer = self.buffer.split_to(header.payload_len()); + base::Codec::apply_mask(&header, &mut self.ctrl_buffer); + if header.opcode() == OpCode::Pong { + return Ok(Incoming::Pong(&self.ctrl_buffer[..])); + } + if let Some(close_reason) = self.on_control(&header).await? { + log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason); + return Ok(Incoming::Closed(close_reason)); + } + continue; + } + + length = length.saturating_add(header.payload_len()); + + // Check if total message does not exceed maximum. + if length > self.max_message_size { + log::warn!("{}: accumulated message length exceeds maximum", self.id); + return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size }); + } + + // Get the frame's payload data bytes from buffer or socket. + { + let old_msg_len = message.len(); + + let bytes_to_read = { + let required = header.payload_len(); + let buffered = self.buffer.len(); + + if buffered == 0 { + required + } else if required > buffered { + message.extend_from_slice(&self.buffer); + self.buffer.clear(); + required - buffered + } else { + message.extend_from_slice(&self.buffer.split_to(required)); + 0 + } + }; + + if bytes_to_read > 0 { + let n = message.len(); + message.resize(n + bytes_to_read, 0u8); + self.reader.read_exact(&mut message[n..]).await? + } + + debug_assert_eq!(header.payload_len(), message.len() - old_msg_len); + + base::Codec::apply_mask(&header, &mut message[old_msg_len..]); + } + + match (header.is_fin(), header.opcode()) { + (false, OpCode::Continue) => { + // Intermediate message fragment. + if first_fragment_opcode.is_none() { + log::debug!("{}: continue frame while not processing message fragments", self.id); + return Err(Error::UnexpectedOpCode(OpCode::Continue)); + } + continue; + } + (false, oc) => { + // Initial message fragment. + if first_fragment_opcode.is_some() { + log::debug!("{}: initial fragment while processing a fragmented message", self.id); + return Err(Error::UnexpectedOpCode(oc)); + } + first_fragment_opcode = Some(oc); + self.decode_with_extensions(&mut header, message).await?; + continue; + } + (true, OpCode::Continue) => { + // Last message fragment. + if let Some(oc) = first_fragment_opcode.take() { + header.set_payload_len(message.len()); + log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len()); + self.decode_with_extensions(&mut header, message).await?; + header.set_opcode(oc); + } else { + log::debug!("{}: last continue frame while not processing message fragments", self.id); + return Err(Error::UnexpectedOpCode(OpCode::Continue)); + } + } + (true, oc) => { + // Regular non-fragmented message. + if first_fragment_opcode.is_some() { + log::debug!("{}: regular message while processing fragmented message", self.id); + return Err(Error::UnexpectedOpCode(oc)); + } + self.decode_with_extensions(&mut header, message).await? + } + } + + let num_bytes = message.len() - message_len; + + if header.opcode() == OpCode::Text { + return Ok(Incoming::Data(Data::Text(num_bytes))); + } else { + return Ok(Incoming::Data(Data::Binary(num_bytes))); + } + } + } + + /// Receive the next websocket message, skipping over control frames. + pub async fn receive_data(&mut self, message: &mut Vec) -> Result { + loop { + if let Incoming::Data(d) = self.receive(message).await? { + return Ok(d); + } + } + } + + /// Read the next frame header. + async fn receive_header(&mut self) -> Result { + loop { + match self.codec.decode_header(&self.buffer)? { + Parsing::Done { value: header, offset } => { + debug_assert!(offset <= MAX_HEADER_SIZE); + self.buffer.advance(offset); + return Ok(header); + } + Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?, + } + } + } + + /// Read the complete payload data into the read buffer. + async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> { + if header.payload_len() <= self.buffer.len() { + return Ok(()); + } + let i = self.buffer.len(); + let d = header.payload_len() - i; + self.buffer.resize(i + d, 0u8); + self.reader.read_exact(&mut self.buffer[i..]).await?; + Ok(()) + } + + /// Answer incoming control frames. + /// `PING`: replied to immediately with a `PONG` + /// `PONG`: no action + /// `CLOSE`: replied to immediately with a `CLOSE`; returns the [`CloseReason`] + /// All other [`OpCode`]s return [`Error::UnexpectedOpCode`] + async fn on_control(&mut self, header: &Header) -> Result, Error> { + match header.opcode() { + OpCode::Ping => { + let mut answer = Header::new(OpCode::Pong); + let mut unused = Vec::new(); + let mut data = Storage::Unique(&mut self.ctrl_buffer); + write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused) + .await?; + self.flush().await?; + Ok(None) + } + OpCode::Pong => Ok(None), + OpCode::Close => { + log::trace!("{}: Acknowledging CLOSE to sender", self.id); + self.is_closed = true; + let (mut header, reason) = close_answer(&self.ctrl_buffer)?; + // Write back a Close frame + let mut unused = Vec::new(); + if let Some(CloseReason { code, .. }) = reason { + let mut data = code.to_be_bytes(); + let mut data = Storage::Unique(&mut data); + let _ = write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + &mut header, + &mut data, + &mut unused, + ) + .await; + } else { + let mut data = Storage::Unique(&mut []); + let _ = write( + self.id, + self.mode, + &mut self.codec, + &mut self.writer, + &mut header, + &mut data, + &mut unused, + ) + .await; + } + self.flush().await?; + self.writer.lock().await.close().await?; + Ok(reason) + } + OpCode::Binary + | OpCode::Text + | OpCode::Continue + | OpCode::Reserved3 + | OpCode::Reserved4 + | OpCode::Reserved5 + | OpCode::Reserved6 + | OpCode::Reserved7 + | OpCode::Reserved11 + | OpCode::Reserved12 + | OpCode::Reserved13 + | OpCode::Reserved14 + | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())), + } + } + + /// Apply all extensions to the given header and the internal message buffer. + async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec) -> Result<(), Error> { + if !self.has_extensions { + return Ok(()); + } + for e in self.extensions.lock().await.iter_mut() { + log::trace!("{}: decoding with extension: {}", self.id, e.name()); + e.decode(header, message).map_err(Error::Extension)? + } + Ok(()) + } + + /// Flush the socket buffer. + async fn flush(&mut self) -> Result<(), Error> { + log::trace!("{}: Receiver flushing connection", self.id); + if self.is_closed { + return Ok(()); + } + self.writer.lock().await.flush().await.or(Err(Error::Closed)) + } } impl Sender { - /// Send a text value over the websocket connection. - pub async fn send_text(&mut self, data: impl AsRef) -> Result<(), Error> { - let mut header = Header::new(OpCode::Text); - self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await - } - - /// Send a text value over the websocket connection. - /// - /// This method performs one copy fewer than [`Sender::send_text`]. - pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> { - let mut header = Header::new(OpCode::Text); - self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await - } - - /// Send some binary data over the websocket connection. - pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> { - let mut header = Header::new(OpCode::Binary); - self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await - } - - /// Send some binary data over the websocket connection. - /// - /// This method performs one copy fewer than [`Sender::send_binary`]. - /// The `data` buffer may be modified by this method, e.g. if masking is necessary. - pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> { - let mut header = Header::new(OpCode::Binary); - self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await - } - - /// Ping the remote end. - pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { - let mut header = Header::new(OpCode::Ping); - self.write(&mut header, &mut Storage::Shared(data.as_ref())).await - } - - /// Send an unsolicited Pong to the remote. - pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { - let mut header = Header::new(OpCode::Pong); - self.write(&mut header, &mut Storage::Shared(data.as_ref())).await - } - - /// Flush the socket buffer. - pub async fn flush(&mut self) -> Result<(), Error> { - log::trace!("{}: Sender flushing connection", self.id); - self.writer.lock().await.flush().await.or(Err(Error::Closed)) - } - - /// Send a close message and close the connection. - pub async fn close(&mut self) -> Result<(), Error> { - log::trace!("{}: closing connection", self.id); - let mut header = Header::new(OpCode::Close); - let code = 1000_u16.to_be_bytes(); // 1000 = normal closure - self.write(&mut header, &mut Storage::Shared(&code[..])).await?; - self.flush().await?; - self.writer.lock().await.close().await.or(Err(Error::Closed)) - } - - /// Send arbitrary websocket frames. - /// - /// Before sending, extensions will be applied to header and payload data. - async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { - if !self.has_extensions { - return self.write(header, data).await - } - - for e in self.extensions.lock().await.iter_mut() { - log::trace!("{}: encoding with extension: {}", self.id, e.name()); - e.encode(header, data).map_err(Error::Extension)? - } - - self.write(header, data).await - } - - /// Write final header and payload data to socket. - /// - /// The data will be masked if necessary. - /// No extensions will be applied to header and payload data. - async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { - write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await - } + /// Send a text value over the websocket connection. + pub async fn send_text(&mut self, data: impl AsRef) -> Result<(), Error> { + let mut header = Header::new(OpCode::Text); + self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await + } + + /// Send a text value over the websocket connection. + /// + /// This method performs one copy fewer than [`Sender::send_text`]. + pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> { + let mut header = Header::new(OpCode::Text); + self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await + } + + /// Send some binary data over the websocket connection. + pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> { + let mut header = Header::new(OpCode::Binary); + self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await + } + + /// Send some binary data over the websocket connection. + /// + /// This method performs one copy fewer than [`Sender::send_binary`]. + /// The `data` buffer may be modified by this method, e.g. if masking is necessary. + pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> { + let mut header = Header::new(OpCode::Binary); + self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await + } + + /// Ping the remote end. + pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { + let mut header = Header::new(OpCode::Ping); + self.write(&mut header, &mut Storage::Shared(data.as_ref())).await + } + + /// Send an unsolicited Pong to the remote. + pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { + let mut header = Header::new(OpCode::Pong); + self.write(&mut header, &mut Storage::Shared(data.as_ref())).await + } + + /// Flush the socket buffer. + pub async fn flush(&mut self) -> Result<(), Error> { + log::trace!("{}: Sender flushing connection", self.id); + self.writer.lock().await.flush().await.or(Err(Error::Closed)) + } + + /// Send a close message and close the connection. + pub async fn close(&mut self) -> Result<(), Error> { + log::trace!("{}: closing connection", self.id); + let mut header = Header::new(OpCode::Close); + let code = 1000_u16.to_be_bytes(); // 1000 = normal closure + self.write(&mut header, &mut Storage::Shared(&code[..])).await?; + self.flush().await?; + self.writer.lock().await.close().await.or(Err(Error::Closed)) + } + + /// Send arbitrary websocket frames. + /// + /// Before sending, extensions will be applied to header and payload data. + async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { + if !self.has_extensions { + return self.write(header, data).await; + } + + for e in self.extensions.lock().await.iter_mut() { + log::trace!("{}: encoding with extension: {}", self.id, e.name()); + e.encode(header, data).map_err(Error::Extension)? + } + + self.write(header, data).await + } + + /// Write final header and payload data to socket. + /// + /// The data will be masked if necessary. + /// No extensions will be applied to header and payload data. + async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { + write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await + } } /// Write header and payload data to socket. -async fn write - ( id: Id - , mode: Mode - , codec: &mut base::Codec - , writer: &mut BiLock> - , header: &mut Header - , data: &mut Storage<'_> - , mask_buffer: &mut Vec - ) -> Result<(), Error> -{ - if mode.is_client() { - header.set_masked(true); - header.set_mask(rand::random()); - } - header.set_payload_len(data.as_ref().len()); - - log::trace!("{}: send: {}", id, header); - - let header_bytes = codec.encode_header(&header); - let mut w = writer.lock().await; - w.write_all(&header_bytes).await.or(Err(Error::Closed))?; - - if !header.is_masked() { - return w.write_all(data.as_ref()).await.or(Err(Error::Closed)) - } - - match data { - Storage::Shared(slice) => { - mask_buffer.clear(); - mask_buffer.extend_from_slice(slice); - base::Codec::apply_mask(header, mask_buffer); - w.write_all(mask_buffer).await.or(Err(Error::Closed)) - } - Storage::Unique(slice) => { - base::Codec::apply_mask(header, slice); - w.write_all(slice).await.or(Err(Error::Closed)) - } - Storage::Owned(ref mut bytes) => { - base::Codec::apply_mask(header, bytes); - w.write_all(bytes).await.or(Err(Error::Closed)) - } - } +async fn write( + id: Id, + mode: Mode, + codec: &mut base::Codec, + writer: &mut BiLock>, + header: &mut Header, + data: &mut Storage<'_>, + mask_buffer: &mut Vec, +) -> Result<(), Error> { + if mode.is_client() { + header.set_masked(true); + header.set_mask(rand::random()); + } + header.set_payload_len(data.as_ref().len()); + + log::trace!("{}: send: {}", id, header); + + let header_bytes = codec.encode_header(&header); + let mut w = writer.lock().await; + w.write_all(&header_bytes).await.or(Err(Error::Closed))?; + + if !header.is_masked() { + return w.write_all(data.as_ref()).await.or(Err(Error::Closed)); + } + + match data { + Storage::Shared(slice) => { + mask_buffer.clear(); + mask_buffer.extend_from_slice(slice); + base::Codec::apply_mask(header, mask_buffer); + w.write_all(mask_buffer).await.or(Err(Error::Closed)) + } + Storage::Unique(slice) => { + base::Codec::apply_mask(header, slice); + w.write_all(slice).await.or(Err(Error::Closed)) + } + Storage::Owned(ref mut bytes) => { + base::Codec::apply_mask(header, bytes); + w.write_all(bytes).await.or(Err(Error::Closed)) + } + } } /// Create a close frame based on the given data. The close frame is echoed back /// to the sender. fn close_answer(data: &[u8]) -> Result<(Header, Option), Error> { - let answer = Header::new(OpCode::Close); - if data.len() < 2 { - return Ok((answer, None)); - } - // Check that the reason string is properly encoded - let descr = std::str::from_utf8(&data[2..])?.into(); - let code = u16::from_be_bytes([data[0], data[1]]); - let reason = CloseReason { code, descr: Some(descr) }; - - // Status codes are defined in - // https://tools.ietf.org/html/rfc6455#section-7.4.1 and - // https://mailarchive.ietf.org/arch/msg/hybi/P_1vbD9uyHl63nbIIbFxKMfSwcM/ - match code { + let answer = Header::new(OpCode::Close); + if data.len() < 2 { + return Ok((answer, None)); + } + // Check that the reason string is properly encoded + let descr = std::str::from_utf8(&data[2..])?.into(); + let code = u16::from_be_bytes([data[0], data[1]]); + let reason = CloseReason { code, descr: Some(descr) }; + + // Status codes are defined in + // https://tools.ietf.org/html/rfc6455#section-7.4.1 and + // https://mailarchive.ietf.org/arch/msg/hybi/P_1vbD9uyHl63nbIIbFxKMfSwcM/ + match code { | 1000 ..= 1003 | 1007 ..= 1011 | 1012 // Service Restart @@ -586,82 +614,75 @@ fn close_answer(data: &[u8]) -> Result<(Header, Option), Error> { #[non_exhaustive] #[derive(Debug)] pub enum Error { - /// An I/O error was encountered. - Io(io::Error), - /// The base codec errored. - Codec(base::Error), - /// An extension produced an error while encoding or decoding. - Extension(crate::BoxedError), - /// An unexpected opcode was encountered. - UnexpectedOpCode(OpCode), - /// A close reason was not correctly UTF-8 encoded. - Utf8(str::Utf8Error), - /// The total message payload data size exceeds the configured maximum. - MessageTooLarge { current: usize, maximum: usize }, - /// The connection is closed. - Closed, + /// An I/O error was encountered. + Io(io::Error), + /// The base codec errored. + Codec(base::Error), + /// An extension produced an error while encoding or decoding. + Extension(crate::BoxedError), + /// An unexpected opcode was encountered. + UnexpectedOpCode(OpCode), + /// A close reason was not correctly UTF-8 encoded. + Utf8(str::Utf8Error), + /// The total message payload data size exceeds the configured maximum. + MessageTooLarge { current: usize, maximum: usize }, + /// The connection is closed. + Closed, } /// Reason for closing the connection. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct CloseReason { - pub code: u16, - pub descr: Option, + pub code: u16, + pub descr: Option, } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::Codec(e) => - write!(f, "codec error: {}", e), - Error::Extension(e) => - write!(f, "extension error: {}", e), - Error::UnexpectedOpCode(c) => - write!(f, "unexpected opcode: {}", c), - Error::Utf8(e) => - write!(f, "utf-8 error: {}", e), - Error::MessageTooLarge { current, maximum } => - write!(f, "message too large: len >= {}, maximum = {}", current, maximum), - Error::Closed => f.write_str("connection closed") - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::Codec(e) => write!(f, "codec error: {}", e), + Error::Extension(e) => write!(f, "extension error: {}", e), + Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c), + Error::Utf8(e) => write!(f, "utf-8 error: {}", e), + Error::MessageTooLarge { current, maximum } => { + write!(f, "message too large: len >= {}, maximum = {}", current, maximum) + } + Error::Closed => f.write_str("connection closed"), + } + } } impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Error::Io(e) => Some(e), - Error::Codec(e) => Some(e), - Error::Extension(e) => Some(&**e), - Error::Utf8(e) => Some(e), - Error::UnexpectedOpCode(_) - | Error::MessageTooLarge {..} - | Error::Closed - => None - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Io(e) => Some(e), + Error::Codec(e) => Some(e), + Error::Extension(e) => Some(&**e), + Error::Utf8(e) => Some(e), + Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None, + } + } } impl From for Error { - fn from(e: io::Error) -> Self { - if e.kind() == io::ErrorKind::UnexpectedEof { - Error::Closed - } else { - Error::Io(e) - } - } + fn from(e: io::Error) -> Self { + if e.kind() == io::ErrorKind::UnexpectedEof { + Error::Closed + } else { + Error::Io(e) + } + } } impl From for Error { - fn from(e: str::Utf8Error) -> Self { - Error::Utf8(e) - } + fn from(e: str::Utf8Error) -> Self { + Error::Utf8(e) + } } impl From for Error { - fn from(e: base::Error) -> Self { - Error::Codec(e) - } + fn from(e: base::Error) -> Self { + Error::Codec(e) + } } diff --git a/src/data.rs b/src/data.rs index 827dac27..7c8299d6 100644 --- a/src/data.rs +++ b/src/data.rs @@ -15,70 +15,86 @@ use crate::connection::CloseReason; /// Data received from the remote end. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Incoming<'a> { - /// Text or binary data. - Data(Data), - /// Data sent with a PONG control frame. - Pong(&'a [u8]), - /// The other end closed the connection. - Closed(CloseReason), + /// Text or binary data. + Data(Data), + /// Data sent with a PONG control frame. + Pong(&'a [u8]), + /// The other end closed the connection. + Closed(CloseReason), } impl Incoming<'_> { - /// Is this text or binary data? - pub fn is_data(&self) -> bool { - if let Incoming::Data(_) = self { true } else { false } - } - - /// Is this a PONG? - pub fn is_pong(&self) -> bool { - if let Incoming::Pong(_) = self { true } else { false } - } - - /// Is this text data? - pub fn is_text(&self) -> bool { - if let Incoming::Data(d) = self { - d.is_text() - } else { - false - } - } - - /// Is this binary data? - pub fn is_binary(&self) -> bool { - if let Incoming::Data(d) = self { - d.is_binary() - } else { - false - } - } + /// Is this text or binary data? + pub fn is_data(&self) -> bool { + if let Incoming::Data(_) = self { + true + } else { + false + } + } + + /// Is this a PONG? + pub fn is_pong(&self) -> bool { + if let Incoming::Pong(_) = self { + true + } else { + false + } + } + + /// Is this text data? + pub fn is_text(&self) -> bool { + if let Incoming::Data(d) = self { + d.is_text() + } else { + false + } + } + + /// Is this binary data? + pub fn is_binary(&self) -> bool { + if let Incoming::Data(d) = self { + d.is_binary() + } else { + false + } + } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Data { - /// Textual data (number of bytes). - Text(usize), - /// Binary data (number of bytes). - Binary(usize) + /// Textual data (number of bytes). + Text(usize), + /// Binary data (number of bytes). + Binary(usize), } impl Data { - /// Is this text data? - pub fn is_text(&self) -> bool { - if let Data::Text(_) = self { true } else { false } - } - - /// Is this binary data? - pub fn is_binary(&self) -> bool { - if let Data::Binary(_) = self { true } else { false } - } - - /// The length of data (number of bytes). - pub fn len(&self) -> usize { - match self { - Data::Text(n) => *n, - Data::Binary(n) => *n - } - } + /// Is this text data? + pub fn is_text(&self) -> bool { + if let Data::Text(_) = self { + true + } else { + false + } + } + + /// Is this binary data? + pub fn is_binary(&self) -> bool { + if let Data::Binary(_) = self { + true + } else { + false + } + } + + /// The length of data (number of bytes). + pub fn len(&self) -> usize { + match self { + Data::Text(n) => *n, + Data::Binary(n) => *n, + } + } } /// Wrapper type which restricts the length of its byte slice to 125 bytes. @@ -90,27 +106,27 @@ pub struct ByteSlice125<'a>(&'a [u8]); pub struct SliceTooLarge(()); impl fmt::Display for SliceTooLarge { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str("Slice larger than 125 bytes") - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("Slice larger than 125 bytes") + } } impl std::error::Error for SliceTooLarge {} impl<'a> TryFrom<&'a [u8]> for ByteSlice125<'a> { - type Error = SliceTooLarge; - - fn try_from(value: &'a [u8]) -> Result { - if value.len() > 125 { - Err(SliceTooLarge(())) - } else { - Ok(ByteSlice125(value)) - } - } + type Error = SliceTooLarge; + + fn try_from(value: &'a [u8]) -> Result { + if value.len() > 125 { + Err(SliceTooLarge(())) + } else { + Ok(ByteSlice125(value)) + } + } } impl AsRef<[u8]> for ByteSlice125<'_> { - fn as_ref(&self) -> &[u8] { - self.0 - } + fn as_ref(&self) -> &[u8] { + self.0 + } } diff --git a/src/extension.rs b/src/extension.rs index 47428a73..636d6d33 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -14,7 +14,7 @@ #[cfg(feature = "deflate")] pub mod deflate; -use crate::{BoxedError, Storage, base::Header}; +use crate::{base::Header, BoxedError, Storage}; use std::{borrow::Cow, fmt}; /// A websocket extension as per RFC 6455, section 9. @@ -47,111 +47,104 @@ use std::{borrow::Cow, fmt}; /// potentially enabled. Enabled extensions can then be used for further base /// frame processing. pub trait Extension: std::fmt::Debug { - /// Is this extension enabled? - fn is_enabled(&self) -> bool; + /// Is this extension enabled? + fn is_enabled(&self) -> bool; - /// The name of this extension. - fn name(&self) -> &str; + /// The name of this extension. + fn name(&self) -> &str; - /// The parameters this extension wants to send for negotiation. - fn params(&self) -> &[Param]; + /// The parameters this extension wants to send for negotiation. + fn params(&self) -> &[Param]; - /// Configure this extension with the parameters received from negotiation. - fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError>; + /// Configure this extension with the parameters received from negotiation. + fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError>; - /// Encode a frame, given as frame header and payload data. - fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError>; + /// Encode a frame, given as frame header and payload data. + fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError>; - /// Decode a frame. - /// - /// The frame header is given, as well as the accumulated payload data, i.e. - /// the concatenated payload data of all message fragments. - fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError>; + /// Decode a frame. + /// + /// The frame header is given, as well as the accumulated payload data, i.e. + /// the concatenated payload data of all message fragments. + fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError>; - /// The reserved bits this extension uses. - fn reserved_bits(&self) -> (bool, bool, bool) { - (false, false, false) - } + /// The reserved bits this extension uses. + fn reserved_bits(&self) -> (bool, bool, bool) { + (false, false, false) + } } impl Extension for Box { - fn is_enabled(&self) -> bool { - (**self).is_enabled() - } + fn is_enabled(&self) -> bool { + (**self).is_enabled() + } - fn name(&self) -> &str { - (**self).name() - } + fn name(&self) -> &str { + (**self).name() + } - fn params(&self) -> &[Param] { - (**self).params() - } + fn params(&self) -> &[Param] { + (**self).params() + } - fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { - (**self).configure(params) - } + fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { + (**self).configure(params) + } - fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { - (**self).encode(header, data) - } + fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { + (**self).encode(header, data) + } - fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { - (**self).decode(header, data) - } + fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { + (**self).decode(header, data) + } - fn reserved_bits(&self) -> (bool, bool, bool) { - (**self).reserved_bits() - } + fn reserved_bits(&self) -> (bool, bool, bool) { + (**self).reserved_bits() + } } /// Extension parameter (used for negotiation). #[derive(Debug, Clone, PartialEq, Eq)] pub struct Param<'a> { - name: Cow<'a, str>, - value: Option> + name: Cow<'a, str>, + value: Option>, } impl<'a> fmt::Display for Param<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if let Some(v) = &self.value { - write!(f, "{} = {}", self.name, v) - } else { - write!(f, "{}", self.name) - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(v) = &self.value { + write!(f, "{} = {}", self.name, v) + } else { + write!(f, "{}", self.name) + } + } } impl<'a> Param<'a> { - /// Create a new parameter with the given name. - pub fn new(name: impl Into>) -> Self{ - Param { - name: name.into(), - value: None - } - } - - /// Access the parameter name. - pub fn name(&self) -> &str { - &self.name - } - - /// Access the optional parameter value. - pub fn value(&self) -> Option<&str> { - self.value.as_ref().map(|v| v.as_ref()) - } - - /// Set the parameter to the given value. - pub fn set_value(&mut self, value: Option>>) -> &mut Self { - self.value = value.map(Into::into); - self - } - - /// Turn this parameter into one that owns its values. - pub fn acquire(self) -> Param<'static> { - Param { - name: Cow::Owned(self.name.into_owned()), - value: self.value.map(|v| Cow::Owned(v.into_owned())) - } - } + /// Create a new parameter with the given name. + pub fn new(name: impl Into>) -> Self { + Param { name: name.into(), value: None } + } + + /// Access the parameter name. + pub fn name(&self) -> &str { + &self.name + } + + /// Access the optional parameter value. + pub fn value(&self) -> Option<&str> { + self.value.as_ref().map(|v| v.as_ref()) + } + + /// Set the parameter to the given value. + pub fn set_value(&mut self, value: Option>>) -> &mut Self { + self.value = value.map(Into::into); + self + } + + /// Turn this parameter into one that owns its values. + pub fn acquire(self) -> Param<'static> { + Param { name: Cow::Owned(self.name.into_owned()), value: self.value.map(|v| Cow::Owned(v.into_owned())) } + } } - diff --git a/src/extension/deflate.rs b/src/extension/deflate.rs index d3f4aad5..f3e807e2 100644 --- a/src/extension/deflate.rs +++ b/src/extension/deflate.rs @@ -11,15 +11,18 @@ //! [rfc7692]: https://tools.ietf.org/html/rfc7692 use crate::{ - as_u64, - BoxedError, - Storage, - base::{Header, OpCode}, - connection::Mode, - extension::{Extension, Param} + as_u64, + base::{Header, OpCode}, + connection::Mode, + extension::{Extension, Param}, + BoxedError, Storage, +}; +use flate2::{write::DeflateDecoder, Compress, Compression, FlushCompress, Status}; +use std::{ + convert::TryInto, + io::{self, Write}, + mem, }; -use flate2::{Compress, Compression, FlushCompress, Status, write::DeflateDecoder}; -use std::{convert::TryInto, io::{self, Write}, mem}; const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover"; const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits"; @@ -33,290 +36,288 @@ const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits"; /// default, which is 15 and will ask for no context takeover during handshake. #[derive(Debug)] pub struct Deflate { - mode: Mode, - enabled: bool, - buffer: Vec, - params: Vec>, - our_max_window_bits: u8, - their_max_window_bits: u8, - await_last_fragment: bool + mode: Mode, + enabled: bool, + buffer: Vec, + params: Vec>, + our_max_window_bits: u8, + their_max_window_bits: u8, + await_last_fragment: bool, } impl Deflate { - /// Create a new deflate extension either on client or server side. - pub fn new(mode: Mode) -> Self { - let params = match mode { - Mode::Server => Vec::new(), - Mode::Client => { - let mut params = Vec::new(); - params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); - params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); - params.push(Param::new(CLIENT_MAX_WINDOW_BITS)); - params - } - }; - Deflate { - mode, - enabled: false, - buffer: Vec::new(), - params, - our_max_window_bits: 15, - their_max_window_bits: 15, - await_last_fragment: false - } - } + /// Create a new deflate extension either on client or server side. + pub fn new(mode: Mode) -> Self { + let params = match mode { + Mode::Server => Vec::new(), + Mode::Client => { + let mut params = Vec::new(); + params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); + params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); + params.push(Param::new(CLIENT_MAX_WINDOW_BITS)); + params + } + }; + Deflate { + mode, + enabled: false, + buffer: Vec::new(), + params, + our_max_window_bits: 15, + their_max_window_bits: 15, + await_last_fragment: false, + } + } - /// Set the server's max. window bits. - /// - /// The value must be within 9 ..= 15. - /// The extension must be in client mode. - /// - /// By including this parameter, a client limits the LZ77 sliding window - /// size that the server will use to compress messages. A server accepts - /// by including the "server_max_window_bits" extension parameter in the - /// response with the same or smaller value as the offer. - pub fn set_max_server_window_bits(&mut self, max: u8) { - assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode"); - assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15"); - self.their_max_window_bits = max; // upper bound of the server's window - let mut p = Param::new(SERVER_MAX_WINDOW_BITS); - p.set_value(Some(max.to_string())); - self.params.push(p) - } + /// Set the server's max. window bits. + /// + /// The value must be within 9 ..= 15. + /// The extension must be in client mode. + /// + /// By including this parameter, a client limits the LZ77 sliding window + /// size that the server will use to compress messages. A server accepts + /// by including the "server_max_window_bits" extension parameter in the + /// response with the same or smaller value as the offer. + pub fn set_max_server_window_bits(&mut self, max: u8) { + assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode"); + assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15"); + self.their_max_window_bits = max; // upper bound of the server's window + let mut p = Param::new(SERVER_MAX_WINDOW_BITS); + p.set_value(Some(max.to_string())); + self.params.push(p) + } - /// Set the client's max. window bits. - /// - /// The value must be within 9 ..= 15. - /// The extension must be in client mode. - /// - /// The parameter informs the server that even if it doesn't include the - /// "client_max_window_bits" extension parameter in the response with a - /// value greater than the one in the negotiation offer or if it doesn't - /// include the extension parameter at all, the client is not going to - /// use an LZ77 sliding window size greater than one given here. - /// The server may also respond with a smaller value which allows the client - /// to reduce its sliding window even more. - pub fn set_max_client_window_bits(&mut self, max: u8) { - assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode"); - assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15"); - self.our_max_window_bits = max; // upper bound of the client's window - if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) { - p.set_value(Some(max.to_string())); - } else { - let mut p = Param::new(CLIENT_MAX_WINDOW_BITS); - p.set_value(Some(max.to_string())); - self.params.push(p) - } - } + /// Set the client's max. window bits. + /// + /// The value must be within 9 ..= 15. + /// The extension must be in client mode. + /// + /// The parameter informs the server that even if it doesn't include the + /// "client_max_window_bits" extension parameter in the response with a + /// value greater than the one in the negotiation offer or if it doesn't + /// include the extension parameter at all, the client is not going to + /// use an LZ77 sliding window size greater than one given here. + /// The server may also respond with a smaller value which allows the client + /// to reduce its sliding window even more. + pub fn set_max_client_window_bits(&mut self, max: u8) { + assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode"); + assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15"); + self.our_max_window_bits = max; // upper bound of the client's window + if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) { + p.set_value(Some(max.to_string())); + } else { + let mut p = Param::new(CLIENT_MAX_WINDOW_BITS); + p.set_value(Some(max.to_string())); + self.params.push(p) + } + } - fn set_their_max_window_bits(&mut self, p: &Param, expected: Option) -> Result<(), ()> { - if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { - if v < 8 || v > 15 { - log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v); - return Err(()) - } - if let Some(x) = expected { - if v > x { - log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x); - return Err(()) - } - } - self.their_max_window_bits = std::cmp::max(9, v); - } - Ok(()) - } + fn set_their_max_window_bits(&mut self, p: &Param, expected: Option) -> Result<(), ()> { + if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { + if v < 8 || v > 15 { + log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v); + return Err(()); + } + if let Some(x) = expected { + if v > x { + log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x); + return Err(()); + } + } + self.their_max_window_bits = std::cmp::max(9, v); + } + Ok(()) + } } impl Extension for Deflate { - fn name(&self) -> &str { - "permessage-deflate" - } + fn name(&self) -> &str { + "permessage-deflate" + } - fn is_enabled(&self) -> bool { - self.enabled - } + fn is_enabled(&self) -> bool { + self.enabled + } - fn params(&self) -> &[Param] { - &self.params - } + fn params(&self) -> &[Param] { + &self.params + } - fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { - match self.mode { - Mode::Server => { - self.params.clear(); - for p in params { - log::trace!("configure server with: {}", p); - match p.name() { - CLIENT_MAX_WINDOW_BITS => - if self.set_their_max_window_bits(&p, None).is_err() { - // we just accept the client's offer as is => no need to reply - return Ok(()) - } - SERVER_MAX_WINDOW_BITS => { - if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { - // The RFC allows 8 to 15 bits, but due to zlib limitations we - // only support 9 to 15. - if v < 9 || v > 15 { - log::debug!("unacceptable server_max_window_bits: {}", v); - return Ok(()) - } - let mut x = Param::new(SERVER_MAX_WINDOW_BITS); - x.set_value(Some(v.to_string())); - self.params.push(x); - self.our_max_window_bits = v; - } else { - log::debug!("invalid server_max_window_bits: {:?}", p.value()); - return Ok(()) - } - } - CLIENT_NO_CONTEXT_TAKEOVER => - self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)), - SERVER_NO_CONTEXT_TAKEOVER => - self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)), - _ => { - log::debug!("{}: unknown parameter: {}", self.name(), p.name()); - return Ok(()) - } - } - } - } - Mode::Client => { - let mut server_no_context_takeover = false; - for p in params { - log::trace!("configure client with: {}", p); - match p.name() { - SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true, - CLIENT_NO_CONTEXT_TAKEOVER => {} // must be supported - SERVER_MAX_WINDOW_BITS => { - let expected = Some(self.their_max_window_bits); - if self.set_their_max_window_bits(&p, expected).is_err() { - return Ok(()) - } - } - CLIENT_MAX_WINDOW_BITS => - if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { - if v < 8 || v > 15 { - log::debug!("unacceptable client_max_window_bits: {}", v); - return Ok(()) - } - use std::cmp::{min, max}; - // Due to zlib limitations we have to use 9 as a lower bound - // here, even if the server allowed us to go down to 8 bits. - self.our_max_window_bits = min(self.our_max_window_bits, max(9, v)); - } - _ => { - log::debug!("{}: unknown parameter: {}", self.name(), p.name()); - return Ok(()) - } - } - } - if !server_no_context_takeover { - log::debug!("{}: server did not confirm no context takeover", self.name()); - return Ok(()) - } - } - } - self.enabled = true; - Ok(()) - } + fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { + match self.mode { + Mode::Server => { + self.params.clear(); + for p in params { + log::trace!("configure server with: {}", p); + match p.name() { + CLIENT_MAX_WINDOW_BITS => { + if self.set_their_max_window_bits(&p, None).is_err() { + // we just accept the client's offer as is => no need to reply + return Ok(()); + } + } + SERVER_MAX_WINDOW_BITS => { + if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { + // The RFC allows 8 to 15 bits, but due to zlib limitations we + // only support 9 to 15. + if v < 9 || v > 15 { + log::debug!("unacceptable server_max_window_bits: {}", v); + return Ok(()); + } + let mut x = Param::new(SERVER_MAX_WINDOW_BITS); + x.set_value(Some(v.to_string())); + self.params.push(x); + self.our_max_window_bits = v; + } else { + log::debug!("invalid server_max_window_bits: {:?}", p.value()); + return Ok(()); + } + } + CLIENT_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)), + SERVER_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)), + _ => { + log::debug!("{}: unknown parameter: {}", self.name(), p.name()); + return Ok(()); + } + } + } + } + Mode::Client => { + let mut server_no_context_takeover = false; + for p in params { + log::trace!("configure client with: {}", p); + match p.name() { + SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true, + CLIENT_NO_CONTEXT_TAKEOVER => {} // must be supported + SERVER_MAX_WINDOW_BITS => { + let expected = Some(self.their_max_window_bits); + if self.set_their_max_window_bits(&p, expected).is_err() { + return Ok(()); + } + } + CLIENT_MAX_WINDOW_BITS => { + if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { + if v < 8 || v > 15 { + log::debug!("unacceptable client_max_window_bits: {}", v); + return Ok(()); + } + use std::cmp::{max, min}; + // Due to zlib limitations we have to use 9 as a lower bound + // here, even if the server allowed us to go down to 8 bits. + self.our_max_window_bits = min(self.our_max_window_bits, max(9, v)); + } + } + _ => { + log::debug!("{}: unknown parameter: {}", self.name(), p.name()); + return Ok(()); + } + } + } + if !server_no_context_takeover { + log::debug!("{}: server did not confirm no context takeover", self.name()); + return Ok(()); + } + } + } + self.enabled = true; + Ok(()) + } - fn reserved_bits(&self) -> (bool, bool, bool) { - (true, false, false) - } + fn reserved_bits(&self) -> (bool, bool, bool) { + (true, false, false) + } - fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { - if data.is_empty() { - return Ok(()) - } + fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { + if data.is_empty() { + return Ok(()); + } - match header.opcode() { - OpCode::Binary | OpCode::Text if header.is_rsv1() => { - if !header.is_fin() { - self.await_last_fragment = true; - log::trace!("deflate: not decoding {}; awaiting last fragment", header); - return Ok(()) - } - log::trace!("deflate: decoding {}", header) - } - OpCode::Continue if header.is_fin() && self.await_last_fragment => { - self.await_last_fragment = false; - log::trace!("deflate: decoding {}", header) - } - _ => { - log::trace!("deflate: not decoding {}", header); - return Ok(()) - } - } + match header.opcode() { + OpCode::Binary | OpCode::Text if header.is_rsv1() => { + if !header.is_fin() { + self.await_last_fragment = true; + log::trace!("deflate: not decoding {}; awaiting last fragment", header); + return Ok(()); + } + log::trace!("deflate: decoding {}", header) + } + OpCode::Continue if header.is_fin() && self.await_last_fragment => { + self.await_last_fragment = false; + log::trace!("deflate: decoding {}", header) + } + _ => { + log::trace!("deflate: not decoding {}", header); + return Ok(()); + } + } - // Restore LEN and NLEN: - data.extend_from_slice(&[0, 0, 0xFF, 0xFF]); // cf. RFC 7692, 7.2.2 + // Restore LEN and NLEN: + data.extend_from_slice(&[0, 0, 0xFF, 0xFF]); // cf. RFC 7692, 7.2.2 - self.buffer.clear(); - let mut decoder = DeflateDecoder::new(&mut self.buffer); - decoder.write_all(&data)?; - decoder.finish()?; - mem::swap(data, &mut self.buffer); + self.buffer.clear(); + let mut decoder = DeflateDecoder::new(&mut self.buffer); + decoder.write_all(&data)?; + decoder.finish()?; + mem::swap(data, &mut self.buffer); - header.set_rsv1(false); - header.set_payload_len(data.len()); + header.set_rsv1(false); + header.set_payload_len(data.len()); - Ok(()) - } + Ok(()) + } - fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { - if data.as_ref().is_empty() { - return Ok(()) - } + fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { + if data.as_ref().is_empty() { + return Ok(()); + } - if let OpCode::Binary | OpCode::Text = header.opcode() { - log::trace!("deflate: encoding {}", header) - } else { - log::trace!("deflate: not encoding {}", header); - return Ok(()) - } + if let OpCode::Binary | OpCode::Text = header.opcode() { + log::trace!("deflate: encoding {}", header) + } else { + log::trace!("deflate: not encoding {}", header); + return Ok(()); + } - self.buffer.clear(); - self.buffer.reserve(data.as_ref().len()); + self.buffer.clear(); + self.buffer.reserve(data.as_ref().len()); - let mut encoder = - Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits); + let mut encoder = Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits); - // Compress all input bytes. - while encoder.total_in() < as_u64(data.as_ref().len()) { - let i: usize = encoder.total_in().try_into()?; - match encoder.compress_vec(&data.as_ref()[i ..], &mut self.buffer, FlushCompress::None)? { - Status::BufError => self.buffer.reserve(4096), - Status::Ok => continue, - Status::StreamEnd => break - } - } + // Compress all input bytes. + while encoder.total_in() < as_u64(data.as_ref().len()) { + let i: usize = encoder.total_in().try_into()?; + match encoder.compress_vec(&data.as_ref()[i..], &mut self.buffer, FlushCompress::None)? { + Status::BufError => self.buffer.reserve(4096), + Status::Ok => continue, + Status::StreamEnd => break, + } + } - // We need to append an empty deflate block if not there yet (RFC 7692, 7.2.1). - while !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { - self.buffer.reserve(5); // Make sure there is room for the trailing end bytes. - match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { - Status::Ok => continue, - Status::BufError => continue, // more capacity is reserved above - Status::StreamEnd => break - } - } + // We need to append an empty deflate block if not there yet (RFC 7692, 7.2.1). + while !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { + self.buffer.reserve(5); // Make sure there is room for the trailing end bytes. + match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { + Status::Ok => continue, + Status::BufError => continue, // more capacity is reserved above + Status::StreamEnd => break, + } + } - // If we still have not seen the empty deflate block appended, something is wrong. - if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { - log::error!("missing 00 00 FF FF"); - return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()) - } + // If we still have not seen the empty deflate block appended, something is wrong. + if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { + log::error!("missing 00 00 FF FF"); + return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()); + } - self.buffer.truncate(self.buffer.len() - 4); // Remove 00 00 FF FF; cf. RFC 7692, 7.2.1 + self.buffer.truncate(self.buffer.len() - 4); // Remove 00 00 FF FF; cf. RFC 7692, 7.2.1 - if let Storage::Owned(d) = data { - mem::swap(d, &mut self.buffer) - } else { - *data = Storage::Owned(mem::take(&mut self.buffer)) - } - header.set_rsv1(true); - header.set_payload_len(data.as_ref().len()); - Ok(()) - } + if let Storage::Owned(d) = data { + mem::swap(d, &mut self.buffer) + } else { + *data = Storage::Owned(mem::take(&mut self.buffer)) + } + header.set_rsv1(true); + header.set_payload_len(data.as_ref().len()); + Ok(()) + } } - diff --git a/src/handshake.rs b/src/handshake.rs index 761654c4..433430a2 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -13,12 +13,12 @@ pub mod client; pub mod server; +use crate::extension::{Extension, Param}; use bytes::BytesMut; -use crate::extension::{Param, Extension}; use std::{fmt, io, str}; pub use client::{Client, ServerResponse}; -pub use server::{Server, ClientRequest}; +pub use server::{ClientRequest, Server}; // Defined in RFC 6455 and used to generate the `Sec-WebSocket-Accept` header // in the server handshake response. @@ -33,194 +33,178 @@ const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol"; /// Check a set of headers contains a specific one. fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> { - enum State { - Init, // Start state - Name, // Header name found - Match // Header value matches - } - - headers.iter() - .filter(|h| h.name.eq_ignore_ascii_case(name)) - .fold(Ok(State::Init), |result, header| { - if let Ok(State::Match) = result { - return result - } - if str::from_utf8(header.value)? - .split(',') - .any(|v| v.trim().eq_ignore_ascii_case(ours)) - { - return Ok(State::Match) - } - Ok(State::Name) - }) - .and_then(|state| { - match state { - State::Init => Err(Error::HeaderNotFound(name.into())), - State::Name => Err(Error::UnexpectedHeader(name.into())), - State::Match => Ok(()) - } - }) + enum State { + Init, // Start state + Name, // Header name found + Match, // Header value matches + } + + headers + .iter() + .filter(|h| h.name.eq_ignore_ascii_case(name)) + .fold(Ok(State::Init), |result, header| { + if let Ok(State::Match) = result { + return result; + } + if str::from_utf8(header.value)?.split(',').any(|v| v.trim().eq_ignore_ascii_case(ours)) { + return Ok(State::Match); + } + Ok(State::Name) + }) + .and_then(|state| match state { + State::Init => Err(Error::HeaderNotFound(name.into())), + State::Name => Err(Error::UnexpectedHeader(name.into())), + State::Match => Ok(()), + }) } /// Pick the first header with the given name and apply the given closure to it. fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result where - F: Fn(&'a [u8]) -> Result + F: Fn(&'a [u8]) -> Result, { - if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) { - f(h.value) - } else { - Err(Error::HeaderNotFound(name.into())) - } + if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) { + f(h.value) + } else { + Err(Error::HeaderNotFound(name.into())) + } } // Configure all extensions with parsed parameters. fn configure_extensions(extensions: &mut [Box], line: &str) -> Result<(), Error> { - for e in line.split(',') { - let mut ext_parts = e.split(';'); - if let Some(name) = ext_parts.next() { - let name = name.trim(); - if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) { - let mut params = Vec::new(); - for p in ext_parts { - let mut key_value = p.split('='); - if let Some(key) = key_value.next().map(str::trim) { - let val = key_value.next().map(|v| v.trim().trim_matches('"')); - let mut p = Param::new(key); - p.set_value(val); - params.push(p) - } - } - ext.configure(¶ms).map_err(Error::Extension)? - } - } - } - Ok(()) + for e in line.split(',') { + let mut ext_parts = e.split(';'); + if let Some(name) = ext_parts.next() { + let name = name.trim(); + if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) { + let mut params = Vec::new(); + for p in ext_parts { + let mut key_value = p.split('='); + if let Some(key) = key_value.next().map(str::trim) { + let val = key_value.next().map(|v| v.trim().trim_matches('"')); + let mut p = Param::new(key); + p.set_value(val); + params.push(p) + } + } + ext.configure(¶ms).map_err(Error::Extension)? + } + } + } + Ok(()) } // Write all extensions to the given buffer. fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut) where - I: IntoIterator> + I: IntoIterator>, { - let mut iter = extensions.into_iter().peekable(); - - if iter.peek().is_some() { - bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") - } - - while let Some(e) = iter.next() { - bytes.extend_from_slice(e.name().as_bytes()); - for p in e.params() { - bytes.extend_from_slice(b"; "); - bytes.extend_from_slice(p.name().as_bytes()); - if let Some(v) = p.value() { - bytes.extend_from_slice(b"="); - bytes.extend_from_slice(v.as_bytes()) - } - } - if iter.peek().is_some() { - bytes.extend_from_slice(b", ") - } - } + let mut iter = extensions.into_iter().peekable(); + + if iter.peek().is_some() { + bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") + } + + while let Some(e) = iter.next() { + bytes.extend_from_slice(e.name().as_bytes()); + for p in e.params() { + bytes.extend_from_slice(b"; "); + bytes.extend_from_slice(p.name().as_bytes()); + if let Some(v) = p.value() { + bytes.extend_from_slice(b"="); + bytes.extend_from_slice(v.as_bytes()) + } + } + if iter.peek().is_some() { + bytes.extend_from_slice(b", ") + } + } } /// Enumeration of possible handshake errors. #[non_exhaustive] #[derive(Debug)] pub enum Error { - /// An I/O error has been encountered. - Io(io::Error), - /// An HTTP version =/= 1.1 was encountered. - UnsupportedHttpVersion, - /// An incomplete HTTP request. - IncompleteHttpRequest, - /// The value of the `Sec-WebSocket-Key` header is of unexpected length. - SecWebSocketKeyInvalidLength(usize), - /// The handshake request was not a GET request. - InvalidRequestMethod, - /// An HTTP header has not been present. - HeaderNotFound(String), - /// An HTTP header value was not expected. - UnexpectedHeader(String), - /// The Sec-WebSocket-Accept header value did not match. - InvalidSecWebSocketAccept, - /// The server returned an extension we did not ask for. - UnsolicitedExtension, - /// The server returned a protocol we did not ask for. - UnsolicitedProtocol, - /// An extension produced an error while encoding or decoding. - Extension(crate::BoxedError), - /// The HTTP entity could not be parsed successfully. - Http(crate::BoxedError), - /// UTF-8 decoding failed. - Utf8(str::Utf8Error) + /// An I/O error has been encountered. + Io(io::Error), + /// An HTTP version =/= 1.1 was encountered. + UnsupportedHttpVersion, + /// An incomplete HTTP request. + IncompleteHttpRequest, + /// The value of the `Sec-WebSocket-Key` header is of unexpected length. + SecWebSocketKeyInvalidLength(usize), + /// The handshake request was not a GET request. + InvalidRequestMethod, + /// An HTTP header has not been present. + HeaderNotFound(String), + /// An HTTP header value was not expected. + UnexpectedHeader(String), + /// The Sec-WebSocket-Accept header value did not match. + InvalidSecWebSocketAccept, + /// The server returned an extension we did not ask for. + UnsolicitedExtension, + /// The server returned a protocol we did not ask for. + UnsolicitedProtocol, + /// An extension produced an error while encoding or decoding. + Extension(crate::BoxedError), + /// The HTTP entity could not be parsed successfully. + Http(crate::BoxedError), + /// UTF-8 decoding failed. + Utf8(str::Utf8Error), } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Error::Io(e) => - write!(f, "i/o error: {}", e), - Error::UnsupportedHttpVersion => - f.write_str("http version was not 1.1"), - Error::IncompleteHttpRequest => - f.write_str("http request was incomplete"), - Error::SecWebSocketKeyInvalidLength(len) => - write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len), - Error::InvalidRequestMethod => - f.write_str("handshake was not a GET request"), - Error::HeaderNotFound(name) => - write!(f, "header {} not found", name), - Error::UnexpectedHeader(name) => - write!(f, "header {} had an unexpected value", name), - Error::InvalidSecWebSocketAccept => - f.write_str("websocket key mismatch"), - Error::UnsolicitedExtension => - f.write_str("unsolicited extension returned"), - Error::UnsolicitedProtocol => - f.write_str("unsolicited protocol returned"), - Error::Extension(e) => - write!(f, "extension error: {}", e), - Error::Http(e) => - write!(f, "http parser error: {}", e), - Error::Utf8(e) => - write!(f, "utf-8 decoding error: {}", e) - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Io(e) => write!(f, "i/o error: {}", e), + Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"), + Error::IncompleteHttpRequest => f.write_str("http request was incomplete"), + Error::SecWebSocketKeyInvalidLength(len) => { + write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len) + } + Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"), + Error::HeaderNotFound(name) => write!(f, "header {} not found", name), + Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name), + Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"), + Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"), + Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"), + Error::Extension(e) => write!(f, "extension error: {}", e), + Error::Http(e) => write!(f, "http parser error: {}", e), + Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e), + } + } } impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Error::Io(e) => Some(e), - Error::Extension(e) => Some(&**e), - Error::Http(e) => Some(&**e), - Error::Utf8(e) => Some(e), - Error::UnsupportedHttpVersion - | Error::IncompleteHttpRequest - | Error::SecWebSocketKeyInvalidLength(_) - | Error::InvalidRequestMethod - | Error::HeaderNotFound(_) - | Error::UnexpectedHeader(_) - | Error::InvalidSecWebSocketAccept - | Error::UnsolicitedExtension - | Error::UnsolicitedProtocol - => None - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Io(e) => Some(e), + Error::Extension(e) => Some(&**e), + Error::Http(e) => Some(&**e), + Error::Utf8(e) => Some(e), + Error::UnsupportedHttpVersion + | Error::IncompleteHttpRequest + | Error::SecWebSocketKeyInvalidLength(_) + | Error::InvalidRequestMethod + | Error::HeaderNotFound(_) + | Error::UnexpectedHeader(_) + | Error::InvalidSecWebSocketAccept + | Error::UnsolicitedExtension + | Error::UnsolicitedProtocol => None, + } + } } impl From for Error { - fn from(e: io::Error) -> Self { - Error::Io(e) - } + fn from(e: io::Error) -> Self { + Error::Io(e) + } } impl From for Error { - fn from(e: str::Utf8Error) -> Self { - Error::Utf8(e) - } + fn from(e: str::Utf8Error) -> Self { + Error::Utf8(e) + } } /// Owned value of the `Sec-WebSocket-Key` header. @@ -238,31 +222,31 @@ pub type WebSocketKey = [u8; 24]; #[cfg(test)] mod tests { - use super::expect_ascii_header; - - #[test] - fn header_match() { - let headers = &[ - httparse::Header { name: "foo", value: b"a,b,c,d" }, - httparse::Header { name: "foo", value: b"x" }, - httparse::Header { name: "foo", value: b"y, z, a" }, - httparse::Header { name: "bar", value: b"xxx" }, - httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" }, - httparse::Header { name: "baz", value: b"123" } - ]; - - assert!(expect_ascii_header(headers, "foo", "a").is_ok()); - assert!(expect_ascii_header(headers, "foo", "b").is_ok()); - assert!(expect_ascii_header(headers, "foo", "c").is_ok()); - assert!(expect_ascii_header(headers, "foo", "d").is_ok()); - assert!(expect_ascii_header(headers, "foo", "x").is_ok()); - assert!(expect_ascii_header(headers, "foo", "y").is_ok()); - assert!(expect_ascii_header(headers, "foo", "z").is_ok()); - assert!(expect_ascii_header(headers, "foo", "a").is_ok()); - assert!(expect_ascii_header(headers, "bar", "xxx").is_ok()); - assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok()); - assert!(expect_ascii_header(headers, "baz", "123").is_ok()); - assert!(expect_ascii_header(headers, "baz", "???").is_err()); - assert!(expect_ascii_header(headers, "???", "x").is_err()); - } + use super::expect_ascii_header; + + #[test] + fn header_match() { + let headers = &[ + httparse::Header { name: "foo", value: b"a,b,c,d" }, + httparse::Header { name: "foo", value: b"x" }, + httparse::Header { name: "foo", value: b"y, z, a" }, + httparse::Header { name: "bar", value: b"xxx" }, + httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" }, + httparse::Header { name: "baz", value: b"123" }, + ]; + + assert!(expect_ascii_header(headers, "foo", "a").is_ok()); + assert!(expect_ascii_header(headers, "foo", "b").is_ok()); + assert!(expect_ascii_header(headers, "foo", "c").is_ok()); + assert!(expect_ascii_header(headers, "foo", "d").is_ok()); + assert!(expect_ascii_header(headers, "foo", "x").is_ok()); + assert!(expect_ascii_header(headers, "foo", "y").is_ok()); + assert!(expect_ascii_header(headers, "foo", "z").is_ok()); + assert!(expect_ascii_header(headers, "foo", "a").is_ok()); + assert!(expect_ascii_header(headers, "bar", "xxx").is_ok()); + assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok()); + assert!(expect_ascii_header(headers, "baz", "123").is_ok()); + assert!(expect_ascii_header(headers, "baz", "???").is_err()); + assert!(expect_ascii_header(headers, "???", "x").is_err()); + } } diff --git a/src/handshake/client.rs b/src/handshake/client.rs index 37b215d3..a7b6cbd2 100644 --- a/src/handshake/client.rs +++ b/src/handshake/client.rs @@ -10,244 +10,231 @@ //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 -use bytes::{Buf, BytesMut}; -use crate::{Parsing, extension::Extension}; +use super::{ + append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY, + MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, +}; use crate::connection::{self, Mode}; +use crate::{extension::Extension, Parsing}; +use bytes::{Buf, BytesMut}; use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; -use super::{ - WebSocketKey, - Error, - KEY, - MAX_NUM_HEADERS, - SEC_WEBSOCKET_EXTENSIONS, - SEC_WEBSOCKET_PROTOCOL, - append_extensions, - configure_extensions, - expect_ascii_header, - with_first_header -}; const BLOCK_SIZE: usize = 8 * 1024; /// Websocket client handshake. #[derive(Debug)] pub struct Client<'a, T> { - /// The underlying async I/O resource. - socket: T, - /// The HTTP host to send the handshake to. - host: &'a str, - /// The HTTP host ressource. - resource: &'a str, - /// The HTTP origin header. - origin: Option<&'a str>, - /// A buffer holding the base-64 encoded request nonce. - nonce: WebSocketKey, - /// The protocols to include in the handshake. - protocols: Vec<&'a str>, - /// The extensions the client wishes to include in the request. - extensions: Vec>, - /// Encoding/decoding buffer. - buffer: BytesMut + /// The underlying async I/O resource. + socket: T, + /// The HTTP host to send the handshake to. + host: &'a str, + /// The HTTP host ressource. + resource: &'a str, + /// The HTTP origin header. + origin: Option<&'a str>, + /// A buffer holding the base-64 encoded request nonce. + nonce: WebSocketKey, + /// The protocols to include in the handshake. + protocols: Vec<&'a str>, + /// The extensions the client wishes to include in the request. + extensions: Vec>, + /// Encoding/decoding buffer. + buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { - /// Create a new client handshake for some host and resource. - pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self { - Client { - socket, - host, - resource, - origin: None, - nonce: [0; 24], - protocols: Vec::new(), - extensions: Vec::new(), - buffer: BytesMut::new() - } - } - - /// Override the buffer to use for request/response handling. - pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { - self.buffer = b; - self - } - - /// Extract the buffer. - pub fn take_buffer(&mut self) -> BytesMut { - mem::take(&mut self.buffer) - } - - /// Set the handshake origin header. - pub fn set_origin(&mut self, o: &'a str) -> &mut Self { - self.origin = Some(o); - self - } - - /// Add a protocol to be included in the handshake. - pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { - self.protocols.push(p); - self - } - - /// Add an extension to be included in the handshake. - pub fn add_extension(&mut self, e: Box) -> &mut Self { - self.extensions.push(e); - self - } - - /// Get back all extensions. - pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { - self.extensions.drain(..) - } - - /// Initiate client handshake request to server and get back the response. - pub async fn handshake(&mut self) -> Result { - self.buffer.clear(); - self.encode_request(); - self.socket.write_all(&self.buffer).await?; - self.socket.flush().await?; - self.buffer.clear(); - - loop { - crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; - if let Parsing::Done { value, offset } = self.decode_response()? { - self.buffer.advance(offset); - return Ok(value) - } - } - } - - /// Turn this handshake into a [`connection::Builder`]. - pub fn into_builder(mut self) -> connection::Builder { - let mut builder = connection::Builder::new(self.socket, Mode::Client); - builder.set_buffer(self.buffer); - builder.add_extensions(self.extensions.drain(..)); - builder - } - - /// Get out the inner socket of the client. - pub fn into_inner(self) -> T { - self.socket - } - - /// Encode the client handshake as a request, ready to be sent to the server. - fn encode_request(&mut self) { - let nonce: [u8; 16] = rand::random(); - base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce); - self.buffer.extend_from_slice(b"GET "); - self.buffer.extend_from_slice(self.resource.as_bytes()); - self.buffer.extend_from_slice(b" HTTP/1.1"); - self.buffer.extend_from_slice(b"\r\nHost: "); - self.buffer.extend_from_slice(self.host.as_bytes()); - self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); - self.buffer.extend_from_slice(&self.nonce); - if let Some(o) = &self.origin { - self.buffer.extend_from_slice(b"\r\nOrigin: "); - self.buffer.extend_from_slice(o.as_bytes()) - } - if let Some((last, prefix)) = self.protocols.split_last() { - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); - for p in prefix { - self.buffer.extend_from_slice(p.as_bytes()); - self.buffer.extend_from_slice(b",") - } - self.buffer.extend_from_slice(last.as_bytes()) - } - append_extensions(&self.extensions, &mut self.buffer); - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") - } - - /// Decode the server response to this client request. - fn decode_response(&mut self) -> Result, Error> { - let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; - let mut response = httparse::Response::new(&mut header_buf); - - let offset = match response.parse(self.buffer.as_ref()) { - Ok(httparse::Status::Complete(off)) => off, - Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())), - Err(e) => return Err(Error::Http(Box::new(e))) - }; - - if response.version != Some(1) { - return Err(Error::UnsupportedHttpVersion) - } - - match response.code { - Some(101) => (), - Some(code@(301 ..= 303)) | Some(code@307) | Some(code@308) => { // redirect response - let location = with_first_header(response.headers, "Location", |loc| { - Ok(String::from(std::str::from_utf8(loc)?)) - })?; - let response = ServerResponse::Redirect { status_code: code, location }; - return Ok(Parsing::Done { value: response, offset }) - } - other => { - let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) }; - return Ok(Parsing::Done { value: response, offset }) - } - } - - expect_ascii_header(response.headers, "Upgrade", "websocket")?; - expect_ascii_header(response.headers, "Connection", "upgrade")?; - - with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| { - let mut digest = Sha1::new(); - digest.update(&self.nonce); - digest.update(KEY); - let ours = base64::encode(&digest.finalize()); - if ours.as_bytes() != theirs { - return Err(Error::InvalidSecWebSocketAccept) - } - Ok(()) - })?; - - // Parse `Sec-WebSocket-Extensions` headers. - - for h in response.headers.iter() - .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) - { - configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? - } - - // Match `Sec-WebSocket-Protocol` header. - - let mut selected_proto = None; - if let Some(tp) = response.headers.iter() - .find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) - { - if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) { - selected_proto = Some(String::from(p)) - } else { - return Err(Error::UnsolicitedProtocol) - } - } - - let response = ServerResponse::Accepted { protocol: selected_proto }; - Ok(Parsing::Done { value: response, offset }) - } + /// Create a new client handshake for some host and resource. + pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self { + Client { + socket, + host, + resource, + origin: None, + nonce: [0; 24], + protocols: Vec::new(), + extensions: Vec::new(), + buffer: BytesMut::new(), + } + } + + /// Override the buffer to use for request/response handling. + pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { + self.buffer = b; + self + } + + /// Extract the buffer. + pub fn take_buffer(&mut self) -> BytesMut { + mem::take(&mut self.buffer) + } + + /// Set the handshake origin header. + pub fn set_origin(&mut self, o: &'a str) -> &mut Self { + self.origin = Some(o); + self + } + + /// Add a protocol to be included in the handshake. + pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { + self.protocols.push(p); + self + } + + /// Add an extension to be included in the handshake. + pub fn add_extension(&mut self, e: Box) -> &mut Self { + self.extensions.push(e); + self + } + + /// Get back all extensions. + pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { + self.extensions.drain(..) + } + + /// Initiate client handshake request to server and get back the response. + pub async fn handshake(&mut self) -> Result { + self.buffer.clear(); + self.encode_request(); + self.socket.write_all(&self.buffer).await?; + self.socket.flush().await?; + self.buffer.clear(); + + loop { + crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; + if let Parsing::Done { value, offset } = self.decode_response()? { + self.buffer.advance(offset); + return Ok(value); + } + } + } + + /// Turn this handshake into a [`connection::Builder`]. + pub fn into_builder(mut self) -> connection::Builder { + let mut builder = connection::Builder::new(self.socket, Mode::Client); + builder.set_buffer(self.buffer); + builder.add_extensions(self.extensions.drain(..)); + builder + } + + /// Get out the inner socket of the client. + pub fn into_inner(self) -> T { + self.socket + } + + /// Encode the client handshake as a request, ready to be sent to the server. + fn encode_request(&mut self) { + let nonce: [u8; 16] = rand::random(); + base64::encode_config_slice(&nonce, base64::STANDARD, &mut self.nonce); + self.buffer.extend_from_slice(b"GET "); + self.buffer.extend_from_slice(self.resource.as_bytes()); + self.buffer.extend_from_slice(b" HTTP/1.1"); + self.buffer.extend_from_slice(b"\r\nHost: "); + self.buffer.extend_from_slice(self.host.as_bytes()); + self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); + self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); + self.buffer.extend_from_slice(&self.nonce); + if let Some(o) = &self.origin { + self.buffer.extend_from_slice(b"\r\nOrigin: "); + self.buffer.extend_from_slice(o.as_bytes()) + } + if let Some((last, prefix)) = self.protocols.split_last() { + self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); + for p in prefix { + self.buffer.extend_from_slice(p.as_bytes()); + self.buffer.extend_from_slice(b",") + } + self.buffer.extend_from_slice(last.as_bytes()) + } + append_extensions(&self.extensions, &mut self.buffer); + self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") + } + + /// Decode the server response to this client request. + fn decode_response(&mut self) -> Result, Error> { + let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; + let mut response = httparse::Response::new(&mut header_buf); + + let offset = match response.parse(self.buffer.as_ref()) { + Ok(httparse::Status::Complete(off)) => off, + Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())), + Err(e) => return Err(Error::Http(Box::new(e))), + }; + + if response.version != Some(1) { + return Err(Error::UnsupportedHttpVersion); + } + + match response.code { + Some(101) => (), + Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => { + // redirect response + let location = + with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?; + let response = ServerResponse::Redirect { status_code: code, location }; + return Ok(Parsing::Done { value: response, offset }); + } + other => { + let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) }; + return Ok(Parsing::Done { value: response, offset }); + } + } + + expect_ascii_header(response.headers, "Upgrade", "websocket")?; + expect_ascii_header(response.headers, "Connection", "upgrade")?; + + with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| { + let mut digest = Sha1::new(); + digest.update(&self.nonce); + digest.update(KEY); + let ours = base64::encode(&digest.finalize()); + if ours.as_bytes() != theirs { + return Err(Error::InvalidSecWebSocketAccept); + } + Ok(()) + })?; + + // Parse `Sec-WebSocket-Extensions` headers. + + for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { + configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? + } + + // Match `Sec-WebSocket-Protocol` header. + + let mut selected_proto = None; + if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { + if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) { + selected_proto = Some(String::from(p)) + } else { + return Err(Error::UnsolicitedProtocol); + } + } + + let response = ServerResponse::Accepted { protocol: selected_proto }; + Ok(Parsing::Done { value: response, offset }) + } } /// Handshake response received from the server. #[derive(Debug)] pub enum ServerResponse { - /// The server has accepted our request. - Accepted { - /// The protocol (if any) the server has selected. - protocol: Option - }, - /// The server is redirecting us to some other location. - Redirect { - /// The HTTP response status code. - status_code: u16, - /// The location URL we should go to. - location: String - }, - /// The server rejected our request. - Rejected { - /// HTTP response status code. - status_code: u16 - } + /// The server has accepted our request. + Accepted { + /// The protocol (if any) the server has selected. + protocol: Option, + }, + /// The server is redirecting us to some other location. + Redirect { + /// The HTTP response status code. + status_code: u16, + /// The location URL we should go to. + location: String, + }, + /// The server rejected our request. + Rejected { + /// HTTP response status code. + status_code: u16, + }, } - diff --git a/src/handshake/server.rs b/src/handshake/server.rs index 2234a369..a4ba6b13 100644 --- a/src/handshake/server.rs +++ b/src/handshake/server.rs @@ -10,351 +10,334 @@ //! //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 -use bytes::BytesMut; -use crate::extension::Extension; +use super::{ + append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY, + MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, +}; use crate::connection::{self, Mode}; +use crate::extension::Extension; +use bytes::BytesMut; use futures::prelude::*; use sha1::{Digest, Sha1}; use std::{mem, str}; -use super::{ - WebSocketKey, - Error, - KEY, - MAX_NUM_HEADERS, - SEC_WEBSOCKET_EXTENSIONS, - SEC_WEBSOCKET_PROTOCOL, - append_extensions, - configure_extensions, - expect_ascii_header, - with_first_header -}; // Most HTTP servers default to 8KB limit on headers const MAX_HEADERS_SIZE: usize = 8 * 1024; const BLOCK_SIZE: usize = 8 * 1024; -const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION"); /// Websocket handshake client. #[derive(Debug)] pub struct Server<'a, T> { - socket: T, - /// Protocols the server supports. - protocols: Vec<&'a str>, - /// Extensions the server supports. - extensions: Vec>, - /// Encoding/decoding buffer. - buffer: BytesMut + socket: T, + /// Protocols the server supports. + protocols: Vec<&'a str>, + /// Extensions the server supports. + extensions: Vec>, + /// Encoding/decoding buffer. + buffer: BytesMut, } impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { - /// Create a new server handshake. - pub fn new(socket: T) -> Self { - Server { - socket, - protocols: Vec::new(), - extensions: Vec::new(), - buffer: BytesMut::new() - } - } - - /// Override the buffer to use for request/response handling. - pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { - self.buffer = b; - self - } - - /// Extract the buffer. - pub fn take_buffer(&mut self) -> BytesMut { - mem::take(&mut self.buffer) - } - - /// Add a protocol the server supports. - pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { - self.protocols.push(p); - self - } - - /// Add an extension the server supports. - pub fn add_extension(&mut self, e: Box) -> &mut Self { - self.extensions.push(e); - self - } - - /// Get back all extensions. - pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { - self.extensions.drain(..) - } - - /// Await an incoming client handshake request. - pub async fn receive_request(&mut self) -> Result, Error> { - self.buffer.clear(); - - let mut skip = 0; - - loop { - crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; - - let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE); - - // We don't expect body, so can search for the CRLF headers tail from - // the end of the buffer. - if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") { - break; - } - - // Give up if we've reached the limit. We could emit a specific error here, - // but httparse will produce meaningful error for us regardless. - if limit == MAX_HEADERS_SIZE { - break; - } - - // Skip bytes that did not contain CRLF in the next iteration. - // If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer - // length, hence backing off the full 4 bytes. - skip = self.buffer.len().saturating_sub(4); - } - - self.decode_request() - } - - /// Respond to the client. - pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> { - self.buffer.clear(); - self.encode_response(r); - self.socket.write_all(&self.buffer).await?; - self.socket.flush().await?; - self.buffer.clear(); - Ok(()) - } - - /// Turn this handshake into a [`connection::Builder`]. - pub fn into_builder(mut self) -> connection::Builder { - let mut builder = connection::Builder::new(self.socket, Mode::Server); - builder.set_buffer(self.buffer); - builder.add_extensions(self.extensions.drain(..)); - builder - } - - /// Get out the inner socket of the server. - pub fn into_inner(self) -> T { - self.socket - } - - // Decode client handshake request. - fn decode_request(&mut self) -> Result { - let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; - let mut request = httparse::Request::new(&mut header_buf); - - match request.parse(self.buffer.as_ref()) { - Ok(httparse::Status::Complete(_)) => (), - Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest), - Err(e) => return Err(Error::Http(Box::new(e))) - }; - if request.method != Some("GET") { - return Err(Error::InvalidRequestMethod) - } - if request.version != Some(1) { - return Err(Error::UnsupportedHttpVersion) - } - - let host = with_first_header(&request.headers, "Host", Ok)?; - - expect_ascii_header(request.headers, "Upgrade", "websocket")?; - expect_ascii_header(request.headers, "Connection", "upgrade")?; - expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?; - - let origin = request.headers.iter().find_map(|h| { - if h.name.eq_ignore_ascii_case("Origin") { - Some(h.value) - } else { - None - } - }); - let headers = RequestHeaders { host, origin }; - - let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| { - use std::convert::TryFrom; - - WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len())) - })?; - - for h in request.headers.iter() - .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) - { - configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? - } - - let mut protocols = Vec::new(); - for p in request.headers.iter() - .filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) - { - if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) { - protocols.push(p) - } - } - - let path = request.path.unwrap_or("/"); - - Ok(ClientRequest { ws_key, protocols, path, headers }) - } - - // Encode server handshake response. - fn encode_response(&mut self, response: &Response<'_>) { - match response { - Response::Accept { key, protocol } => { - let mut key_buf = [0; 32]; - let accept_value = { - let mut digest = Sha1::new(); - digest.update(key); - digest.update(KEY); - let d = digest.finalize(); - let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf); - &key_buf[.. n] - }; - self.buffer.extend_from_slice(b"HTTP/1.1 101 Switching Protocols"); - self.buffer.extend_from_slice(b"\r\nServer: soketto-"); - self.buffer.extend_from_slice(SOKETTO_VERSION.as_bytes()); - self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: upgrade"); - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Accept: "); - self.buffer.extend_from_slice(accept_value); - if let Some(p) = protocol { - self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); - self.buffer.extend_from_slice(p.as_bytes()) - } - append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer); - self.buffer.extend_from_slice(b"\r\n\r\n") - } - Response::Reject { status_code } => { - self.buffer.extend_from_slice(b"HTTP/1.1 "); - let (_, s, reason) = - if let Ok(i) = STATUSCODES.binary_search_by_key(status_code, |(n, _, _)| *n) { - STATUSCODES[i] - } else { - (500, "500", "Internal Server Error") - }; - self.buffer.extend_from_slice(s.as_bytes()); - self.buffer.extend_from_slice(b" "); - self.buffer.extend_from_slice(reason.as_bytes()); - self.buffer.extend_from_slice(b"\r\n\r\n") - } - } - } + /// Create a new server handshake. + pub fn new(socket: T) -> Self { + Server { socket, protocols: Vec::new(), extensions: Vec::new(), buffer: BytesMut::new() } + } + + /// Override the buffer to use for request/response handling. + pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { + self.buffer = b; + self + } + + /// Extract the buffer. + pub fn take_buffer(&mut self) -> BytesMut { + mem::take(&mut self.buffer) + } + + /// Add a protocol the server supports. + pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { + self.protocols.push(p); + self + } + + /// Add an extension the server supports. + pub fn add_extension(&mut self, e: Box) -> &mut Self { + self.extensions.push(e); + self + } + + /// Get back all extensions. + pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { + self.extensions.drain(..) + } + + /// Await an incoming client handshake request. + pub async fn receive_request(&mut self) -> Result, Error> { + self.buffer.clear(); + + let mut skip = 0; + + loop { + crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; + + let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE); + + // We don't expect body, so can search for the CRLF headers tail from + // the end of the buffer. + if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") { + break; + } + + // Give up if we've reached the limit. We could emit a specific error here, + // but httparse will produce meaningful error for us regardless. + if limit == MAX_HEADERS_SIZE { + break; + } + + // Skip bytes that did not contain CRLF in the next iteration. + // If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer + // length, hence backing off the full 4 bytes. + skip = self.buffer.len().saturating_sub(4); + } + + self.decode_request() + } + + /// Respond to the client. + pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> { + self.buffer.clear(); + self.encode_response(r); + self.socket.write_all(&self.buffer).await?; + self.socket.flush().await?; + self.buffer.clear(); + Ok(()) + } + + /// Turn this handshake into a [`connection::Builder`]. + pub fn into_builder(mut self) -> connection::Builder { + let mut builder = connection::Builder::new(self.socket, Mode::Server); + builder.set_buffer(self.buffer); + builder.add_extensions(self.extensions.drain(..)); + builder + } + + /// Get out the inner socket of the server. + pub fn into_inner(self) -> T { + self.socket + } + + // Decode client handshake request. + fn decode_request(&mut self) -> Result { + let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; + let mut request = httparse::Request::new(&mut header_buf); + + match request.parse(self.buffer.as_ref()) { + Ok(httparse::Status::Complete(_)) => (), + Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest), + Err(e) => return Err(Error::Http(Box::new(e))), + }; + if request.method != Some("GET") { + return Err(Error::InvalidRequestMethod); + } + if request.version != Some(1) { + return Err(Error::UnsupportedHttpVersion); + } + + let host = with_first_header(&request.headers, "Host", Ok)?; + + expect_ascii_header(request.headers, "Upgrade", "websocket")?; + expect_ascii_header(request.headers, "Connection", "upgrade")?; + expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?; + + let origin = + request.headers.iter().find_map( + |h| { + if h.name.eq_ignore_ascii_case("Origin") { + Some(h.value) + } else { + None + } + }, + ); + let headers = RequestHeaders { host, origin }; + + let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| { + use std::convert::TryFrom; + + WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len())) + })?; + + for h in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { + configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? + } + + let mut protocols = Vec::new(); + for p in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { + if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) { + protocols.push(p) + } + } + + let path = request.path.unwrap_or("/"); + + Ok(ClientRequest { ws_key, protocols, path, headers }) + } + + // Encode server handshake response. + fn encode_response(&mut self, response: &Response<'_>) { + match response { + Response::Accept { key, protocol } => { + let mut key_buf = [0; 32]; + let accept_value = { + let mut digest = Sha1::new(); + digest.update(key); + digest.update(KEY); + let d = digest.finalize(); + let n = base64::encode_config_slice(&d, base64::STANDARD, &mut key_buf); + &key_buf[..n] + }; + self.buffer.extend_from_slice( + concat![ + "HTTP/1.1 101 Switching Protocols", + "\r\nServer: soketto-", + env!("CARGO_PKG_VERSION"), + "\r\nUpgrade: websocket", + "\r\nConnection: upgrade", + "\r\nSec-WebSocket-Accept: ", + ] + .as_bytes(), + ); + self.buffer.extend_from_slice(accept_value); + if let Some(p) = protocol { + self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); + self.buffer.extend_from_slice(p.as_bytes()) + } + append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer); + self.buffer.extend_from_slice(b"\r\n\r\n") + } + Response::Reject { status_code } => { + self.buffer.extend_from_slice(b"HTTP/1.1 "); + let (_, reason) = if let Ok(i) = STATUSCODES.binary_search_by_key(status_code, |(n, _)| *n) { + STATUSCODES[i] + } else { + (500, "500 Internal Server Error") + }; + self.buffer.extend_from_slice(reason.as_bytes()); + self.buffer.extend_from_slice(b"\r\n\r\n") + } + } + } } /// Handshake request received from the client. #[derive(Debug)] pub struct ClientRequest<'a> { - ws_key: WebSocketKey, - protocols: Vec<&'a str>, - path: &'a str, - headers: RequestHeaders<'a>, + ws_key: WebSocketKey, + protocols: Vec<&'a str>, + path: &'a str, + headers: RequestHeaders<'a>, } /// Select HTTP headers sent by the client. #[derive(Debug, Copy, Clone)] pub struct RequestHeaders<'a> { - /// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header. - pub host: &'a [u8], - /// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided. - pub origin: Option<&'a [u8]>, + /// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header. + pub host: &'a [u8], + /// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided. + pub origin: Option<&'a [u8]>, } impl<'a> ClientRequest<'a> { - /// The `Sec-WebSocket-Key` header nonce value. - pub fn key(&self) -> WebSocketKey { - self.ws_key - } - - /// The protocols the client is proposing. - pub fn protocols(&self) -> impl Iterator { - self.protocols.iter().cloned() - } - - /// The path the client is requesting. - pub fn path(&self) -> &str { - self.path - } - - /// Select HTTP headers sent by the client. - pub fn headers(&self) -> RequestHeaders { - self.headers - } + /// The `Sec-WebSocket-Key` header nonce value. + pub fn key(&self) -> WebSocketKey { + self.ws_key + } + + /// The protocols the client is proposing. + pub fn protocols(&self) -> impl Iterator { + self.protocols.iter().cloned() + } + + /// The path the client is requesting. + pub fn path(&self) -> &str { + self.path + } + + /// Select HTTP headers sent by the client. + pub fn headers(&self) -> RequestHeaders { + self.headers + } } /// Handshake response the server sends back to the client. #[derive(Debug)] pub enum Response<'a> { - /// The server accepts the handshake request. - Accept { - key: WebSocketKey, - protocol: Option<&'a str> - }, - /// The server rejects the handshake request. - Reject { - status_code: u16 - } + /// The server accepts the handshake request. + Accept { key: WebSocketKey, protocol: Option<&'a str> }, + /// The server rejects the handshake request. + Reject { status_code: u16 }, } /// Known status codes and their reason phrases. -const STATUSCODES: &[(u16, &str, &str)] = &[ - (100, "100", "Continue"), - (101, "101", "Switching Protocols"), - (102, "102", "Processing"), - (200, "200", "OK"), - (201, "201", "Created"), - (202, "202", "Accepted"), - (203, "203", "Non Authoritative Information"), - (204, "204", "No Content"), - (205, "205", "Reset Content"), - (206, "206", "Partial Content"), - (207, "207", "Multi-Status"), - (208, "208", "Already Reported"), - (226, "226", "IM Used"), - (300, "300", "Multiple Choices"), - (301, "301", "Moved Permanently"), - (302, "302", "Found"), - (303, "303", "See Other"), - (304, "304", "Not Modified"), - (305, "305", "Use Proxy"), - (307, "307", "Temporary Redirect"), - (308, "308", "Permanent Redirect"), - (400, "400", "Bad Request"), - (401, "401", "Unauthorized"), - (402, "402", "Payment Required"), - (403, "403", "Forbidden"), - (404, "404", "Not Found"), - (405, "405", "Method Not Allowed"), - (406, "406", "Not Acceptable"), - (407, "407", "Proxy Authentication Required"), - (408, "408", "Request Timeout"), - (409, "409", "Conflict"), - (410, "410", "Gone"), - (411, "411", "Length Required"), - (412, "412", "Precondition Failed"), - (413, "413", "Payload Too Large"), - (414, "414", "URI Too Long"), - (415, "415", "Unsupported Media Type"), - (416, "416", "Range Not Satisfiable"), - (417, "417", "Expectation Failed"), - (418, "418", "I'm a teapot"), - (421, "421", "Misdirected Request"), - (422, "422", "Unprocessable Entity"), - (423, "423", "Locked"), - (424, "424", "Failed Dependency"), - (426, "426", "Upgrade Required"), - (428, "428", "Precondition Required"), - (429, "429", "Too Many Requests"), - (431, "431", "Request Header Fields Too Large"), - (451, "451", "Unavailable For Legal Reasons"), - (500, "500", "Internal Server Error"), - (501, "501", "Not Implemented"), - (502, "502", "Bad Gateway"), - (503, "503", "Service Unavailable"), - (504, "504", "Gateway Timeout"), - (505, "505", "HTTP Version Not Supported"), - (506, "506", "Variant Also Negotiates"), - (507, "507", "Insufficient Storage"), - (508, "508", "Loop Detected"), - (510, "510", "Not Extended"), - (511, "511", "Network Authentication Required") +const STATUSCODES: &[(u16, &str)] = &[ + (100, "100 Continue"), + (101, "101 Switching Protocols"), + (102, "102 Processing"), + (200, "200 OK"), + (201, "201 Created"), + (202, "202 Accepted"), + (203, "203 Non Authoritative Information"), + (204, "204 No Content"), + (205, "205 Reset Content"), + (206, "206 Partial Content"), + (207, "207 Multi-Status"), + (208, "208 Already Reported"), + (226, "226 IM Used"), + (300, "300 Multiple Choices"), + (301, "301 Moved Permanently"), + (302, "302 Found"), + (303, "303 See Other"), + (304, "304 Not Modified"), + (305, "305 Use Proxy"), + (307, "307 Temporary Redirect"), + (308, "308 Permanent Redirect"), + (400, "400 Bad Request"), + (401, "401 Unauthorized"), + (402, "402 Payment Required"), + (403, "403 Forbidden"), + (404, "404 Not Found"), + (405, "405 Method Not Allowed"), + (406, "406 Not Acceptable"), + (407, "407 Proxy Authentication Required"), + (408, "408 Request Timeout"), + (409, "409 Conflict"), + (410, "410 Gone"), + (411, "411 Length Required"), + (412, "412 Precondition Failed"), + (413, "413 Payload Too Large"), + (414, "414 URI Too Long"), + (415, "415 Unsupported Media Type"), + (416, "416 Range Not Satisfiable"), + (417, "417 Expectation Failed"), + (418, "418 I'm a teapot"), + (421, "421 Misdirected Request"), + (422, "422 Unprocessable Entity"), + (423, "423 Locked"), + (424, "424 Failed Dependency"), + (426, "426 Upgrade Required"), + (428, "428 Precondition Required"), + (429, "429 Too Many Requests"), + (431, "431 Request Header Fields Too Large"), + (451, "451 Unavailable For Legal Reasons"), + (500, "500 Internal Server Error"), + (501, "501 Not Implemented"), + (502, "502 Bad Gateway"), + (503, "503 Service Unavailable"), + (504, "504 Gateway Timeout"), + (505, "505 HTTP Version Not Supported"), + (506, "506 Variant Also Negotiates"), + (507, "507 Insufficient Storage"), + (508, "508 Loop Detected"), + (510, "510 Not Extended"), + (511, "511 Network Authentication Required"), ]; diff --git a/src/lib.rs b/src/lib.rs index b2cc7514..e1665b1c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -112,10 +112,10 @@ #![forbid(unsafe_code)] pub mod base; +pub mod connection; pub mod data; pub mod extension; pub mod handshake; -pub mod connection; use bytes::BytesMut; use futures::io::{AsyncRead, AsyncReadExt}; @@ -129,58 +129,57 @@ pub type BoxedError = Box; /// A parsing result. #[derive(Debug, Clone)] pub enum Parsing { - /// Parsing completed. - Done { - /// The parsed value. - value: T, - /// The offset into the byte slice that has been consumed. - offset: usize - }, - /// Parsing is incomplete and needs more data. - NeedMore(N) + /// Parsing completed. + Done { + /// The parsed value. + value: T, + /// The offset into the byte slice that has been consumed. + offset: usize, + }, + /// Parsing is incomplete and needs more data. + NeedMore(N), } /// A buffer type used for implementing `Extension`s. #[derive(Debug)] pub enum Storage<'a> { - /// A read-only shared byte slice. - Shared(&'a [u8]), - /// A mutable byte slice. - Unique(&'a mut [u8]), - /// An owned byte buffer. - Owned(Vec) + /// A read-only shared byte slice. + Shared(&'a [u8]), + /// A mutable byte slice. + Unique(&'a mut [u8]), + /// An owned byte buffer. + Owned(Vec), } impl AsRef<[u8]> for Storage<'_> { - fn as_ref(&self) -> &[u8] { - match self { - Storage::Shared(d) => d, - Storage::Unique(d) => d, - Storage::Owned(b) => b.as_ref() - } - } + fn as_ref(&self) -> &[u8] { + match self { + Storage::Shared(d) => d, + Storage::Unique(d) => d, + Storage::Owned(b) => b.as_ref(), + } + } } /// Helper function to allow casts from `usize` to `u64` only on platforms /// where the sizes are guaranteed to fit. #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] const fn as_u64(a: usize) -> u64 { - a as u64 + a as u64 } /// Fill the buffer from the given `AsyncRead` impl with up to `max` bytes. async fn read(reader: &mut R, dest: &mut BytesMut, max: usize) -> io::Result<()> where - R: AsyncRead + Unpin + R: AsyncRead + Unpin, { - let i = dest.len(); - dest.resize(i + max, 0u8); - let n = reader.read(&mut dest[i ..]).await?; - dest.truncate(i + n); - if n == 0 { - return Err(io::ErrorKind::UnexpectedEof.into()) - } - log::trace!("read {} bytes", n); - Ok(()) + let i = dest.len(); + dest.resize(i + max, 0u8); + let n = reader.read(&mut dest[i..]).await?; + dest.truncate(i + n); + if n == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + log::trace!("read {} bytes", n); + Ok(()) } -