Skip to content

Commit

Permalink
fix: fix race condition when reporting stream errors (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
aschey authored Nov 11, 2024
1 parent bac8c64 commit 594d41d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 56 deletions.
33 changes: 3 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
use std::fmt::Debug;
use std::future::{self, Future};
use std::io::{self, Read, Seek, SeekFrom};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use educe::Educe;
pub use settings::*;
use source::handle::SourceHandle;
use source::{DecodeError, Source, SourceStream};
use storage::StorageProvider;
use tap::{Tap, TapFallible};
use tap::Tap;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, instrument, trace};

Expand Down Expand Up @@ -72,7 +70,6 @@ impl StreamHandle {
pub struct StreamDownload<P: StorageProvider> {
output_reader: P::Reader,
handle: SourceHandle,
download_status: DownloadStatus,
download_task_cancellation_token: CancellationToken,
cancel_on_drop: bool,
}
Expand Down Expand Up @@ -415,20 +412,10 @@ impl<P: StorageProvider> StreamDownload<P> {
let mut source = Source::new(writer, content_length, settings, cancellation_token.clone());
let handle = source.source_handle();

let download_status = DownloadStatus::default();
tokio::spawn({
let download_status = download_status.clone();
let cancellation_token = cancellation_token.clone();
async move {
if source
.download(stream)
.await
.tap_err(|e| error!("Error downloading stream: {e}"))
.is_err()
{
download_status.set_failed();
source.signal_download_complete();
}
source.download(stream).await;
cancellation_token.cancel();
debug!("download task finished");
}
Expand All @@ -437,7 +424,6 @@ impl<P: StorageProvider> StreamDownload<P> {
Ok(Self {
output_reader: reader,
handle,
download_status,
download_task_cancellation_token: cancellation_token,
cancel_on_drop,
})
Expand Down Expand Up @@ -476,7 +462,7 @@ impl<P: StorageProvider> StreamDownload<P> {
}

fn check_for_failure(&self) -> io::Result<()> {
if self.download_status.is_failed() {
if self.handle.is_failed() {
Err(io::Error::new(
io::ErrorKind::Other,
"stream failed to download",
Expand Down Expand Up @@ -600,19 +586,6 @@ impl<P: StorageProvider> Seek for StreamDownload<P> {
}
}

#[derive(Default, Clone, Debug)]
struct DownloadStatus(Arc<AtomicBool>);

impl DownloadStatus {
fn set_failed(&self) {
self.0.store(true, Ordering::SeqCst);
}

fn is_failed(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}

pub(crate) trait WrapIoResult {
fn wrap_err(self, msg: &str) -> Self;
}
Expand Down
18 changes: 18 additions & 0 deletions src/source/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tracing::{debug, error};
#[derive(Debug, Clone)]
pub(crate) struct SourceHandle {
pub(super) downloaded: Downloaded,
pub(super) download_status: DownloadStatus,
pub(super) requested_position: RequestedPosition,
pub(super) position_reached: PositionReached,
pub(super) content_length: Option<u64>,
Expand Down Expand Up @@ -56,6 +57,10 @@ impl SourceHandle {
pub(crate) fn content_length(&self) -> Option<u64> {
self.content_length
}

pub(crate) fn is_failed(&self) -> bool {
self.download_status.is_failed()
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -174,3 +179,16 @@ impl NotifyRead {
self.notify.notified().await;
}
}

#[derive(Default, Clone, Debug)]
pub(super) struct DownloadStatus(Arc<AtomicBool>);

impl DownloadStatus {
pub(super) fn set_failed(&self) {
self.0.store(true, Ordering::SeqCst);
}

fn is_failed(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}
70 changes: 44 additions & 26 deletions src/source/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use std::time::{Duration, Instant};

use bytes::{BufMut, Bytes, BytesMut};
use futures::{Future, Stream, StreamExt, TryStream};
use handle::{Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle};
use handle::{
DownloadStatus, Downloaded, NotifyRead, PositionReached, RequestedPosition, SourceHandle,
};
use tap::TapFallible;
use tokio::sync::mpsc;
use tokio::time::timeout;
Expand Down Expand Up @@ -88,14 +90,15 @@ impl DecodeError for Infallible {
}

#[derive(PartialEq, Eq)]
enum DownloadStatus {
enum DownloadAction {
Continue,
Complete,
}

pub(crate) struct Source<S: SourceStream, W: StorageWriter> {
writer: W,
downloaded: Downloaded,
download_status: DownloadStatus,
requested_position: RequestedPosition,
position_reached: PositionReached,
notify_read: NotifyRead,
Expand Down Expand Up @@ -126,6 +129,7 @@ where
Self {
writer,
downloaded: Downloaded::default(),
download_status: DownloadStatus::default(),
requested_position: RequestedPosition::default(),
position_reached: PositionReached::default(),
notify_read: NotifyRead::default(),
Expand All @@ -144,7 +148,17 @@ where
}

#[instrument(skip_all)]
pub(crate) async fn download(&mut self, mut stream: S) -> io::Result<()> {
pub(crate) async fn download(&mut self, mut stream: S) {
let res = self.download_inner(&mut stream).await;

if let Err(e) = res {
error!("download failed: {e:?}");
self.download_status.set_failed();
}
self.signal_download_complete();
}

async fn download_inner(&mut self, stream: &mut S) -> io::Result<()> {
debug!("starting file download");
let download_start = std::time::Instant::now();

Expand All @@ -154,17 +168,17 @@ where
let next_chunk = timeout(self.retry_timeout, stream.next());
tokio::select! {
position = self.seek_rx.recv() => {
self.handle_seek(&mut stream, &position).await?;
self.handle_seek(stream, &position).await?;
},
bytes = next_chunk => {
let Ok(bytes) = bytes else {
self.handle_reconnect(&mut stream).await?;
self.handle_reconnect(stream).await?;
continue;
};
if self
.handle_bytes(&mut stream, bytes, download_start)
.handle_bytes(stream, bytes, download_start)
.await?
== DownloadStatus::Complete
== DownloadAction::Complete
{
debug!(
download_duration = format!("{:?}", download_start.elapsed()),
Expand All @@ -179,7 +193,7 @@ where
}
};
}
self.report_download_complete(&stream, download_start)?;
self.report_download_complete(stream, download_start)?;
Ok(())
}

Expand Down Expand Up @@ -227,7 +241,7 @@ where
bytes: Option<Bytes>,
start_position: u64,
download_start: Instant,
) -> io::Result<DownloadStatus> {
) -> io::Result<DownloadAction> {
let Some(bytes) = bytes else {
self.prefetch_complete = true;
debug!("file shorter than prefetch length, download finished");
Expand Down Expand Up @@ -259,37 +273,39 @@ where
}

self.report_prefetch_progress(stream, stream_position, download_start, written);
Ok(DownloadStatus::Continue)
Ok(DownloadAction::Continue)
}

async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadStatus> {
if let Some(content_length) = self.content_length {
let gap = self.get_download_gap(content_length);
if let Some(gap) = gap {
debug!(
missing = format!("{gap:?}"),
"downloading missing stream chunk"
);
self.seek(stream, gap.start, Some(gap.end)).await?;
return Ok(DownloadStatus::Continue);
async fn finish_or_find_next_gap(&mut self, stream: &mut S) -> io::Result<DownloadAction> {
if stream.supports_seek() {
if let Some(content_length) = self.content_length {
let gap = self.get_download_gap(content_length);
if let Some(gap) = gap {
debug!(
missing = format!("{gap:?}"),
"downloading missing stream chunk"
);
self.seek(stream, gap.start, Some(gap.end)).await?;
return Ok(DownloadAction::Continue);
}
}
}
self.writer.flush()?;
self.signal_download_complete();
Ok(DownloadStatus::Complete)
Ok(DownloadAction::Complete)
}

async fn handle_bytes(
&mut self,
stream: &mut S,
bytes: Option<Result<Bytes, S::Error>>,
download_start: Instant,
) -> io::Result<DownloadStatus> {
) -> io::Result<DownloadAction> {
let bytes = match bytes.transpose() {
Ok(bytes) => bytes,
Err(e) => {
error!("Error fetching chunk from stream: {e:?}");
return Ok(DownloadStatus::Continue);
return Ok(DownloadAction::Continue);
}
};

Expand All @@ -316,7 +332,7 @@ where
let new_position = self.write(bytes).await?;
self.report_downloading_progress(stream, new_position, download_start, bytes_len)?;

Ok(DownloadStatus::Continue)
Ok(DownloadAction::Continue)
}

async fn write(&mut self, bytes: Bytes) -> io::Result<u64> {
Expand Down Expand Up @@ -391,7 +407,7 @@ where
self.downloaded.next_gap(&range)
}

pub(crate) fn signal_download_complete(&self) {
fn signal_download_complete(&self) {
self.position_reached.notify_stream_done();
}

Expand Down Expand Up @@ -444,7 +460,8 @@ where
self.report_progress(stream, StreamState {
current_position: pos,
elapsed: download_start.elapsed(),
current_chunk: self.downloaded.get(pos - 1).expect(""),
// ensure no subtraction overflow
current_chunk: self.downloaded.get(pos.max(1) - 1).unwrap_or_default(),
phase: StreamPhase::Complete,
});
Ok(())
Expand All @@ -453,6 +470,7 @@ where
pub(crate) fn source_handle(&self) -> SourceHandle {
SourceHandle {
downloaded: self.downloaded.clone(),
download_status: self.download_status.clone(),
requested_position: self.requested_position.clone(),
notify_read: self.notify_read.clone(),
position_reached: self.position_reached.clone(),
Expand Down

0 comments on commit 594d41d

Please sign in to comment.