From c6b2a17e405a6754926c4f726f646efe9ef2bd2a Mon Sep 17 00:00:00 2001 From: Alessandro Passaro Date: Wed, 8 Nov 2023 14:26:45 +0000 Subject: [PATCH] Introduce Prefetch trait (#595) Introduce a new `Prefetch` trait to abstract how `S3Filesystem` fetches object data from an `ObjectClient`. While this change does not introduce any functional change, this abstraction will be used to implement optional object data caching. The existing `Prefetcher` struct has been adapted to implement the new `Prefetch` trait. The main changes are: * it is generic on the `ObjectPartStream` (previously `ObjectPartFeed`), rather than using dynamic dispatch, * it does not own an `ObjectClient` instance, instead one is required when initiating a `prefetch` request, * the logic to spawn a new task for each `GetObject` request and handle the object body parts returned was moved into `ObjectPartStream`. Signed-off-by: Alessandro Passaro --- mountpoint-s3/examples/fs_benchmark.rs | 7 +- mountpoint-s3/examples/prefetch_benchmark.rs | 6 +- mountpoint-s3/src/fs.rs | 90 ++-- mountpoint-s3/src/fuse.rs | 30 +- mountpoint-s3/src/main.rs | 95 +++- mountpoint-s3/src/prefetch.rs | 491 ++++++++---------- mountpoint-s3/src/prefetch/feed.rs | 128 ----- mountpoint-s3/src/prefetch/part_queue.rs | 16 +- mountpoint-s3/src/prefetch/part_stream.rs | 271 ++++++++++ mountpoint-s3/src/prefetch/task.rs | 74 +++ mountpoint-s3/tests/common/mod.rs | 12 +- mountpoint-s3/tests/fs.rs | 21 +- mountpoint-s3/tests/fuse_tests/mod.rs | 116 +++-- .../tests/fuse_tests/prefetch_test.rs | 8 +- mountpoint-s3/tests/fuse_tests/read_test.rs | 12 +- mountpoint-s3/tests/reftests/harness.rs | 9 +- 16 files changed, 841 insertions(+), 545 deletions(-) delete mode 100644 mountpoint-s3/src/prefetch/feed.rs create mode 100644 mountpoint-s3/src/prefetch/part_stream.rs create mode 100644 mountpoint-s3/src/prefetch/task.rs diff --git a/mountpoint-s3/examples/fs_benchmark.rs b/mountpoint-s3/examples/fs_benchmark.rs index 45fa23d51..5a77e766d 100644 --- a/mountpoint-s3/examples/fs_benchmark.rs +++ b/mountpoint-s3/examples/fs_benchmark.rs @@ -1,13 +1,13 @@ use clap::{Arg, ArgAction, Command}; use fuser::{BackgroundSession, MountOption, Session}; use mountpoint_s3::fuse::S3FuseFilesystem; +use mountpoint_s3::prefetch::default_prefetch; use mountpoint_s3::S3FilesystemConfig; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; use mountpoint_s3_client::S3CrtClient; use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter; use std::{ - fs::File, - fs::OpenOptions, + fs::{File, OpenOptions}, io::{self, BufRead, BufReader}, time::Instant, }; @@ -164,8 +164,9 @@ fn mount_file_system(bucket_name: &str, region: &str, throughput_target_gbps: Op bucket_name, mountpoint.to_str().unwrap() ); + let prefetcher = default_prefetch(runtime, Default::default()); let session = Session::new( - S3FuseFilesystem::new(client, runtime, bucket_name, &Default::default(), filesystem_config), + S3FuseFilesystem::new(client, prefetcher, bucket_name, &Default::default(), filesystem_config), mountpoint, &options, ) diff --git a/mountpoint-s3/examples/prefetch_benchmark.rs b/mountpoint-s3/examples/prefetch_benchmark.rs index 38f5b5a51..737eb3d2e 100644 --- a/mountpoint-s3/examples/prefetch_benchmark.rs +++ b/mountpoint-s3/examples/prefetch_benchmark.rs @@ -4,7 +4,7 @@ use std::time::Instant; use clap::{Arg, Command}; use futures::executor::{block_on, ThreadPool}; -use mountpoint_s3::prefetch::Prefetcher; +use mountpoint_s3::prefetch::{default_prefetch, Prefetch, PrefetchResult}; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; use mountpoint_s3_client::types::ETag; use mountpoint_s3_client::S3CrtClient; @@ -80,12 +80,12 @@ fn main() { for i in 0..iterations.unwrap_or(1) { let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let manager = Prefetcher::new(client.clone(), runtime, Default::default()); + let manager = default_prefetch(runtime, Default::default()); let received_size = Arc::new(AtomicU64::new(0)); let start = Instant::now(); - let mut request = manager.get(bucket, key, size, ETag::for_tests()); + let mut request = manager.prefetch(client.clone(), bucket, key, size, ETag::for_tests()); block_on(async { loop { let offset = received_size.load(Ordering::SeqCst); diff --git a/mountpoint-s3/src/fs.rs b/mountpoint-s3/src/fs.rs index 31d833ef3..661b8b34a 100644 --- a/mountpoint-s3/src/fs.rs +++ b/mountpoint-s3/src/fs.rs @@ -1,6 +1,5 @@ //! FUSE file system types and operations, not tied to the _fuser_ library bindings. -use futures::task::Spawn; use nix::unistd::{getgid, getuid}; use std::collections::HashMap; use std::ffi::{OsStr, OsString}; @@ -15,7 +14,7 @@ use mountpoint_s3_client::types::ETag; use mountpoint_s3_client::ObjectClient; use crate::inode::{Inode, InodeError, InodeKind, LookedUp, ReaddirHandle, Superblock, WriteHandle}; -use crate::prefetch::{PrefetchGetObject, PrefetchReadError, Prefetcher, PrefetcherConfig}; +use crate::prefetch::{Prefetch, PrefetchReadError, PrefetchResult}; use crate::prefix::Prefix; use crate::sync::atomic::{AtomicI64, AtomicU64, Ordering}; use crate::sync::{Arc, AsyncMutex, AsyncRwLock}; @@ -49,30 +48,54 @@ impl DirHandle { } #[derive(Debug)] -struct FileHandle { +struct FileHandle +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch, +{ inode: Inode, full_key: String, object_size: u64, - typ: FileHandleType, + typ: FileHandleType, } -#[derive(Debug)] -enum FileHandleType { +enum FileHandleType +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch, +{ Read { - request: AsyncMutex>>, + request: AsyncMutex>>, etag: ETag, }, Write(AsyncMutex>), } -impl FileHandleType { +impl std::fmt::Debug for FileHandleType +where + Client: ObjectClient + Send + Sync + 'static + std::fmt::Debug, + Prefetcher: Prefetch, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Read { request: _, etag } => f.debug_struct("Read").field("etag", etag).finish(), + Self::Write(arg0) => f.debug_tuple("Write").field(arg0).finish(), + } + } +} + +impl FileHandleType +where + Client: ObjectClient + Send + Sync, + Prefetcher: Prefetch, +{ async fn new_write_handle( lookup: &LookedUp, ino: InodeNo, flags: i32, pid: u32, - fs: &S3Filesystem, - ) -> Result, Error> { + fs: &S3Filesystem, + ) -> Result, Error> { // We can't support O_SYNC writes because they require the data to go to stable storage // at `write` time, but we only commit a PUT at `close` time. if flags & (libc::O_SYNC | libc::O_DSYNC) != 0 { @@ -96,7 +119,7 @@ impl FileHandleType { Ok(handle) } - async fn new_read_handle(lookup: &LookedUp) -> Result, Error> { + async fn new_read_handle(lookup: &LookedUp) -> Result, Error> { if !lookup.stat.is_readable { return Err(err!( libc::EACCES, @@ -302,8 +325,6 @@ pub struct S3FilesystemConfig { pub dir_mode: u16, /// File permissions pub file_mode: u16, - /// Prefetcher configuration - pub prefetcher_config: PrefetcherConfig, /// Allow delete pub allow_delete: bool, /// Storage class to be used for new object uploads @@ -322,7 +343,6 @@ impl Default for S3FilesystemConfig { gid, dir_mode: 0o755, file_mode: 0o644, - prefetcher_config: PrefetcherConfig::default(), allow_delete: false, storage_class: None, } @@ -330,31 +350,38 @@ impl Default for S3FilesystemConfig { } #[derive(Debug)] -pub struct S3Filesystem { +pub struct S3Filesystem +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch, +{ config: S3FilesystemConfig, client: Arc, superblock: Superblock, - prefetcher: Prefetcher, + prefetcher: Prefetcher, uploader: Uploader, bucket: String, #[allow(unused)] prefix: Prefix, next_handle: AtomicU64, dir_handles: AsyncRwLock>>, - file_handles: AsyncRwLock>>>, + file_handles: AsyncRwLock>>>, } -impl S3Filesystem +impl S3Filesystem where Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn + Send + Sync, + Prefetcher: Prefetch, { - pub fn new(client: Client, runtime: Runtime, bucket: &str, prefix: &Prefix, config: S3FilesystemConfig) -> Self { - let superblock = Superblock::new(bucket, prefix, config.cache_config.clone()); - + pub fn new( + client: Client, + prefetcher: Prefetcher, + bucket: &str, + prefix: &Prefix, + config: S3FilesystemConfig, + ) -> Self { let client = Arc::new(client); - - let prefetcher = Prefetcher::new(client.clone(), runtime, config.prefetcher_config); + let superblock = Superblock::new(bucket, prefix, config.cache_config.clone()); let uploader = Uploader::new(client.clone(), config.storage_class.to_owned()); Self { @@ -429,10 +456,10 @@ pub trait ReadReplier { fn error(self, error: Error) -> Self::Replied; } -impl S3Filesystem +impl S3Filesystem where Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn + Send + Sync, + Prefetcher: Prefetch, { pub async fn init(&self, config: &mut KernelConfig) -> Result<(), libc::c_int> { let _ = config.add_capabilities(fuser::consts::FUSE_DO_READDIRPLUS); @@ -608,10 +635,13 @@ where }; if request.is_none() { - *request = Some( - self.prefetcher - .get(&self.bucket, &handle.full_key, handle.object_size, file_etag), - ); + *request = Some(self.prefetcher.prefetch( + self.client.clone(), + &self.bucket, + &handle.full_key, + handle.object_size, + file_etag, + )); } match request.as_mut().unwrap().read(offset as u64, size as usize).await { diff --git a/mountpoint-s3/src/fuse.rs b/mountpoint-s3/src/fuse.rs index f65fa6ab5..324e32e75 100644 --- a/mountpoint-s3/src/fuse.rs +++ b/mountpoint-s3/src/fuse.rs @@ -1,7 +1,7 @@ //! Links _fuser_ method calls into Mountpoint's filesystem code in [crate::fs]. use futures::executor::block_on; -use futures::task::Spawn; +use mountpoint_s3_client::ObjectClient; use std::ffi::OsStr; use std::path::Path; use std::time::SystemTime; @@ -11,6 +11,7 @@ use tracing::{instrument, Instrument}; use crate::fs::{ self, DirectoryEntry, DirectoryReplier, InodeNo, ReadReplier, S3Filesystem, S3FilesystemConfig, ToErrno, }; +use crate::prefetch::Prefetch; use crate::prefix::Prefix; #[cfg(target_os = "macos")] use fuser::ReplyXTimes; @@ -18,7 +19,6 @@ use fuser::{ Filesystem, KernelConfig, ReplyAttr, ReplyBmap, ReplyCreate, ReplyData, ReplyEmpty, ReplyEntry, ReplyIoctl, ReplyLock, ReplyLseek, ReplyOpen, ReplyWrite, ReplyXattr, Request, TimeOrNow, }; -use mountpoint_s3_client::ObjectClient; pub mod session; @@ -48,26 +48,36 @@ macro_rules! fuse_unsupported { /// This is just a thin wrapper around [S3Filesystem] that implements the actual `fuser` protocol, /// so that we can test our actual filesystem implementation without having actual FUSE in the loop. -pub struct S3FuseFilesystem { - fs: S3Filesystem, +pub struct S3FuseFilesystem +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch, +{ + fs: S3Filesystem, } -impl S3FuseFilesystem +impl S3FuseFilesystem where Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn + Send + Sync, + Prefetcher: Prefetch, { - pub fn new(client: Client, runtime: Runtime, bucket: &str, prefix: &Prefix, config: S3FilesystemConfig) -> Self { - let fs = S3Filesystem::new(client, runtime, bucket, prefix, config); + pub fn new( + client: Client, + prefetcher: Prefetcher, + bucket: &str, + prefix: &Prefix, + config: S3FilesystemConfig, + ) -> Self { + let fs = S3Filesystem::new(client, prefetcher, bucket, prefix, config); Self { fs } } } -impl Filesystem for S3FuseFilesystem +impl Filesystem for S3FuseFilesystem where Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn + Send + Sync, + Prefetcher: Prefetch, { #[instrument(level="warn", skip_all, fields(req=_req.unique()))] fn init(&self, _req: &Request<'_>, config: &mut KernelConfig) -> Result<(), libc::c_int> { diff --git a/mountpoint-s3/src/main.rs b/mountpoint-s3/src/main.rs index 800638d9e..9ec735a6b 100644 --- a/mountpoint-s3/src/main.rs +++ b/mountpoint-s3/src/main.rs @@ -15,6 +15,7 @@ use mountpoint_s3::fuse::S3FuseFilesystem; use mountpoint_s3::instance::InstanceInfo; use mountpoint_s3::logging::{init_logging, LoggingConfig}; use mountpoint_s3::metrics; +use mountpoint_s3::prefetch::{default_prefetch, Prefetch}; use mountpoint_s3::prefix::Prefix; use mountpoint_s3_client::config::{AddressingStyle, EndpointConfig, S3ClientAuthConfig, S3ClientConfig}; use mountpoint_s3_client::error::ObjectClientError; @@ -288,6 +289,35 @@ impl CliArgs { format!("bucket {}", self.bucket_name) } } + + fn fuse_session_config(&self) -> FuseSessionConfig { + let fs_name = String::from("mountpoint-s3"); + let mut options = vec![ + MountOption::DefaultPermissions, + MountOption::FSName(fs_name), + MountOption::NoAtime, + ]; + if self.read_only { + options.push(MountOption::RO); + } + if self.auto_unmount { + options.push(MountOption::AutoUnmount); + } + if self.allow_root { + options.push(MountOption::AllowRoot); + } + if self.allow_other { + options.push(MountOption::AllowOther); + } + + let mount_point = self.mount_point.to_owned(); + let max_threads = self.max_threads as usize; + FuseSessionConfig { + mount_point, + options, + max_threads, + } + } } fn main() -> anyhow::Result<()> { @@ -425,6 +455,7 @@ fn mount(args: CliArgs) -> anyhow::Result { validate_mount_point(&args.mount_point)?; let bucket_description = args.bucket_description(); + let fuse_config = args.fuse_session_config(); // Placeholder region will be filled in by [create_client_for_bucket] let endpoint_config = EndpointConfig::new("PLACEHOLDER") @@ -502,6 +533,8 @@ fn mount(args: CliArgs) -> anyhow::Result { filesystem_config.storage_class = args.storage_class; filesystem_config.allow_delete = args.allow_delete; + let prefetcher_config = Default::default(); + #[cfg(feature = "caching")] { use mountpoint_s3::fs::CacheConfig; @@ -517,41 +550,53 @@ fn mount(args: CliArgs) -> anyhow::Result { } } - let fs = S3FuseFilesystem::new(client, runtime, &args.bucket_name, &prefix, filesystem_config); - - let fs_name = String::from("mountpoint-s3"); - let mut options = vec![ - MountOption::DefaultPermissions, - MountOption::FSName(fs_name), - MountOption::NoAtime, - ]; - if args.read_only { - options.push(MountOption::RO); - } - if args.auto_unmount { - options.push(MountOption::AutoUnmount); - } - if args.allow_root { - options.push(MountOption::AllowRoot); - } - if args.allow_other { - options.push(MountOption::AllowOther); - } - - let session = Session::new(fs, &args.mount_point, &options).context("Failed to create FUSE session")?; + let prefetcher = default_prefetch(runtime, prefetcher_config); + create_filesystem( + client, + prefetcher, + &args.bucket_name, + &prefix, + filesystem_config, + fuse_config, + &bucket_description, + ) +} - let max_threads = args.max_threads as usize; - let session = FuseSession::new(session, max_threads).context("Failed to start FUSE session")?; +fn create_filesystem( + client: Client, + prefetcher: Prefetcher, + bucket_name: &str, + prefix: &Prefix, + filesystem_config: S3FilesystemConfig, + fuse_session_config: FuseSessionConfig, + bucket_description: &str, +) -> Result +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch + Send + Sync + 'static, +{ + let fs = S3FuseFilesystem::new(client, prefetcher, bucket_name, prefix, filesystem_config); + let session = Session::new(fs, &fuse_session_config.mount_point, &fuse_session_config.options) + .context("Failed to create FUSE session")?; + let session = FuseSession::new(session, fuse_session_config.max_threads).context("Failed to start FUSE session")?; tracing::info!( "successfully mounted {} at {}", bucket_description, - args.mount_point.display() + fuse_session_config.mount_point.display() ); Ok(session) } +/// Configuration for a FUSE background session. +#[derive(Debug)] +struct FuseSessionConfig { + pub mount_point: PathBuf, + pub options: Vec, + pub max_threads: usize, +} + /// Create a client for a bucket in the given region and send a ListObjectsV2 request to validate /// that it's accessible. If no region is provided, attempt to infer it by first sending a /// ListObjectsV2 to the default region. diff --git a/mountpoint-s3/src/prefetch.rs b/mountpoint-s3/src/prefetch.rs index 550aa229b..017deb6e2 100644 --- a/mountpoint-s3/src/prefetch.rs +++ b/mountpoint-s3/src/prefetch.rs @@ -7,32 +7,82 @@ //! we increase the size of the GetObject requests up to some maximum. If the reader ever makes a //! non-sequential read, we abandon the prefetching and start again with the minimum request size. -mod feed; mod part; mod part_queue; +mod part_stream; mod seek_window; +mod task; use std::collections::VecDeque; use std::fmt::Debug; use std::time::Duration; -use futures::future::RemoteHandle; -use futures::task::{Spawn, SpawnExt}; +use async_trait::async_trait; +use futures::task::Spawn; use metrics::{counter, histogram}; use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; use mountpoint_s3_client::types::ETag; use mountpoint_s3_client::ObjectClient; use thiserror::Error; -use tracing::{debug_span, error, trace, Instrument}; +use tracing::trace; use crate::checksums::{ChecksummedBytes, IntegrityError}; -use crate::prefetch::feed::{ClientPartFeed, ObjectPartFeed}; -use crate::prefetch::part::Part; -use crate::prefetch::part_queue::{unbounded_part_queue, PartQueue}; +use crate::prefetch::part_stream::{ClientPartStream, ObjectPartStream, RequestRange}; use crate::prefetch::seek_window::SeekWindow; +use crate::prefetch::task::RequestTask; use crate::sync::Arc; -type TaskError = ObjectClientError::ClientError>; +/// Generic interface to handle reading data from an object. +pub trait Prefetch { + type PrefetchResult: PrefetchResult; + + /// Start a new prefetch request to the specified object. + fn prefetch( + &self, + client: Arc, + bucket: &str, + key: &str, + size: u64, + etag: ETag, + ) -> Self::PrefetchResult + where + Client: ObjectClient + Send + Sync + 'static; +} + +/// Result of a prefetch request. Allows callers to read object data. +#[async_trait] +pub trait PrefetchResult: Send + Sync { + /// Read some bytes from the object. This function will always return exactly `size` bytes, + /// except at the end of the object where it will return however many bytes are left (including + /// possibly 0 bytes). + async fn read( + &mut self, + offset: u64, + length: usize, + ) -> Result>; +} + +#[derive(Debug, Error)] +pub enum PrefetchReadError { + #[error("get object request failed")] + GetRequestFailed(#[source] ObjectClientError), + + #[error("get request terminated unexpectedly")] + GetRequestTerminatedUnexpectedly, + + #[error("integrity check failed")] + Integrity(#[from] IntegrityError), +} + +pub type DefaultPrefetcher = Prefetcher>; + +pub fn default_prefetch(runtime: Runtime, prefetcher_config: PrefetcherConfig) -> DefaultPrefetcher +where + Runtime: Spawn + Send + Sync + 'static, +{ + let part_stream = ClientPartStream::new(runtime); + Prefetcher::new(part_stream, prefetcher_config) +} #[derive(Debug, Clone, Copy)] pub struct PrefetcherConfig { @@ -80,55 +130,63 @@ impl Default for PrefetcherConfig { /// A [Prefetcher] creates and manages prefetching GetObject requests to objects. #[derive(Debug)] -pub struct Prefetcher { - inner: Arc>, -} - -struct PrefetcherInner { - part_feed: Arc + Send + Sync>, +pub struct Prefetcher { + part_stream: Arc, config: PrefetcherConfig, - runtime: Runtime, } -impl Debug for PrefetcherInner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PrefetcherInner").field("config", &self.config).finish() +impl Prefetcher +where + Stream: ObjectPartStream, +{ + /// Create a new [Prefetcher] from the given [ObjectPartStream] instance. + pub fn new(part_stream: Stream, config: PrefetcherConfig) -> Self { + let part_stream = Arc::new(part_stream); + Self { part_stream, config } } } -impl Prefetcher +impl Prefetch for Prefetcher where - Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn, + Stream: ObjectPartStream + Send + Sync + 'static, { - /// Create a new [Prefetcher] that will make requests to the given client. - pub fn new(client: Arc, runtime: Runtime, config: PrefetcherConfig) -> Self { - let part_feed = Arc::new(ClientPartFeed::new(client)); - let inner = PrefetcherInner { - part_feed, - config, - runtime, - }; + type PrefetchResult = PrefetchGetObject; - Self { inner: Arc::new(inner) } - } - - /// Start a new get request to the specified object. - pub fn get(&self, bucket: &str, key: &str, size: u64, etag: ETag) -> PrefetchGetObject { - PrefetchGetObject::new(self.inner.clone(), bucket, key, size, etag) + fn prefetch( + &self, + client: Arc, + bucket: &str, + key: &str, + size: u64, + etag: ETag, + ) -> Self::PrefetchResult + where + Client: ObjectClient + Send + Sync + 'static, + { + PrefetchGetObject::new( + client.clone(), + self.part_stream.clone(), + self.config, + bucket, + key, + size, + etag, + ) } } /// A GetObject request that divides the desired range of the object into chunks that it prefetches /// in a way that maximizes throughput from S3. #[derive(Debug)] -pub struct PrefetchGetObject { - inner: Arc>, +pub struct PrefetchGetObject { + client: Arc, + part_stream: Arc, + config: PrefetcherConfig, // Invariant: the offset of the first byte in this task's part queue is always // self.next_sequential_read_offset. - current_task: Option>>, + current_task: Option>, // Currently we only every spawn at most one future task (see [spawn_next_request]) - future_tasks: VecDeque>>, + future_tasks: VecDeque>, // Invariant: the offset of the last byte in this window is always // self.next_sequential_read_offset - 1. backward_seek_window: SeekWindow, @@ -143,37 +201,20 @@ pub struct PrefetchGetObject { etag: ETag, } -impl PrefetchGetObject +#[async_trait] +impl PrefetchResult for PrefetchGetObject where + Stream: ObjectPartStream + Send + Sync + 'static, Client: ObjectClient + Send + Sync + 'static, - Runtime: Spawn, { - /// Create and spawn a new prefetching request for an object - fn new(inner: Arc>, bucket: &str, key: &str, size: u64, etag: ETag) -> Self { - PrefetchGetObject { - inner: inner.clone(), - current_task: None, - future_tasks: Default::default(), - backward_seek_window: SeekWindow::new(inner.config.max_backward_seek_distance as usize), - preferred_part_size: 128 * 1024, - next_request_size: inner.config.first_request_size, - next_sequential_read_offset: 0, - next_request_offset: 0, - bucket: bucket.to_owned(), - key: key.to_owned(), - size, - etag, - } - } - /// Read some bytes from the object. This function will always return exactly `size` bytes, /// except at the end of the object where it will return however many bytes are left (including /// possibly 0 bytes). - pub async fn read( + async fn read( &mut self, offset: u64, length: usize, - ) -> Result>> { + ) -> Result> { trace!( offset, length, @@ -223,7 +264,7 @@ where trace!(offset, length, "read beyond object size"); break; }; - debug_assert!(current_task.remaining > 0); + debug_assert!(current_task.remaining() > 0); let part = match current_task.read(to_read as usize).await { Err(e) => { @@ -261,11 +302,45 @@ where Ok(response) } +} + +impl PrefetchGetObject +where + Stream: ObjectPartStream, + Client: ObjectClient + Send + Sync + 'static, +{ + /// Create and spawn a new prefetching request for an object + fn new( + client: Arc, + part_stream: Arc, + config: PrefetcherConfig, + bucket: &str, + key: &str, + size: u64, + etag: ETag, + ) -> Self { + PrefetchGetObject { + client, + part_stream, + config, + current_task: None, + future_tasks: Default::default(), + backward_seek_window: SeekWindow::new(config.max_backward_seek_distance as usize), + preferred_part_size: 128 * 1024, + next_sequential_read_offset: 0, + next_request_size: config.first_request_size, + next_request_offset: 0, + bucket: bucket.to_owned(), + key: key.to_owned(), + size, + etag, + } + } /// Runs on every read to prepare and spawn any requests our prefetching logic requires fn prepare_requests(&mut self) { let current_task = self.current_task.as_ref(); - if current_task.map(|task| task.remaining == 0).unwrap_or(true) { + if current_task.map(|task| task.remaining() == 0).unwrap_or(true) { // There's no current task, or the current task is finished. Prepare the next request. if let Some(next_task) = self.future_tasks.pop_front() { self.current_task = Some(next_task); @@ -275,7 +350,7 @@ where } else if current_task .map(|task| { // Don't trigger prefetch if we're in a fake task created by backward streaming - task.is_streaming() && task.remaining <= task.total_size / 2 + task.is_streaming() && task.remaining() <= task.total_size() / 2 }) .unwrap_or(false) && self.future_tasks.is_empty() @@ -289,64 +364,38 @@ where } /// Spawn the next required request - fn spawn_next_request(&mut self) -> Option>> { + fn spawn_next_request(&mut self) -> Option> { let start = self.next_request_offset; - let end = (start + self.next_request_size as u64).min(self.size); - if start >= self.size { return None; } - let size = end - start; - let range = start..end; - - let (part_queue, part_queue_producer) = unbounded_part_queue(); - - trace!(?range, size, "spawning request"); - - let request_task = { - let feed = self.inner.part_feed.clone(); - let preferred_part_size = self.preferred_part_size; - let bucket = self.bucket.to_owned(); - let key = self.key.to_owned(); - let etag = self.etag.clone(); - let span = debug_span!("prefetch", range=?range); - - async move { - feed.get_object_parts(&bucket, &key, range, etag, preferred_part_size, part_queue_producer) - .await - } - .instrument(span) - }; + let range = RequestRange::new(self.size as usize, start, self.next_request_size); + let task = self.part_stream.spawn_get_object_request( + &self.client, + &self.bucket, + &self.key, + self.etag.clone(), + range, + self.preferred_part_size, + ); // [read] will reset these if the reader stops making sequential requests - self.next_request_offset += size; - self.next_request_size = self.get_next_request_size(); - - let task_handle = self.inner.runtime.spawn_with_handle(request_task).unwrap(); + self.next_request_offset += task.total_size() as u64; + self.next_request_size = self.get_next_request_size(task.total_size()); - Some(RequestTask { - task_handle: Some(task_handle), - total_size: size as usize, - remaining: size as usize, - start_offset: start, - part_queue, - }) + Some(task) } /// Suggest next request size. /// The next request size is the current request size multiplied by sequential prefetch multiplier. - fn get_next_request_size(&self) -> usize { + fn get_next_request_size(&self, request_size: usize) -> usize { // TODO: this logic doesn't work well right now in the case where part_size < // first_request_size and sequential_prefetch_multiplier = 1. It ends up just repeatedly // shrinking the request size until it reaches 1. But this isn't a configuration we // currently expect to ever run in (part_size will always be >= 5MB for MPU reasons, and a // prefetcher with multiplier 1 is not very good). - let next_request_size = (self.next_request_size * self.inner.config.sequential_prefetch_multiplier) - .min(self.inner.config.max_request_size); - self.inner - .part_feed - .get_aligned_request_size(self.next_request_offset, next_request_size) + (request_size * self.config.sequential_prefetch_multiplier).min(self.config.max_request_size) } /// Reset this prefetch request to a new offset, clearing any existing tasks queued. @@ -354,15 +403,15 @@ where self.current_task = None; self.future_tasks.drain(..); self.backward_seek_window.clear(); - self.next_request_size = self.inner.config.first_request_size; self.next_sequential_read_offset = offset; + self.next_request_size = self.config.first_request_size; self.next_request_offset = offset; } /// Try to seek within the current inflight requests without restarting them. Returns true if /// the seek succeeded, in which case self.next_sequential_read_offset will be updated to the /// new offset. If this returns false, the prefetcher is in an unknown state and must be reset. - async fn try_seek(&mut self, offset: u64) -> Result>> { + async fn try_seek(&mut self, offset: u64) -> Result> { assert_ne!(offset, self.next_sequential_read_offset); trace!(from = self.next_sequential_read_offset, to = offset, "trying to seek"); if offset > self.next_sequential_read_offset { @@ -372,27 +421,27 @@ where } } - async fn try_seek_forward(&mut self, offset: u64) -> Result>> { + async fn try_seek_forward(&mut self, offset: u64) -> Result> { assert!(offset > self.next_sequential_read_offset); let total_seek_distance = offset - self.next_sequential_read_offset; let Some(current_task) = self.current_task.as_mut() else { // Can't seek if there's no requests in flight at all return Ok(false); }; - let future_remaining = self.future_tasks.iter().map(|task| task.remaining).sum::() as u64; + let future_remaining = self.future_tasks.iter().map(|task| task.remaining()).sum::() as u64; if total_seek_distance - >= (current_task.remaining as u64 + future_remaining).min(self.inner.config.max_forward_seek_distance) + >= (current_task.remaining() as u64 + future_remaining).min(self.config.max_forward_seek_distance) { // TODO maybe adjust the next_request_size somehow if we were still within // max_forward_seek_distance, so that strides > first_request_size can still get // prefetched. - trace!(?current_task.remaining, ?future_remaining, "seek failed: not enough inflight data"); + trace!(current_task_remaining=?current_task.remaining(), ?future_remaining, "seek failed: not enough inflight data"); return Ok(false); } // Jump ahead to the right request - if total_seek_distance >= current_task.remaining as u64 { - self.next_sequential_read_offset += current_task.remaining as u64; + if total_seek_distance >= current_task.remaining() as u64 { + self.next_sequential_read_offset += current_task.remaining() as u64; self.current_task = None; while let Some(next_request) = self.future_tasks.pop_front() { if next_request.end_offset() > offset { @@ -427,7 +476,7 @@ where Ok(true) } - fn try_seek_backward(&mut self, offset: u64) -> Result>> { + fn try_seek_backward(&mut self, offset: u64) -> Result> { assert!(offset < self.next_sequential_read_offset); let backwards_length_needed = self.next_sequential_read_offset - offset; let Some(parts) = self.backward_seek_window.read_back(backwards_length_needed as usize) else { @@ -437,17 +486,7 @@ where // We're going to create a new fake "request" that contains the parts we read out of the // window. That sounds a bit hacky, but it keeps all the read logic simple rather than // needing separate paths for backwards seeks vs others. - let (part_queue, part_queue_producer) = unbounded_part_queue(); - for part in parts { - part_queue_producer.push(Ok(part)); - } - let request = RequestTask { - task_handle: None, - remaining: backwards_length_needed as usize, - start_offset: offset, - total_size: backwards_length_needed as usize, - part_queue, - }; + let request = RequestTask::from_parts(parts, offset); if let Some(current_task) = self.current_task.take() { self.future_tasks.push_front(current_task); } @@ -460,49 +499,6 @@ where } } -/// A single GetObject request submitted to the S3 client -#[derive(Debug)] -struct RequestTask { - /// Handle on the task/future. The future is cancelled when handle is dropped. This is None if - /// the request is fake (created by seeking backwards in the stream) - task_handle: Option>, - remaining: usize, - start_offset: u64, - total_size: usize, - part_queue: PartQueue, -} - -impl RequestTask { - async fn read(&mut self, length: usize) -> Result> { - let part = self.part_queue.read(length).await?; - debug_assert!(part.len() <= self.remaining); - self.remaining -= part.len(); - Ok(part) - } - - fn end_offset(&self) -> u64 { - self.start_offset + self.total_size as u64 - } - - /// Some requests aren't actually streaming data (they're fake, created by backwards seeks), and - /// shouldn't be counted for prefetcher progress. - fn is_streaming(&self) -> bool { - self.task_handle.is_some() - } -} - -#[derive(Debug, Error)] -pub enum PrefetchReadError { - #[error("get request failed")] - GetRequestFailed(#[source] E), - - #[error("get request terminated unexpectedly")] - GetRequestTerminatedUnexpectedly, - - #[error("integrity check failed")] - Integrity(#[from] IntegrityError), -} - #[cfg(test)] mod tests { // It's convenient to write test constants like "1 * 1024 * 1024" for symmetry @@ -510,6 +506,7 @@ mod tests { use super::*; use futures::executor::{block_on, ThreadPool}; + use mountpoint_s3_client::error::{GetObjectError, ObjectClientError}; use mountpoint_s3_client::failure_client::{countdown_failure_client, RequestFailureMap}; use mountpoint_s3_client::mock_client::{ramp_bytes, MockClient, MockClientConfig, MockClientError, MockObject}; use proptest::proptest; @@ -518,9 +515,6 @@ mod tests { use std::collections::HashMap; use test_case::test_case; - const KB: usize = 1024; - const MB: usize = 1024 * 1024; - #[derive(Debug, Arbitrary)] struct TestConfig { #[proptest(strategy = "16usize..1*1024*1024")] @@ -537,18 +531,28 @@ mod tests { max_backward_seek_distance: u64, } - fn run_sequential_read_test(size: u64, read_size: usize, test_config: TestConfig) { + fn default_stream() -> ClientPartStream { + let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); + ClientPartStream::new(runtime) + } + + fn run_sequential_read_test( + part_stream: Stream, + size: u64, + read_size: usize, + test_config: TestConfig, + ) { let config = MockClientConfig { bucket: "test-bucket".to_string(), part_size: test_config.client_part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, size as usize, ETag::for_tests()); let etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size: test_config.first_request_size, max_request_size: test_config.max_request_size, sequential_prefetch_multiplier: test_config.sequential_prefetch_multiplier, @@ -556,10 +560,9 @@ mod tests { max_forward_seek_distance: test_config.max_forward_seek_distance, max_backward_seek_distance: test_config.max_backward_seek_distance, }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); - let mut request = prefetcher.get("test-bucket", "hello", size, etag); + let prefetcher = Prefetcher::new(part_stream, prefetcher_config); + let mut request = prefetcher.prefetch(client, "test-bucket", "hello", size, etag); let mut next_offset = 0; loop { @@ -585,7 +588,7 @@ mod tests { max_forward_seek_distance: 16 * 1024 * 1024, max_backward_seek_distance: 2 * 1024 * 1024, }; - run_sequential_read_test(1024 * 1024 + 111, 1024 * 1024, config); + run_sequential_read_test(default_stream(), 1024 * 1024 + 111, 1024 * 1024, config); } #[test] @@ -598,7 +601,7 @@ mod tests { max_forward_seek_distance: 16 * 1024 * 1024, max_backward_seek_distance: 2 * 1024 * 1024, }; - run_sequential_read_test(16 * 1024 * 1024 + 111, 1024 * 1024, config); + run_sequential_read_test(default_stream(), 16 * 1024 * 1024 + 111, 1024 * 1024, config); } #[test] @@ -611,10 +614,11 @@ mod tests { max_forward_seek_distance: 16 * 1024 * 1024, max_backward_seek_distance: 2 * 1024 * 1024, }; - run_sequential_read_test(256 * 1024 * 1024 + 111, 1024 * 1024, config); + run_sequential_read_test(default_stream(), 256 * 1024 * 1024 + 111, 1024 * 1024, config); } - fn fail_sequential_read_test( + fn fail_sequential_read_test( + part_stream: Stream, size: u64, read_size: usize, test_config: TestConfig, @@ -632,16 +636,15 @@ mod tests { let client = countdown_failure_client(client, get_failures, HashMap::new(), HashMap::new(), HashMap::new()); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size: test_config.first_request_size, max_request_size: test_config.max_request_size, sequential_prefetch_multiplier: test_config.sequential_prefetch_multiplier, ..Default::default() }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); - let mut request = prefetcher.get("test-bucket", "hello", size, etag); + let prefetcher = Prefetcher::new(part_stream, prefetcher_config); + let mut request = prefetcher.prefetch(Arc::new(client), "test-bucket", "hello", size, etag); let mut next_offset = 0; loop { @@ -682,50 +685,7 @@ mod tests { ))), ); - fail_sequential_read_test(1024 * 1024 + 111, 1024 * 1024, config, get_failures); - } - - #[test_case(256 * KB, 256 * KB, 8, 100 * MB, 8 * MB, 2 * MB; "next request size is smaller than part size")] - #[test_case(7 * MB, 256 * KB, 8, 100 * MB, 8 * MB, 1 * MB; "next request size is remaining bytes in the part")] - #[test_case(9 * MB, (2 * MB) + 11, 11, 100 * MB, 9 * MB, 18 * MB; "next request size is trimmed to part boundaries")] - #[test_case(8 * MB, 2 * MB, 8, 100 * MB, 8 * MB, 16 * MB; "next request size is multiple of the part size")] - #[test_case(8 * MB, 2 * MB, 100, 20 * MB, 8 * MB, 16 * MB; "max request size is trimmed to part boundaries")] - #[test_case(8 * MB, 2 * MB, 100, 24 * MB, 8 * MB, 24 * MB; "max request size is multiple of the part size")] - #[test_case(8 * MB, 2 * MB, 8, 3 * MB, 8 * MB, 3 * MB; "max request size is less than part size")] - fn test_get_next_request_size( - next_request_offset: usize, - current_request_size: usize, - prefetch_multiplier: usize, - max_request_size: usize, - part_size: usize, - expected_size: usize, - ) { - let object_size = 50 * 1024 * 1024; - - let config = MockClientConfig { - bucket: "test-bucket".to_string(), - part_size, - }; - let client = MockClient::new(config); - - let test_config = PrefetcherConfig { - first_request_size: 256 * 1024, - sequential_prefetch_multiplier: prefetch_multiplier, - max_request_size, - read_timeout: Duration::from_secs(60), - max_forward_seek_distance: 16 * 1024 * 1024, - max_backward_seek_distance: 2 * 1024 * 1024, - }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); - let etag = ETag::for_tests(); - - let mut request = prefetcher.get("test-bucket", "hello", object_size, etag); - - request.next_request_offset = next_request_offset as u64; - request.next_request_size = current_request_size; - let next_request_size = request.get_next_request_size(); - assert_eq!(next_request_size, expected_size); + fail_sequential_read_test(default_stream(), 1024 * 1024 + 111, 1024 * 1024, config, get_failures); } proptest! { @@ -735,13 +695,13 @@ mod tests { read_size in 1usize..1 * 1024 * 1024, config: TestConfig, ) { - run_sequential_read_test(size, read_size, config); + run_sequential_read_test(default_stream(), size, read_size, config); } #[test] fn proptest_sequential_read_small_read_size(size in 1u64..1 * 1024 * 1024, read_factor in 1usize..10, config: TestConfig) { let read_size = (size as usize / read_factor).max(1); - run_sequential_read_test(size, read_size, config); + run_sequential_read_test(default_stream(), size, read_size, config); } } @@ -757,21 +717,26 @@ mod tests { max_forward_seek_distance: 1, max_backward_seek_distance: 18668, }; - run_sequential_read_test(object_size, read_size, config); + run_sequential_read_test(default_stream(), object_size, read_size, config); } - fn run_random_read_test(object_size: u64, reads: Vec<(u64, usize)>, test_config: TestConfig) { + fn run_random_read_test( + part_stream: Stream, + object_size: u64, + reads: Vec<(u64, usize)>, + test_config: TestConfig, + ) { let config = MockClientConfig { bucket: "test-bucket".to_string(), part_size: test_config.client_part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, object_size as usize, ETag::for_tests()); let etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size: test_config.first_request_size, max_request_size: test_config.max_request_size, sequential_prefetch_multiplier: test_config.sequential_prefetch_multiplier, @@ -779,10 +744,9 @@ mod tests { max_backward_seek_distance: test_config.max_backward_seek_distance, ..Default::default() }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); - let mut request = prefetcher.get("test-bucket", "hello", object_size, etag); + let prefetcher = Prefetcher::new(part_stream, prefetcher_config); + let mut request = prefetcher.prefetch(client, "test-bucket", "hello", object_size, etag); for (offset, length) in reads { assert!(offset < object_size); @@ -826,7 +790,7 @@ mod tests { config: TestConfig, ) { let (object_size, reads) = reads; - run_random_read_test(object_size, reads, config); + run_random_read_test(default_stream(), object_size, reads, config); } } @@ -842,7 +806,7 @@ mod tests { max_forward_seek_distance: 16 * 1024 * 1024, max_backward_seek_distance: 2 * 1024 * 1024, }; - run_random_read_test(object_size, reads, config); + run_random_read_test(default_stream(), object_size, reads, config); } #[test] @@ -857,7 +821,7 @@ mod tests { max_forward_seek_distance: 16 * 1024 * 1024, max_backward_seek_distance: 2 * 1024 * 1024, }; - run_random_read_test(object_size, reads, config); + run_random_read_test(default_stream(), object_size, reads, config); } #[test] @@ -872,7 +836,7 @@ mod tests { max_forward_seek_distance: 2260662, max_backward_seek_distance: 2369799, }; - run_random_read_test(object_size, reads, config); + run_random_read_test(default_stream(), object_size, reads, config); } #[test] @@ -887,7 +851,7 @@ mod tests { max_forward_seek_distance: 2810651, max_backward_seek_distance: 3531090, }; - run_random_read_test(object_size, reads, config); + run_random_read_test(default_stream(), object_size, reads, config); } #[test_case(0, 25; "no first read")] @@ -902,22 +866,23 @@ mod tests { bucket: "test-bucket".to_string(), part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, OBJECT_SIZE, ETag::for_tests()); let etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size: FIRST_REQUEST_SIZE, ..Default::default() }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); + + let prefetcher = Prefetcher::new(default_stream(), prefetcher_config); // Try every possible seek from first_read_size for offset in first_read_size + 1..OBJECT_SIZE { - let mut request = prefetcher.get("test-bucket", "hello", OBJECT_SIZE as u64, etag.clone()); + let mut request = + prefetcher.prefetch(client.clone(), "test-bucket", "hello", OBJECT_SIZE as u64, etag.clone()); if first_read_size > 0 { let _first_read = block_on(request.read(0, first_read_size)).unwrap(); } @@ -939,22 +904,22 @@ mod tests { bucket: "test-bucket".to_string(), part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, OBJECT_SIZE, ETag::for_tests()); let etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size: FIRST_REQUEST_SIZE, ..Default::default() }; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - let prefetcher = Prefetcher::new(Arc::new(client), runtime, test_config); + let prefetcher = Prefetcher::new(default_stream(), prefetcher_config); // Try every possible seek from first_read_size for offset in 0..first_read_size { - let mut request = prefetcher.get("test-bucket", "hello", OBJECT_SIZE as u64, etag.clone()); + let mut request = + prefetcher.prefetch(client.clone(), "test-bucket", "hello", OBJECT_SIZE as u64, etag.clone()); if first_read_size > 0 { let _first_read = block_on(request.read(0, first_read_size)).unwrap(); } @@ -968,7 +933,7 @@ mod tests { #[cfg(feature = "shuttle")] mod shuttle_tests { use super::*; - use futures::task::{FutureObj, SpawnError}; + use futures::task::{FutureObj, Spawn, SpawnError}; use shuttle::future::block_on; use shuttle::rand::Rng; use shuttle::{check_pct, check_random}; @@ -995,13 +960,13 @@ mod tests { bucket: "test-bucket".to_string(), part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, object_size as usize, ETag::for_tests()); let file_etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size, max_request_size, sequential_prefetch_multiplier, @@ -1010,9 +975,8 @@ mod tests { ..Default::default() }; - let prefetcher = Prefetcher::new(Arc::new(client), ShuttleRuntime, test_config); - - let mut request = prefetcher.get("test-bucket", "hello", object_size, file_etag); + let prefetcher = Prefetcher::new(ClientPartStream::new(ShuttleRuntime), prefetcher_config); + let mut request = prefetcher.prefetch(client, "test-bucket", "hello", object_size, file_etag); let mut next_offset = 0; loop { @@ -1052,13 +1016,13 @@ mod tests { bucket: "test-bucket".to_string(), part_size, }; - let client = MockClient::new(config); + let client = Arc::new(MockClient::new(config)); let object = MockObject::ramp(0xaa, object_size as usize, ETag::for_tests()); let file_etag = object.etag(); client.add_object("hello", object); - let test_config = PrefetcherConfig { + let prefetcher_config = PrefetcherConfig { first_request_size, max_request_size, sequential_prefetch_multiplier, @@ -1067,9 +1031,8 @@ mod tests { ..Default::default() }; - let prefetcher = Prefetcher::new(Arc::new(client), ShuttleRuntime, test_config); - - let mut request = prefetcher.get("test-bucket", "hello", object_size, file_etag); + let prefetcher = Prefetcher::new(ClientPartStream::new(ShuttleRuntime), prefetcher_config); + let mut request = prefetcher.prefetch(client, "test-bucket", "hello", object_size, file_etag); let num_reads = rng.gen_range(10usize..50); for _ in 0..num_reads { diff --git a/mountpoint-s3/src/prefetch/feed.rs b/mountpoint-s3/src/prefetch/feed.rs deleted file mode 100644 index b6067cafe..000000000 --- a/mountpoint-s3/src/prefetch/feed.rs +++ /dev/null @@ -1,128 +0,0 @@ -use std::{fmt::Debug, ops::Range, sync::Arc}; - -use async_trait::async_trait; -use bytes::Bytes; -use futures::{pin_mut, StreamExt}; -use mountpoint_s3_client::{ - error::{GetObjectError, ObjectClientError}, - types::ETag, - ObjectClient, -}; -use mountpoint_s3_crt::checksums::crc32c; -use tracing::{error, trace}; - -use crate::checksums::ChecksummedBytes; -use crate::prefetch::{part::Part, part_queue::PartQueueProducer}; - -/// A generic interface to retrieve data from objects in a S3-like store. -#[async_trait] -pub trait ObjectPartFeed { - /// Get the content of an object in fixed size parts. The parts are pushed to the provided `part_sink` - /// and are guaranteed to be contiguous and in the correct order. Callers need to specify a preferred - /// size for the parts, but implementations are allowed to ignore it. - async fn get_object_parts( - &self, - bucket: &str, - key: &str, - range: Range, - if_match: ETag, - preferred_part_size: usize, - part_sink: PartQueueProducer>, - ); - - /// Adjust the size of a request to align to optimal part boundaries for this client. - fn get_aligned_request_size(&self, offset: u64, preferred_size: usize) -> usize; -} - -/// [ObjectPartFeed] implementation which delegates retrieving object data to a [Client]. -#[derive(Debug)] -pub struct ClientPartFeed { - client: Arc, -} - -impl ClientPartFeed { - pub fn new(client: Arc) -> Self { - Self { client } - } -} - -#[async_trait] -impl ObjectPartFeed for ClientPartFeed -where - Client: ObjectClient + Send + Sync + 'static, -{ - async fn get_object_parts( - &self, - bucket: &str, - key: &str, - range: Range, - if_match: ETag, - preferred_part_size: usize, - part_queue_producer: PartQueueProducer>, - ) { - assert!(preferred_part_size > 0); - let get_object_result = match self.client.get_object(bucket, key, Some(range), Some(if_match)).await { - Ok(get_object_result) => get_object_result, - Err(e) => { - error!(error=?e, "GetObject request failed"); - part_queue_producer.push(Err(e)); - return; - } - }; - - pin_mut!(get_object_result); - loop { - match get_object_result.next().await { - Some(Ok((offset, body))) => { - trace!(offset, length = body.len(), "received GetObject part"); - // pre-split the body into multiple parts as suggested by preferred part size - // in order to avoid validating checksum on large parts at read. - let mut body: Bytes = body.into(); - let mut curr_offset = offset; - loop { - let chunk_size = preferred_part_size.min(body.len()); - if chunk_size == 0 { - break; - } - let chunk = body.split_to(chunk_size); - // S3 doesn't provide checksum for us if the request range is not aligned to - // object part boundaries, so we're computing our own checksum here. - let checksum = crc32c::checksum(&chunk); - let checksum_bytes = ChecksummedBytes::new(chunk, checksum); - let part = Part::new(key, curr_offset, checksum_bytes); - curr_offset += part.len() as u64; - part_queue_producer.push(Ok(part)); - } - } - Some(Err(e)) => { - error!(error=?e, "GetObject body part failed"); - part_queue_producer.push(Err(e)); - break; - } - None => break, - } - } - trace!("request finished"); - } - - fn get_aligned_request_size(&self, offset: u64, preferred_length: usize) -> usize { - // If the request size is bigger than a part size we will try to align it to part boundaries. - let part_alignment = self.client.part_size().unwrap_or(8 * 1024 * 1024); - let offset_in_part = (offset % part_alignment as u64) as usize; - if offset_in_part != 0 { - // if the offset is not at the start of the part we will drain all the bytes from that part first - let remaining_in_part = part_alignment - offset_in_part; - preferred_length.min(remaining_in_part) - } else { - // if the request size is smaller than the part size, just return that value - if preferred_length < part_alignment { - preferred_length - } else { - // if it exceeds part boundaries, trim it to the part boundaries - let request_boundary = offset + preferred_length as u64; - let remainder = (request_boundary % part_alignment as u64) as usize; - preferred_length - remainder - } - } - } -} diff --git a/mountpoint-s3/src/prefetch/part_queue.rs b/mountpoint-s3/src/prefetch/part_queue.rs index 07a353bab..386b92c35 100644 --- a/mountpoint-s3/src/prefetch/part_queue.rs +++ b/mountpoint-s3/src/prefetch/part_queue.rs @@ -11,20 +11,20 @@ use crate::sync::AsyncMutex; /// A queue of [Part]s where the first part can be partially read from if the reader doesn't want /// the entire part in one shot. #[derive(Debug)] -pub struct PartQueue { +pub struct PartQueue { current_part: AsyncMutex>, - receiver: Receiver>, + receiver: Receiver>>, failed: AtomicBool, } /// Producer side of the queue of [Part]s. #[derive(Debug)] -pub struct PartQueueProducer { - sender: Sender>, +pub struct PartQueueProducer { + sender: Sender>>, } /// Creates an unbounded [PartQueue] and its related [PartQueueProducer]. -pub fn unbounded_part_queue() -> (PartQueue, PartQueueProducer) { +pub fn unbounded_part_queue() -> (PartQueue, PartQueueProducer) { let (sender, receiver) = unbounded(); let part_queue = PartQueue { current_part: AsyncMutex::new(None), @@ -55,14 +55,14 @@ impl PartQueue { } else { // Do `try_recv` first so we can track whether the read is starved or not if let Ok(part) = self.receiver.try_recv() { - part.map_err(|e| PrefetchReadError::GetRequestFailed(e)) + part } else { let start = Instant::now(); let part = self.receiver.recv().await; metrics::histogram!("prefetch.part_queue_starved_us", start.elapsed().as_micros() as f64); match part { Err(RecvError) => Err(PrefetchReadError::GetRequestTerminatedUnexpectedly), - Ok(part) => part.map_err(|e| PrefetchReadError::GetRequestFailed(e)), + Ok(part) => part, } } }; @@ -88,7 +88,7 @@ impl PartQueue { impl PartQueueProducer { /// Push a new [Part] onto the back of the queue - pub fn push(&self, part: Result) { + pub fn push(&self, part: Result>) { // Unbounded channel will never actually block let send_result = self.sender.send_blocking(part); if send_result.is_err() { diff --git a/mountpoint-s3/src/prefetch/part_stream.rs b/mountpoint-s3/src/prefetch/part_stream.rs new file mode 100644 index 000000000..a7205022a --- /dev/null +++ b/mountpoint-s3/src/prefetch/part_stream.rs @@ -0,0 +1,271 @@ +use std::{fmt::Debug, ops::Range}; + +use bytes::Bytes; +use futures::task::SpawnExt; +use futures::{pin_mut, task::Spawn, StreamExt}; +use mountpoint_s3_client::{types::ETag, ObjectClient}; +use mountpoint_s3_crt::checksums::crc32c; +use tracing::{debug_span, error, trace, Instrument}; + +use crate::checksums::ChecksummedBytes; +use crate::prefetch::part::Part; +use crate::prefetch::part_queue::unbounded_part_queue; +use crate::prefetch::task::RequestTask; +use crate::prefetch::PrefetchReadError; + +/// A generic interface to retrieve data from objects in a S3-like store. +pub trait ObjectPartStream { + /// Spawns a request to get the content of an object. The object data will be retrieved in fixed size + /// parts and can then be consumed using [RequestTask::read]. Callers need to specify a preferred + /// size for the parts, but implementations are allowed to ignore it. + fn spawn_get_object_request( + &self, + client: &Client, + bucket: &str, + key: &str, + if_match: ETag, + range: RequestRange, + preferred_part_size: usize, + ) -> RequestTask + where + Client: ObjectClient + Clone + Send + Sync + 'static; +} + +/// The range of a [ObjectPartStream::spawn_get_object_request] request. +/// Includes the total size of the object. +#[derive(Clone, Copy)] +pub struct RequestRange { + object_size: usize, + offset: u64, + size: usize, +} + +impl RequestRange { + pub fn new(object_size: usize, offset: u64, size: usize) -> Self { + let size = size.min(object_size.saturating_sub(offset as usize)); + Self { + object_size, + offset, + size, + } + } + + pub fn len(&self) -> usize { + self.size + } + + pub fn is_empty(&self) -> bool { + self.size == 0 + } + + pub fn object_size(&self) -> usize { + self.object_size + } + + pub fn start(&self) -> u64 { + self.offset + } + + pub fn end(&self) -> u64 { + self.offset + self.size as u64 + } + + /// Trim the start of this range at the given `start_offset`. + /// Note `start_offset` is clamped to the original range. + pub fn trim_start(&self, start_offset: u64) -> Self { + let end = self.end(); + let offset = start_offset.clamp(self.offset, end); + let size = end.saturating_sub(offset) as usize; + Self { + object_size: self.object_size, + offset, + size, + } + } + + /// Trim the end of this range at the given `end_offset`. + /// Note `end_offset` is clamped to the original range. + pub fn trim_end(&self, end_offset: u64) -> Self { + let end = end_offset.clamp(self.offset, self.end()); + let size = end.saturating_sub(self.offset) as usize; + Self { + object_size: self.object_size, + offset: self.offset, + size, + } + } +} + +impl From for Range { + fn from(val: RequestRange) -> Self { + val.start()..val.end() + } +} + +impl Debug for RequestRange { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}..{} out of {}", self.start(), self.end(), self.object_size()) + } +} + +/// [ObjectPartStream] implementation which delegates retrieving object data to a [Client]. +#[derive(Debug)] +pub struct ClientPartStream { + runtime: Runtime, +} + +impl ClientPartStream +where + Runtime: Spawn, +{ + pub fn new(runtime: Runtime) -> Self { + Self { runtime } + } +} + +impl ObjectPartStream for ClientPartStream +where + Runtime: Spawn, +{ + fn spawn_get_object_request( + &self, + client: &Client, + bucket: &str, + key: &str, + if_match: ETag, + range: RequestRange, + preferred_part_size: usize, + ) -> RequestTask + where + Client: ObjectClient + Clone + Send + Sync + 'static, + { + assert!(preferred_part_size > 0); + let request_range = get_aligned_request_range(range, client.part_size().unwrap_or(8 * 1024 * 1024)); + let start = request_range.start(); + let size = request_range.len(); + + let (part_queue, part_queue_producer) = unbounded_part_queue(); + trace!(range=?request_range, "spawning request"); + + let request_task = { + let client = client.clone(); + let bucket = bucket.to_owned(); + let key = key.to_owned(); + let span = debug_span!("prefetch", range=?request_range); + + async move { + let get_object_result = match client + .get_object(&bucket, &key, Some(request_range.into()), Some(if_match)) + .await + { + Ok(get_object_result) => get_object_result, + Err(e) => { + error!(error=?e, "GetObject request failed"); + part_queue_producer.push(Err(PrefetchReadError::GetRequestFailed(e))); + return; + } + }; + + pin_mut!(get_object_result); + loop { + match get_object_result.next().await { + Some(Ok((offset, body))) => { + trace!(offset, length = body.len(), "received GetObject part"); + // pre-split the body into multiple parts as suggested by preferred part size + // in order to avoid validating checksum on large parts at read. + let mut body: Bytes = body.into(); + let mut curr_offset = offset; + loop { + let chunk_size = preferred_part_size.min(body.len()); + if chunk_size == 0 { + break; + } + let chunk = body.split_to(chunk_size); + // S3 doesn't provide checksum for us if the request range is not aligned to + // object part boundaries, so we're computing our own checksum here. + let checksum = crc32c::checksum(&chunk); + let checksum_bytes = ChecksummedBytes::new(chunk, checksum); + let part = Part::new(&key, curr_offset, checksum_bytes); + curr_offset += part.len() as u64; + part_queue_producer.push(Ok(part)); + } + } + Some(Err(e)) => { + error!(error=?e, "GetObject body part failed"); + part_queue_producer.push(Err(PrefetchReadError::GetRequestFailed(e))); + break; + } + None => break, + } + } + trace!("request finished"); + } + .instrument(span) + }; + + let task_handle = self.runtime.spawn_with_handle(request_task).unwrap(); + + RequestTask::from_handle(task_handle, size, start, part_queue) + } +} + +fn get_aligned_request_range(range: RequestRange, part_alignment: usize) -> RequestRange { + let object_size = range.object_size(); + let offset = range.start(); + let preferred_length = range.len(); + + // If the request size is bigger than a part size we will try to align it to part boundaries. + let offset_in_part = (offset % part_alignment as u64) as usize; + let size = if offset_in_part != 0 { + // if the offset is not at the start of the part we will drain all the bytes from that part first + let remaining_in_part = part_alignment - offset_in_part; + preferred_length.min(remaining_in_part) + } else { + // if the request size is smaller than the part size, just return that value + if preferred_length < part_alignment { + preferred_length + } else { + // if it exceeds part boundaries, trim it to the part boundaries + let request_boundary = offset + preferred_length as u64; + let remainder = (request_boundary % part_alignment as u64) as usize; + preferred_length - remainder + } + }; + RequestRange::new(object_size, offset, size) +} + +#[cfg(test)] +mod tests { + // It's convenient to write test constants like "1 * 1024 * 1024" for symmetry + #![allow(clippy::identity_op)] + + use super::*; + + use test_case::test_case; + + const KB: usize = 1024; + const MB: usize = 1024 * 1024; + + #[test_case(256 * KB, 256 * KB, 8, 100 * MB, 8 * MB, 2 * MB; "next request size is smaller than part size")] + #[test_case(7 * MB, 256 * KB, 8, 100 * MB, 8 * MB, 1 * MB; "next request size is remaining bytes in the part")] + #[test_case(9 * MB, (2 * MB) + 11, 11, 100 * MB, 9 * MB, 18 * MB; "next request size is trimmed to part boundaries")] + #[test_case(8 * MB, 2 * MB, 8, 100 * MB, 8 * MB, 16 * MB; "next request size is multiple of the part size")] + #[test_case(8 * MB, 2 * MB, 100, 20 * MB, 8 * MB, 16 * MB; "max request size is trimmed to part boundaries")] + #[test_case(8 * MB, 2 * MB, 100, 24 * MB, 8 * MB, 24 * MB; "max request size is multiple of the part size")] + #[test_case(8 * MB, 2 * MB, 8, 3 * MB, 8 * MB, 3 * MB; "max request size is less than part size")] + fn test_get_aligned_request_range( + next_request_offset: usize, + current_request_size: usize, + prefetch_multiplier: usize, + max_request_size: usize, + part_size: usize, + expected_size: usize, + ) { + let object_size = 50 * 1024 * 1024; + let request_size = (current_request_size * prefetch_multiplier).min(max_request_size); + let range = RequestRange::new(object_size, next_request_offset as u64, request_size); + + let aligned_range = get_aligned_request_range(range, part_size); + assert_eq!(aligned_range.len(), expected_size); + } +} diff --git a/mountpoint-s3/src/prefetch/task.rs b/mountpoint-s3/src/prefetch/task.rs new file mode 100644 index 000000000..27c6fde7f --- /dev/null +++ b/mountpoint-s3/src/prefetch/task.rs @@ -0,0 +1,74 @@ +use futures::future::RemoteHandle; + +use crate::prefetch::part::Part; +use crate::prefetch::part_queue::{unbounded_part_queue, PartQueue}; +use crate::prefetch::PrefetchReadError; + +/// A single GetObject request submitted to the S3 client +#[derive(Debug)] +pub struct RequestTask { + /// Handle on the task/future. The future is cancelled when handle is dropped. This is None if + /// the request is fake (created by seeking backwards in the stream) + task_handle: Option>, + remaining: usize, + start_offset: u64, + total_size: usize, + part_queue: PartQueue, +} + +impl RequestTask { + pub fn from_handle(task_handle: RemoteHandle<()>, size: usize, offset: u64, part_queue: PartQueue) -> Self { + Self { + task_handle: Some(task_handle), + remaining: size, + start_offset: offset, + total_size: size, + part_queue, + } + } + + pub fn from_parts(parts: impl IntoIterator, offset: u64) -> Self { + let mut size = 0; + let (part_queue, part_queue_producer) = unbounded_part_queue(); + for part in parts { + size += part.len(); + part_queue_producer.push(Ok(part)); + } + Self { + task_handle: None, + remaining: size, + start_offset: offset, + total_size: size, + part_queue, + } + } + + pub async fn read(&mut self, length: usize) -> Result> { + let part = self.part_queue.read(length).await?; + debug_assert!(part.len() <= self.remaining); + self.remaining -= part.len(); + Ok(part) + } + + pub fn start_offset(&self) -> u64 { + self.start_offset + } + + pub fn end_offset(&self) -> u64 { + self.start_offset + self.total_size as u64 + } + + pub fn total_size(&self) -> usize { + self.total_size + } + + pub fn remaining(&self) -> usize { + self.remaining + } + + /// Some requests aren't actually streaming data (they're fake, created by backwards seeks), and + /// shouldn't be counted for prefetcher progress. + pub fn is_streaming(&self) -> bool { + self.task_handle.is_some() + } +} diff --git a/mountpoint-s3/tests/common/mod.rs b/mountpoint-s3/tests/common/mod.rs index 3a78b7b93..a33158c1e 100644 --- a/mountpoint-s3/tests/common/mod.rs +++ b/mountpoint-s3/tests/common/mod.rs @@ -1,6 +1,7 @@ use fuser::{FileAttr, FileType}; use futures::executor::ThreadPool; use mountpoint_s3::fs::{self, DirectoryEntry, DirectoryReplier, ReadReplier, ToErrno}; +use mountpoint_s3::prefetch::{default_prefetch, DefaultPrefetcher}; use mountpoint_s3::prefix::Prefix; use mountpoint_s3::{S3Filesystem, S3FilesystemConfig}; use mountpoint_s3_client::mock_client::{MockClient, MockClientConfig}; @@ -9,18 +10,20 @@ use mountpoint_s3_crt::common::rust_log_adapter::RustLogAdapter; use std::collections::VecDeque; use std::sync::Arc; +pub type TestS3Filesystem = S3Filesystem>; + pub fn make_test_filesystem( bucket: &str, prefix: &Prefix, config: S3FilesystemConfig, -) -> (Arc, S3Filesystem, ThreadPool>) { +) -> (Arc, TestS3Filesystem>) { let client_config = MockClientConfig { bucket: bucket.to_string(), part_size: 1024 * 1024, }; let client = Arc::new(MockClient::new(client_config)); - let fs = make_test_filesystem_with_client(Arc::clone(&client), bucket, prefix, config); + let fs = make_test_filesystem_with_client(client.clone(), bucket, prefix, config); (client, fs) } @@ -29,12 +32,13 @@ pub fn make_test_filesystem_with_client( bucket: &str, prefix: &Prefix, config: S3FilesystemConfig, -) -> S3Filesystem +) -> TestS3Filesystem where Client: ObjectClient + Send + Sync + 'static, { let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - S3Filesystem::new(client, runtime, bucket, prefix, config) + let prefetcher = default_prefetch(runtime, Default::default()); + S3Filesystem::new(client, prefetcher, bucket, prefix, config) } #[track_caller] diff --git a/mountpoint-s3/tests/fs.rs b/mountpoint-s3/tests/fs.rs index 8345d756e..e4a48f7ef 100644 --- a/mountpoint-s3/tests/fs.rs +++ b/mountpoint-s3/tests/fs.rs @@ -704,7 +704,12 @@ async fn test_upload_aborted_on_write_failure() { Default::default(), put_failures, ); - let fs = make_test_filesystem_with_client(failure_client, BUCKET_NAME, &Default::default(), Default::default()); + let fs = make_test_filesystem_with_client( + Arc::new(failure_client), + BUCKET_NAME, + &Default::default(), + Default::default(), + ); let mode = libc::S_IFREG | libc::S_IRWXU; // regular file + 0700 permissions let dentry = fs.mknod(FUSE_ROOT_INODE, FILE_NAME.as_ref(), mode, 0, 0).await.unwrap(); @@ -775,7 +780,12 @@ async fn test_upload_aborted_on_fsync_failure() { Default::default(), put_failures, ); - let fs = make_test_filesystem_with_client(failure_client, BUCKET_NAME, &Default::default(), Default::default()); + let fs = make_test_filesystem_with_client( + Arc::new(failure_client), + BUCKET_NAME, + &Default::default(), + Default::default(), + ); let mode = libc::S_IFREG | libc::S_IRWXU; // regular file + 0700 permissions let dentry = fs.mknod(FUSE_ROOT_INODE, FILE_NAME.as_ref(), mode, 0, 0).await.unwrap(); @@ -831,7 +841,12 @@ async fn test_upload_aborted_on_release_failure() { Default::default(), put_failures, ); - let fs = make_test_filesystem_with_client(failure_client, BUCKET_NAME, &Default::default(), Default::default()); + let fs = make_test_filesystem_with_client( + Arc::new(failure_client), + BUCKET_NAME, + &Default::default(), + Default::default(), + ); let mode = libc::S_IFREG | libc::S_IRWXU; // regular file + 0700 permissions let dentry = fs.mknod(FUSE_ROOT_INODE, FILE_NAME.as_ref(), mode, 0, 0).await.unwrap(); diff --git a/mountpoint-s3/tests/fuse_tests/mod.rs b/mountpoint-s3/tests/fuse_tests/mod.rs index e0d6f69c7..aab0860ff 100644 --- a/mountpoint-s3/tests/fuse_tests/mod.rs +++ b/mountpoint-s3/tests/fuse_tests/mod.rs @@ -14,15 +14,19 @@ mod write_test; use std::ffi::OsStr; use std::fs::ReadDir; +use std::path::Path; +use std::sync::Arc; use aws_sdk_s3::primitives::ByteStream; use aws_sdk_sts::config::Region; use fuser::{BackgroundSession, MountOption, Session}; use futures::Future; use mountpoint_s3::fuse::S3FuseFilesystem; +use mountpoint_s3::prefetch::{Prefetch, PrefetcherConfig}; use mountpoint_s3::prefix::Prefix; use mountpoint_s3::S3FilesystemConfig; use mountpoint_s3_client::types::PutObjectParams; +use mountpoint_s3_client::ObjectClient; use rand::RngCore; use rand_chacha::rand_core::OsRng; use tempfile::TempDir; @@ -59,6 +63,7 @@ pub type TestClientBox = Box; pub struct TestSessionConfig { pub part_size: usize, pub filesystem_config: S3FilesystemConfig, + pub prefetcher_config: PrefetcherConfig, } impl Default for TestSessionConfig { @@ -66,16 +71,47 @@ impl Default for TestSessionConfig { Self { part_size: 8 * 1024 * 1024, filesystem_config: Default::default(), + prefetcher_config: Default::default(), } } } +fn create_fuse_session( + client: Client, + prefetcher: Prefetcher, + bucket: &str, + prefix: &str, + mount_dir: &Path, + filesystem_config: S3FilesystemConfig, +) -> BackgroundSession +where + Client: ObjectClient + Send + Sync + 'static, + Prefetcher: Prefetch + Send + Sync + 'static, +{ + let options = vec![ + MountOption::DefaultPermissions, + MountOption::FSName("mountpoint-s3".to_string()), + MountOption::NoAtime, + MountOption::AutoUnmount, + MountOption::AllowOther, + ]; + + let prefix = Prefix::new(prefix).expect("valid prefix"); + let session = Session::new( + S3FuseFilesystem::new(client, prefetcher, bucket, &prefix, filesystem_config), + mount_dir, + &options, + ) + .unwrap(); + + BackgroundSession::new(session).unwrap() +} + mod mock_session { use super::*; - use std::sync::Arc; - use futures::executor::ThreadPool; + use mountpoint_s3::prefetch::default_prefetch; use mountpoint_s3_client::mock_client::{MockClient, MockClientConfig, MockObject}; /// Create a FUSE mount backed by a mock object client that does not talk to S3 @@ -94,39 +130,28 @@ mod mock_session { part_size: test_config.part_size, }; let client = Arc::new(MockClient::new(client_config)); - - let options = vec![ - MountOption::DefaultPermissions, - MountOption::FSName("mountpoint-s3".to_string()), - MountOption::NoAtime, - MountOption::AutoUnmount, - MountOption::AllowOther, - ]; - let runtime = ThreadPool::builder().pool_size(1).create().unwrap(); - - let prefix = Prefix::new(&prefix).expect("valid prefix"); - let session = Session::new( - S3FuseFilesystem::new( - Arc::clone(&client), - runtime, - bucket, - &prefix, - test_config.filesystem_config, - ), + let prefetcher = default_prefetch(runtime, test_config.prefetcher_config); + let session = create_fuse_session( + client.clone(), + prefetcher, + bucket, + &prefix, mount_dir.path(), - &options, - ) - .unwrap(); + test_config.filesystem_config, + ); + let test_client = create_test_client(client, &prefix); - let session = BackgroundSession::new(session).unwrap(); + (mount_dir, session, test_client) + } + fn create_test_client(client: Arc, prefix: &str) -> TestClientBox { let test_client = MockTestClient { - prefix: prefix.to_string(), + prefix: prefix.to_owned(), client, }; - (mount_dir, session, Box::new(test_client)) + Box::new(test_client) } struct MockTestClient { @@ -203,6 +228,7 @@ mod s3_session { use aws_sdk_s3::primitives::ByteStream; use aws_sdk_s3::types::{ChecksumAlgorithm, GlacierJobParameters, RestoreRequest, Tier}; use aws_sdk_s3::Client; + use mountpoint_s3::prefetch::default_prefetch; use mountpoint_s3_client::config::{EndpointConfig, S3ClientConfig}; use mountpoint_s3_client::S3CrtClient; @@ -218,33 +244,29 @@ mod s3_session { .endpoint_config(EndpointConfig::new(®ion)); let client = S3CrtClient::new(client_config).unwrap(); let runtime = client.event_loop_group(); - - let options = vec![ - MountOption::DefaultPermissions, - MountOption::FSName("mountpoint-s3".to_string()), - MountOption::NoAtime, - MountOption::AutoUnmount, - MountOption::AllowOther, - ]; - - let prefix = Prefix::new(&prefix).expect("valid prefix"); - let session = Session::new( - S3FuseFilesystem::new(client, runtime, &bucket, &prefix, test_config.filesystem_config), + let prefetcher = default_prefetch(runtime, test_config.prefetcher_config); + let session = create_fuse_session( + client, + prefetcher, + &bucket, + &prefix, mount_dir.path(), - &options, - ) - .unwrap(); + test_config.filesystem_config, + ); + let test_client = create_test_client(®ion, &bucket, &prefix); - let session = BackgroundSession::new(session).unwrap(); + (mount_dir, session, test_client) + } - let sdk_client = tokio_block_on(async { get_test_sdk_client(®ion).await }); + fn create_test_client(region: &str, bucket: &str, prefix: &str) -> TestClientBox { + let sdk_client = tokio_block_on(async { get_test_sdk_client(region).await }); let test_client = SDKTestClient { - prefix: prefix.to_string(), - bucket, + prefix: prefix.to_owned(), + bucket: bucket.to_owned(), sdk_client, }; - (mount_dir, session, Box::new(test_client)) + Box::new(test_client) } async fn get_test_sdk_client(region: &str) -> aws_sdk_s3::Client { diff --git a/mountpoint-s3/tests/fuse_tests/prefetch_test.rs b/mountpoint-s3/tests/fuse_tests/prefetch_test.rs index 5592715eb..39d81e5b1 100644 --- a/mountpoint-s3/tests/fuse_tests/prefetch_test.rs +++ b/mountpoint-s3/tests/fuse_tests/prefetch_test.rs @@ -1,6 +1,5 @@ use fuser::BackgroundSession; use mountpoint_s3::prefetch::PrefetcherConfig; -use mountpoint_s3::S3FilesystemConfig; use std::fs::{File, OpenOptions}; use std::io::Read; use tempfile::TempDir; @@ -60,15 +59,10 @@ where ..Default::default() }; - let filesystem_config = S3FilesystemConfig { - prefetcher_config, - ..Default::default() - }; - let (mount_point, _session, mut test_client) = creator_fn( prefix, TestSessionConfig { - filesystem_config, + prefetcher_config, ..Default::default() }, ); diff --git a/mountpoint-s3/tests/fuse_tests/read_test.rs b/mountpoint-s3/tests/fuse_tests/read_test.rs index 2bc5e9d49..170d45b8e 100644 --- a/mountpoint-s3/tests/fuse_tests/read_test.rs +++ b/mountpoint-s3/tests/fuse_tests/read_test.rs @@ -66,14 +66,10 @@ fn basic_read_test_s3() { basic_read_test(crate::fuse_tests::s3_session::new, "basic_read_test"); } -#[test] -fn basic_read_test_mock() { - basic_read_test(crate::fuse_tests::mock_session::new, ""); -} - -#[test] -fn basic_read_test_mock_prefix() { - basic_read_test(crate::fuse_tests::mock_session::new, "basic_read_test"); +#[test_case("")] +#[test_case("basic_read_test")] +fn basic_read_test_mock(prefix: &str) { + basic_read_test(crate::fuse_tests::mock_session::new, prefix); } #[derive(PartialEq)] diff --git a/mountpoint-s3/tests/reftests/harness.rs b/mountpoint-s3/tests/reftests/harness.rs index f2ff52718..d64fe0e58 100644 --- a/mountpoint-s3/tests/reftests/harness.rs +++ b/mountpoint-s3/tests/reftests/harness.rs @@ -5,18 +5,17 @@ use std::sync::Arc; use std::time::Duration; use fuser::FileType; -use futures::executor::ThreadPool; use futures::future::{BoxFuture, FutureExt}; use mountpoint_s3::fs::{self, CacheConfig, InodeNo, ReadReplier, ToErrno, FUSE_ROOT_INODE}; use mountpoint_s3::prefix::Prefix; -use mountpoint_s3::{S3Filesystem, S3FilesystemConfig}; +use mountpoint_s3::S3FilesystemConfig; use mountpoint_s3_client::mock_client::{MockClient, MockObject}; use mountpoint_s3_client::ObjectClient; use proptest::prelude::*; use proptest_derive::Arbitrary; use tracing::{debug, trace}; -use crate::common::{make_test_filesystem, DirectoryReply}; +use crate::common::{make_test_filesystem, DirectoryReply, TestS3Filesystem}; use crate::reftests::generators::{flatten_tree, gen_tree, FileContent, FileSize, Name, TreeNode, ValidName}; use crate::reftests::reference::{File, Node, Reference}; @@ -164,7 +163,7 @@ impl InflightWrites { pub struct Harness { readdir_limit: usize, // max number of entries that a readdir will return; 0 means no limit reference: Reference, - fs: S3Filesystem, ThreadPool>, + fs: TestS3Filesystem>, client: Arc, bucket: String, inflight_writes: InflightWrites, @@ -173,7 +172,7 @@ pub struct Harness { impl Harness { /// Create a new test harness pub fn new( - fs: S3Filesystem, ThreadPool>, + fs: TestS3Filesystem>, client: Arc, reference: Reference, bucket: &str,