Skip to content

Commit

Permalink
Switch to using task group instead of manually tracking tasks in flig…
Browse files Browse the repository at this point in the history
…ht. This makes the walker thread safe so we can get rid of the mutex
  • Loading branch information
westonpace committed Mar 29, 2021
1 parent 7a35945 commit 524f5e0
Showing 1 changed file with 28 additions and 51 deletions.
79 changes: 28 additions & 51 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@
#include "arrow/util/future.h"
#include "arrow/util/logging.h"
#include "arrow/util/optional.h"
#include "arrow/util/task_group.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/windows_fixup.h"

namespace arrow {

using internal::TaskGroup;
using internal::Uri;

namespace fs {
Expand Down Expand Up @@ -1142,83 +1144,60 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {
recursion_handler_(std::move(recursion_handler)) {}

private:
std::mutex mutex_;
Future<> future_;
std::atomic<int32_t> num_in_flight_;
std::shared_ptr<TaskGroup> task_group_;

Status DoWalk() {
future_ = decltype(future_)::Make();
num_in_flight_ = 0;
task_group_ =
TaskGroup::MakeThreaded(io_context_.executor(), io_context_.stop_token());
WalkChild(base_dir_, /*nesting_depth=*/0);
// When this returns, ListObjectsV2 tasks either have finished or will exit early
return future_.status();
return task_group_->Finish();
}

bool is_finished() const { return future_.is_finished(); }

void ListObjectsFinished(Status st) {
const auto in_flight = --num_in_flight_;
if (!st.ok() || !in_flight) {
future_.MarkFinished(std::move(st));
}
}
bool ok() const { return task_group_->ok(); }

struct ListObjectsV2Handler {
std::shared_ptr<TreeWalker> walker;
std::string prefix;
int32_t nesting_depth;
S3Model::ListObjectsV2Request req;

void operator()(const Result<S3Model::ListObjectsV2Outcome>& result) {
Status operator()(const Result<S3Model::ListObjectsV2Outcome>& result) {
// Serialize calls to operation-specific handlers
std::unique_lock<std::mutex> guard(walker->mutex_);
if (walker->is_finished()) {
if (!walker->ok()) {
// Early exit: avoid executing handlers if DoWalk() returned
return;
return Status::OK();
}
if (!result.ok()) {
HandleError(result.status());
return;
return result.status();
}
const auto& outcome = *result;
if (!outcome.IsSuccess()) {
Status st = walker->error_handler_(outcome.GetError());
HandleError(std::move(st));
return;
return walker->error_handler_(outcome.GetError());
}
HandleResult(outcome.GetResult());
return HandleResult(outcome.GetResult());
}

void SpawnListObjectsV2() {
Status SpawnListObjectsV2() {
auto walker = this->walker;
auto req = this->req;
auto maybe_fut = walker->io_context_.executor()->Submit(
walker->io_context_.stop_token(),
[walker, req]() { return walker->client_->ListObjectsV2(req); });
if (!maybe_fut.ok()) {
HandleError(maybe_fut.status());
return;
}
maybe_fut->AddCallback(*this);
auto cb = *this;
walker->task_group_->Append([walker, req, cb]() mutable {
Result<S3Model::ListObjectsV2Outcome> result =
walker->client_->ListObjectsV2(req);
return cb(result);
});
return Status::OK();
}

void HandleError(Status status) { walker->ListObjectsFinished(std::move(status)); }

void HandleResult(const S3Model::ListObjectsV2Result& result) {
Status HandleResult(const S3Model::ListObjectsV2Result& result) {
bool recurse = result.GetCommonPrefixes().size() > 0;
if (recurse) {
auto maybe_recurse = walker->recursion_handler_(nesting_depth + 1);
if (!maybe_recurse.ok()) {
walker->ListObjectsFinished(maybe_recurse.status());
return;
}
recurse &= *maybe_recurse;
}
Status st = walker->result_handler_(prefix, result);
if (!st.ok()) {
walker->ListObjectsFinished(std::move(st));
return;
ARROW_ASSIGN_OR_RAISE(auto maybe_recurse,
walker->recursion_handler_(nesting_depth + 1));
recurse &= maybe_recurse;
}
RETURN_NOT_OK(walker->result_handler_(prefix, result));
if (recurse) {
walker->WalkChildren(result, nesting_depth + 1);
}
Expand All @@ -1227,10 +1206,9 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {
if (result.GetIsTruncated()) {
DCHECK(!result.GetNextContinuationToken().empty());
req.SetContinuationToken(result.GetNextContinuationToken());
SpawnListObjectsV2();
} else {
walker->ListObjectsFinished(Status::OK());
RETURN_NOT_OK(SpawnListObjectsV2());
}
return Status::OK();
}

void Start() {
Expand All @@ -1246,7 +1224,6 @@ struct TreeWalker : public std::enable_shared_from_this<TreeWalker> {

void WalkChild(std::string key, int32_t nesting_depth) {
ListObjectsV2Handler handler{shared_from_this(), std::move(key), nesting_depth, {}};
++num_in_flight_;
handler.Start();
}

Expand Down

0 comments on commit 524f5e0

Please sign in to comment.