Skip to content

Commit

Permalink
feat!: add SourceStream implementation for AsyncRead (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
aschey authored Nov 10, 2024
1 parent 6ae481b commit bac8c64
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 12 deletions.
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ reqwest-rustls = ["reqwest", "reqwest/rustls-tls"]
reqwest-middleware = ["reqwest", "dep:reqwest-middleware"]
temp-storage = ["dep:tempfile"]
open-dal = ["dep:opendal", "dep:pin-project-lite", "tokio-util/compat"]
async-read = ["tokio-util/io"]

[dev-dependencies]
rodio = { version = "0.20.1", default-features = false, features = [
Expand Down Expand Up @@ -104,6 +105,11 @@ name = "s3"
required-features = ["open-dal"]
doc-scrape-examples = true

[[example]]
name = "stdin"
required-features = ["async-read"]
doc-scrape-examples = true

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
Expand Down
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ transports and storage implementations.
cargo add stream-download
```

## Features
## Feature Flags

- `http` - adds an HTTP-based implementation of the
[`SourceStream`](https://docs.rs/stream-download/latest/stream_download/source/trait.SourceStream.html)
Expand All @@ -45,6 +45,8 @@ cargo add stream-download
`reqwest` feature.
- `open-dal` - adds a `SourceStream` implementation that uses
[Apache OpenDAL](https://crates.io/crates/opendal) as the backend.
- `async-read` - adds a `SourceStream` implementation for any type implementing
[`AsyncRead`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html).
- `temp-storage` - adds a temporary file-based storage backend (enabled by
default).

Expand Down Expand Up @@ -95,11 +97,15 @@ See [examples](https://github.com/aschey/stream-download-rs/tree/main/examples).

Transports implement the
[`SourceStream`](https://docs.rs/stream-download/latest/stream_download/source/trait.SourceStream.html)
trait. Two types of transports are provided out of the box -
[`http`](https://docs.rs/stream-download/latest/stream_download/http) for
typical HTTP-based sources and
[`open_dal`](https://docs.rs/stream-download/latest/stream_download/open_dal)
which is more complex, but supports a large variety of services.
trait. A few types of transports are provided out of the box:

- [`http`](https://docs.rs/stream-download/latest/stream_download/http) for
typical HTTP-based sources.
- [`open_dal`](https://docs.rs/stream-download/latest/stream_download/open_dal)
which is more complex, but supports a large variety of services.
- [`async_read`](https://docs.rs/stream-download/latest/stream_download/async_read)
for any source implementing
[`AsyncRead`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html).

Only `http` is enabled by default. You can provide a custom transport by
implementing `SourceStream` yourself.
Expand Down
39 changes: 39 additions & 0 deletions examples/stdin.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use std::error::Error;
use std::io::IsTerminal;

use stream_download::async_read::AsyncReadStreamParams;
use stream_download::storage::temp::TempStorageProvider;
use stream_download::{Settings, StreamDownload};
use tracing::metadata::LevelFilter;
use tracing_subscriber::EnvFilter;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::default().add_directive(LevelFilter::INFO.into()))
.with_line_number(true)
.with_file(true)
.init();

if std::io::stdin().is_terminal() {
Err("Pipe in an input stream. Ex: cat ./assets/music.mp3 | cargo run --example=stdin")?;
}

let reader = StreamDownload::new_async_read(
AsyncReadStreamParams::new(tokio::io::stdin()),
TempStorageProvider::new(),
Settings::default(),
)
.await?;

let handle = tokio::task::spawn_blocking(move || {
let (_stream, handle) = rodio::OutputStream::try_default()?;
let sink = rodio::Sink::try_new(&handle)?;
sink.append(rodio::Decoder::new(reader)?);
sink.sleep_until_end();

Ok::<_, Box<dyn Error + Send + Sync>>(())
});
handle.await??;
Ok(())
}
115 changes: 115 additions & 0 deletions src/async_read.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//! A [`SourceStream`] adapter for any source that implements [`AsyncRead`].

use std::convert::Infallible;
use std::io;
use std::pin::Pin;

use bytes::Bytes;
use futures::Stream;
use tokio::io::AsyncRead;
use tokio_util::io::ReaderStream;

use crate::source::SourceStream;

/// Parameters for creating an [`AsyncReadStream`].
#[derive(Debug)]
pub struct AsyncReadStreamParams<T> {
stream: T,
content_length: Option<u64>,
}

impl<T> AsyncReadStreamParams<T> {
/// Creates a new [`AsyncReadStreamParams`] instance.
pub fn new(stream: T) -> Self {
Self {
stream,
content_length: None,
}
}

/// Sets the content length of the stream.
/// A generic [`AsyncRead`] source has no way of knowing the content length automatically, so it
/// must be set explicitly or it will default to [`None`].
#[must_use]
pub fn content_length<L>(self, content_length: L) -> Self
where
L: Into<Option<u64>>,
{
Self {
content_length: content_length.into(),
..self
}
}
}

/// An implementation of the [`SourceStream`] trait for any stream implementing [`AsyncRead`].
#[derive(Debug)]
pub struct AsyncReadStream<T> {
stream: ReaderStream<T>,
content_length: Option<u64>,
}

impl<T> AsyncReadStream<T>
where
T: AsyncRead + Send + Sync + Unpin + 'static,
{
/// Creates a new [`AsyncReadStream`].
pub fn new<L>(stream: T, content_length: L) -> Self
where
L: Into<Option<u64>>,
{
Self {
stream: ReaderStream::new(stream),
content_length: content_length.into(),
}
}
}

impl<T> SourceStream for AsyncReadStream<T>
where
T: AsyncRead + Send + Sync + Unpin + 'static,
{
type Params = AsyncReadStreamParams<T>;

type StreamCreationError = Infallible;

async fn create(params: Self::Params) -> Result<Self, Self::StreamCreationError> {
Ok(Self::new(params.stream, params.content_length))
}

fn content_length(&self) -> Option<u64> {
self.content_length
}

fn supports_seek(&self) -> bool {
false
}

async fn seek_range(&mut self, _start: u64, _end: Option<u64>) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"seek unsupported",
))
}

async fn reconnect(&mut self, _current_position: u64) -> io::Result<()> {
Err(io::Error::new(
io::ErrorKind::Unsupported,
"reconnect unsupported",
))
}
}

impl<T> Stream for AsyncReadStream<T>
where
T: AsyncRead + Unpin,
{
type Item = io::Result<Bytes>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
4 changes: 4 additions & 0 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ impl<C: Client> SourceStream for HttpStream<C> {
Ok(())
}
}

fn supports_seek(&self) -> bool {
true
}
}

/// HTTP range header key
Expand Down
51 changes: 48 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use educe::Educe;
pub use settings::*;
use source::handle::SourceHandle;
use source::{DecodeError, Source, SourceStream};
use storage::StorageProvider;
use tap::{Tap, TapFallible};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, instrument, trace};

#[cfg(feature = "async-read")]
pub mod async_read;
#[cfg(feature = "http")]
pub mod http;
#[cfg(feature = "open-dal")]
Expand All @@ -34,8 +38,6 @@ mod settings;
pub mod source;
pub mod storage;

pub use settings::*;

/// A handle that can be usd to interact with the stream remotely.
#[derive(Debug, Clone)]
pub struct StreamHandle {
Expand Down Expand Up @@ -175,7 +177,6 @@ impl<P: StorageProvider> StreamDownload<P> {
Self::new(url, storage_provider, settings).await
}

#[cfg(feature = "open-dal")]
/// Creates a new [`StreamDownload`] that uses an `OpenDAL` resource.
/// See the [`open_dal`] documentation for more details.
///
Expand Down Expand Up @@ -216,6 +217,7 @@ impl<P: StorageProvider> StreamDownload<P> {
/// Ok(())
/// }
/// ```
#[cfg(feature = "open-dal")]
pub async fn new_open_dal(
params: open_dal::OpenDalStreamParams,
storage_provider: P,
Expand All @@ -224,6 +226,49 @@ impl<P: StorageProvider> StreamDownload<P> {
Self::new(params, storage_provider, settings).await
}

/// Creates a new [`StreamDownload`] that uses an [`AsyncRead`][tokio::io::AsyncRead] resource.
///
/// # Example reading from `stdin`
///
/// ```no_run
/// use std::error::Error;
/// use std::io::{self, Read};
/// use std::result::Result;
///
/// use stream_download::async_read::AsyncReadStreamParams;
/// use stream_download::storage::temp::TempStorageProvider;
/// use stream_download::{Settings, StreamDownload};
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// let mut reader = StreamDownload::new_async_read(
/// AsyncReadStreamParams::new(tokio::io::stdin()),
/// TempStorageProvider::new(),
/// Settings::default(),
/// )
/// .await?;
///
/// tokio::task::spawn_blocking(move || {
/// let mut buf = Vec::new();
/// reader.read_to_end(&mut buf)?;
/// Ok::<_, io::Error>(())
/// })
/// .await??;
/// Ok(())
/// }
/// ```
#[cfg(feature = "async-read")]
pub async fn new_async_read<T>(
params: async_read::AsyncReadStreamParams<T>,
storage_provider: P,
settings: Settings<async_read::AsyncReadStream<T>>,
) -> Result<Self, StreamInitializationError<async_read::AsyncReadStream<T>>>
where
T: tokio::io::AsyncRead + Send + Sync + Unpin + 'static,
{
Self::new(params, storage_provider, settings).await
}

/// Creates a new [`StreamDownload`] that accesses a remote resource at the given URL.
///
/// # Example
Expand Down
5 changes: 5 additions & 0 deletions src/open_dal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ impl OpenDalStreamParams {

/// Sets the chunk size for the [`OpenDalStream`].
/// The default value is 4096.
#[must_use]
pub fn chunk_size(mut self, chunk_size: NonZeroUsize) -> Self {
self.chunk_size = chunk_size.get();
self
Expand Down Expand Up @@ -172,6 +173,10 @@ impl SourceStream for OpenDalStream {
async fn reconnect(&mut self, current_position: u64) -> io::Result<()> {
self.seek_range(current_position, None).await
}

fn supports_seek(&self) -> bool {
true
}
}

impl Stream for OpenDalStream {
Expand Down
2 changes: 1 addition & 1 deletion src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl<S> Settings<S> {
}

/// If set to `true`, this will cause the stream download task to automatically cancel when the
/// [`crate::StreamDownload`] instance is dropped.
/// [`StreamDownload`][crate::StreamDownload] instance is dropped.
#[must_use]
pub fn cancel_on_drop(self, cancel_on_drop: bool) -> Self {
Self {
Expand Down
19 changes: 17 additions & 2 deletions src/source/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Provides the [`SourceStream`] trait which abstracts over the transport used to
//! stream remote content.
use std::convert::Infallible;
use std::error::Error;
use std::fmt::Debug;
use std::future;
Expand Down Expand Up @@ -65,6 +66,10 @@ pub trait SourceStream:
&mut self,
current_position: u64,
) -> impl Future<Output = Result<(), io::Error>> + Send;

/// Returns whether seeking is supported in the stream.
/// If this method returns `false`, [`SourceStream::seek_range`] will never be invoked.
fn supports_seek(&self) -> bool;
}

/// Trait for decoding extra error information asynchronously.
Expand All @@ -75,6 +80,13 @@ pub trait DecodeError: Error + Send + Sized {
}
}

impl DecodeError for Infallible {
async fn decode_error(self) -> String {
// This will never get called since it's infallible
String::new()
}
}

#[derive(PartialEq, Eq)]
enum DownloadStatus {
Continue,
Expand Down Expand Up @@ -173,7 +185,7 @@ where

async fn handle_seek(&mut self, stream: &mut S, position: &Option<u64>) -> io::Result<()> {
let position = position.expect("seek_tx dropped");
if self.should_seek(position)? {
if self.should_seek(stream, position)? {
debug!("seek position not yet downloaded");

if self.prefetch_complete {
Expand Down Expand Up @@ -357,7 +369,10 @@ where
Ok(new_position)
}

fn should_seek(&mut self, position: u64) -> io::Result<bool> {
fn should_seek(&mut self, stream: &S, position: u64) -> io::Result<bool> {
if !stream.supports_seek() {
return Ok(false);
}
Ok(if let Some(range) = self.downloaded.get(position) {
!range.contains(&self.writer.stream_position()?)
} else {
Expand Down
Loading

0 comments on commit bac8c64

Please sign in to comment.