diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 1940f4dd4df91..60824b20d1ed1 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -79,11 +79,13 @@ #include "arrow/util/future.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" +#include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" #include "arrow/util/windows_fixup.h" namespace arrow { +using internal::TaskGroup; using internal::Uri; namespace fs { @@ -1142,26 +1144,17 @@ struct TreeWalker : public std::enable_shared_from_this { recursion_handler_(std::move(recursion_handler)) {} private: - std::mutex mutex_; - Future<> future_; - std::atomic num_in_flight_; + std::shared_ptr task_group_; Status DoWalk() { - future_ = decltype(future_)::Make(); - num_in_flight_ = 0; + task_group_ = + TaskGroup::MakeThreaded(io_context_.executor(), io_context_.stop_token()); WalkChild(base_dir_, /*nesting_depth=*/0); // When this returns, ListObjectsV2 tasks either have finished or will exit early - return future_.status(); + return task_group_->Finish(); } - bool is_finished() const { return future_.is_finished(); } - - void ListObjectsFinished(Status st) { - const auto in_flight = --num_in_flight_; - if (!st.ok() || !in_flight) { - future_.MarkFinished(std::move(st)); - } - } + bool ok() const { return task_group_->ok(); } struct ListObjectsV2Handler { std::shared_ptr walker; @@ -1169,56 +1162,42 @@ struct TreeWalker : public std::enable_shared_from_this { int32_t nesting_depth; S3Model::ListObjectsV2Request req; - void operator()(const Result& result) { + Status operator()(const Result& result) { // Serialize calls to operation-specific handlers - std::unique_lock guard(walker->mutex_); - if (walker->is_finished()) { + if (!walker->ok()) { // Early exit: avoid executing handlers if DoWalk() returned - return; + return Status::OK(); } if (!result.ok()) { - HandleError(result.status()); - return; + return result.status(); } const auto& outcome = *result; if (!outcome.IsSuccess()) { - Status st = walker->error_handler_(outcome.GetError()); - HandleError(std::move(st)); - return; + return walker->error_handler_(outcome.GetError()); } - HandleResult(outcome.GetResult()); + return HandleResult(outcome.GetResult()); } - void SpawnListObjectsV2() { + Status SpawnListObjectsV2() { auto walker = this->walker; auto req = this->req; - auto maybe_fut = walker->io_context_.executor()->Submit( - walker->io_context_.stop_token(), - [walker, req]() { return walker->client_->ListObjectsV2(req); }); - if (!maybe_fut.ok()) { - HandleError(maybe_fut.status()); - return; - } - maybe_fut->AddCallback(*this); + auto cb = *this; + walker->task_group_->Append([walker, req, cb]() mutable { + Result result = + walker->client_->ListObjectsV2(req); + return cb(result); + }); + return Status::OK(); } - void HandleError(Status status) { walker->ListObjectsFinished(std::move(status)); } - - void HandleResult(const S3Model::ListObjectsV2Result& result) { + Status HandleResult(const S3Model::ListObjectsV2Result& result) { bool recurse = result.GetCommonPrefixes().size() > 0; if (recurse) { - auto maybe_recurse = walker->recursion_handler_(nesting_depth + 1); - if (!maybe_recurse.ok()) { - walker->ListObjectsFinished(maybe_recurse.status()); - return; - } - recurse &= *maybe_recurse; - } - Status st = walker->result_handler_(prefix, result); - if (!st.ok()) { - walker->ListObjectsFinished(std::move(st)); - return; + ARROW_ASSIGN_OR_RAISE(auto maybe_recurse, + walker->recursion_handler_(nesting_depth + 1)); + recurse &= maybe_recurse; } + RETURN_NOT_OK(walker->result_handler_(prefix, result)); if (recurse) { walker->WalkChildren(result, nesting_depth + 1); } @@ -1227,10 +1206,9 @@ struct TreeWalker : public std::enable_shared_from_this { if (result.GetIsTruncated()) { DCHECK(!result.GetNextContinuationToken().empty()); req.SetContinuationToken(result.GetNextContinuationToken()); - SpawnListObjectsV2(); - } else { - walker->ListObjectsFinished(Status::OK()); + RETURN_NOT_OK(SpawnListObjectsV2()); } + return Status::OK(); } void Start() { @@ -1246,7 +1224,6 @@ struct TreeWalker : public std::enable_shared_from_this { void WalkChild(std::string key, int32_t nesting_depth) { ListObjectsV2Handler handler{shared_from_this(), std::move(key), nesting_depth, {}}; - ++num_in_flight_; handler.Start(); }