From cacf896928349234660ec88ddef39c391b63d3ff Mon Sep 17 00:00:00 2001 From: link2xt Date: Sun, 22 Sep 2024 20:27:15 +0000 Subject: [PATCH] feat: implement IMAP COMPRESS --- .github/workflows/ci.yml | 6 ++ Cargo.toml | 7 +- src/extensions/compress.rs | 168 +++++++++++++++++++++++++++++++++++++ src/extensions/mod.rs | 3 + src/lib.rs | 6 ++ src/shared_stream.rs | 157 ++++++++++++++++++++++++++++++++++ 6 files changed, 345 insertions(+), 2 deletions(-) create mode 100644 src/extensions/compress.rs create mode 100644 src/shared_stream.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ac24f85..5e520ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,12 @@ jobs: - name: check tokio run: cargo check --workspace --all-targets --no-default-features --features runtime-tokio + - name: check compress feature with tokio + run: cargo check --workspace --all-targets --no-default-features --features runtime-tokio,compress + + - name: check compress feature with async-std + run: cargo check --workspace --all-targets --no-default-features --features runtime-async-std,compress + - name: check async-std examples working-directory: examples run: cargo check --workspace --all-targets --no-default-features --features runtime-async-std diff --git a/Cargo.toml b/Cargo.toml index 592df25..8a70b8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,12 +20,14 @@ is-it-maintained-open-issues = { repository = "async-email/async-imap" } [features] default = ["runtime-async-std"] +compress = ["async-compression"] -runtime-async-std = ["async-std"] -runtime-tokio = ["tokio"] +runtime-async-std = ["async-std", "async-compression?/futures-io"] +runtime-tokio = ["tokio", "async-compression?/tokio"] [dependencies] async-channel = "2.0.0" +async-compression = { git = "https://github.com/link2xt/async-compression.git", default-features = false, features = ["deflate"], optional = true, branch = "link2xt/miniz_oxide-consumes-all-input" } async-std = { version = "1.8.0", default-features = false, features = ["std", "unstable"], optional = true } base64 = "0.21" bytes = "1" @@ -35,6 +37,7 @@ imap-proto = "0.16.4" log = "0.4.8" nom = "7.0" once_cell = "1.8.0" +pin-project = "1" pin-utils = "0.1.0-alpha.4" self_cell = "1.0.1" stop-token = "0.7" diff --git a/src/extensions/compress.rs b/src/extensions/compress.rs new file mode 100644 index 0000000..9980589 --- /dev/null +++ b/src/extensions/compress.rs @@ -0,0 +1,168 @@ +//! IMAP COMPRESS extension specified in [RFC4978](https://www.rfc-editor.org/rfc/rfc4978.html). + +use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use pin_project::pin_project; + +use crate::client::Session; +use crate::error::Result; +use crate::imap_stream::ImapStream; +use crate::shared_stream::SharedStream; +use crate::types::IdGenerator; +use crate::Connection; + +#[cfg(feature = "runtime-async-std")] +use async_std::io::{BufReader, Read, Write}; +#[cfg(feature = "runtime-tokio")] +use tokio::io::{AsyncRead as Read, AsyncWrite as Write, BufReader, ReadBuf}; + +#[cfg(feature = "runtime-tokio")] +use async_compression::tokio::bufread::DeflateDecoder; +#[cfg(feature = "runtime-tokio")] +use async_compression::tokio::write::DeflateEncoder; + +#[cfg(feature = "runtime-async-std")] +use async_compression::futures::bufread::DeflateDecoder; +#[cfg(feature = "runtime-async-std")] +use async_compression::futures::write::DeflateEncoder; + +/// IMAP stream +#[derive(Debug)] +#[pin_project] +pub struct DeflateStream { + /// Shared stream reference to allow direct access + /// to the underlying stream. + stream: SharedStream, + + #[pin] + decoder: DeflateDecoder>>, + + #[pin] + encoder: DeflateEncoder>, +} + +impl DeflateStream { + pub(crate) fn new(stream: T) -> Self { + let stream = SharedStream::new(stream); + let decoder = DeflateDecoder::new(BufReader::new(stream.clone())); + let encoder = DeflateEncoder::new(stream.clone()); + Self { + stream, + decoder, + encoder, + } + } + + /// Runs provided function while holding the lock on the underlying stream. + /// + /// This allows to access the underlying stream while ensuring + /// that no data is read from the stream or written into the stream at the same time. + pub fn with_lock(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R { + self.stream.with_lock(f) + } +} + +#[cfg(feature = "runtime-tokio")] +impl Read for DeflateStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().decoder.poll_read(cx, buf) + } +} + +#[cfg(feature = "runtime-async-std")] +impl Read for DeflateStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().decoder.poll_read(cx, buf) + } +} + +#[cfg(feature = "runtime-tokio")] +impl Write for DeflateStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().encoder.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().encoder.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().encoder.poll_shutdown(cx) + } +} + +#[cfg(feature = "runtime-async-std")] +impl Write for DeflateStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().encoder.poll_write(cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().encoder.poll_flush(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.project().encoder.poll_close(cx) + } +} + +impl Session { + /// Runs `COMPRESS DEFLATE` command. + pub async fn compress(self, f: F) -> Result> + where + S: Read + Write + Unpin + fmt::Debug, + F: FnOnce(DeflateStream) -> S, + { + let Self { + mut conn, + unsolicited_responses_tx, + unsolicited_responses, + } = self; + conn.run_command_and_check_ok("COMPRESS DEFLATE", Some(unsolicited_responses_tx.clone())) + .await?; + + let stream = conn.into_inner(); + let deflate_stream = DeflateStream::new(stream); + let stream = ImapStream::new(f(deflate_stream)); + let conn = Connection { + stream, + request_ids: IdGenerator::new(), + }; + let session = Session { + conn, + unsolicited_responses_tx, + unsolicited_responses, + }; + Ok(session) + } +} diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 25fdaba..56aaa05 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,4 +1,7 @@ //! Implementations of various IMAP extensions. +#[cfg(feature = "compress")] +pub mod compress; + pub mod idle; pub mod quota; diff --git a/src/lib.rs b/src/lib.rs index 4f33ffa..c5d836f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -95,6 +95,12 @@ mod imap_stream; mod parse; pub mod types; +#[cfg(feature = "compress")] +pub use crate::extensions::compress::DeflateStream; + +#[cfg(feature = "compress")] +mod shared_stream; + pub use crate::authenticator::Authenticator; pub use crate::client::*; diff --git a/src/shared_stream.rs b/src/shared_stream.rs new file mode 100644 index 0000000..86ca576 --- /dev/null +++ b/src/shared_stream.rs @@ -0,0 +1,157 @@ +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +#[cfg(feature = "runtime-async-std")] +use async_std::io::{Read, Write}; +#[cfg(feature = "runtime-tokio")] +use tokio::io::{AsyncRead as Read, AsyncWrite as Write, ReadBuf}; + +#[cfg(feature = "runtime-tokio")] +#[derive(Debug)] +pub(crate) struct SharedStream { + inner: Arc>, + is_write_vectored: bool, +} + +#[cfg(feature = "runtime-async-std")] +#[derive(Debug)] +pub(crate) struct SharedStream { + inner: Arc>, +} + +#[cfg(feature = "runtime-tokio")] +impl Clone for SharedStream { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + is_write_vectored: self.is_write_vectored, + } + } +} + +#[cfg(feature = "runtime-async-std")] +impl Clone for SharedStream { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +#[cfg(feature = "runtime-tokio")] +impl SharedStream +where + T: Read + Write, +{ + pub(crate) fn new(stream: T) -> SharedStream { + let is_write_vectored = stream.is_write_vectored(); + + let inner = Arc::new(Mutex::new(stream)); + + Self { + inner, + is_write_vectored, + } + } +} + +#[cfg(feature = "runtime-async-std")] +impl SharedStream +where + T: Read + Write, +{ + pub(crate) fn new(stream: T) -> SharedStream { + let inner = Arc::new(Mutex::new(stream)); + + Self { inner } + } +} + +impl SharedStream { + pub(crate) fn with_lock(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R { + let mut guard = self.inner.lock().unwrap(); + let stream = Pin::new(&mut *guard); + f(stream) + } +} + +#[cfg(feature = "runtime-tokio")] +impl Read for SharedStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.with_lock(|stream| stream.poll_read(cx, buf)) + } +} + +#[cfg(feature = "runtime-async-std")] +impl Read for SharedStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.with_lock(|stream| stream.poll_read(cx, buf)) + } +} + +#[cfg(feature = "runtime-tokio")] +impl Write for SharedStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.with_lock(|stream| stream.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_lock(|stream| stream.poll_flush(cx)) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_lock(|stream| stream.poll_shutdown(cx)) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.with_lock(|stream| stream.poll_write_vectored(cx, bufs)) + } + + fn is_write_vectored(&self) -> bool { + self.is_write_vectored + } +} + +#[cfg(feature = "runtime-async-std")] +impl Write for SharedStream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.with_lock(|stream| stream.poll_write(cx, buf)) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_lock(|stream| stream.poll_flush(cx)) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.with_lock(|stream| stream.poll_close(cx)) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[async_std::io::IoSlice<'_>], + ) -> Poll> { + self.with_lock(|stream| stream.poll_write_vectored(cx, bufs)) + } +}