Skip to content
This repository has been archived by the owner on Dec 8, 2021. It is now read-only.

fix: fix racy conditions during shutdown #250

Merged
merged 5 commits into from
Mar 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified ci/test-api/google_cloud_cpp_common.expected.abi.dump.gz
Binary file not shown.
Binary file modified ci/test-api/google_cloud_cpp_grpc_utils.expected.abi.dump.gz
Binary file not shown.
6 changes: 2 additions & 4 deletions google/cloud/completion_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,8 @@ google::cloud::future<StatusOr<std::chrono::system_clock::time_point>>
CompletionQueue::MakeDeadlineTimer(
std::chrono::system_clock::time_point deadline) {
auto op = std::make_shared<AsyncTimerFuture>(impl_->CreateAlarm());
void* tag = impl_->RegisterOperation(op);
if (tag != nullptr) {
op->Set(impl_->cq(), deadline, tag);
}
impl_->StartOperation(
op, [&](void* tag) { op->Set(impl_->cq(), deadline, tag); });
return op->GetFuture();
}

Expand Down
5 changes: 2 additions & 3 deletions google/cloud/completion_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,9 @@ class CompletionQueue {
std::unique_ptr<grpc::ClientContext> context) {
auto op =
std::make_shared<internal::AsyncUnaryRpcFuture<Request, Response>>();
void* tag = impl_->RegisterOperation(op);
if (tag != nullptr) {
impl_->StartOperation(op, [&](void* tag) {
op->Start(async_call, std::move(context), request, &impl_->cq(), tag);
}
});
return op->GetFuture();
}

Expand Down
38 changes: 32 additions & 6 deletions google/cloud/completion_queue_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,14 @@ TEST(CompletionQueueTest, RunAsync) {
// Sets up a timer that reschedules itself and verifies we can shut down
// cleanly whether we call `CancelAll()` on the queue first or not.
namespace {
void RunAndReschedule(CompletionQueue& cq, bool ok) {
using TimerFuture = future<StatusOr<std::chrono::system_clock::time_point>>;

void RunAndReschedule(CompletionQueue& cq, bool ok,
std::chrono::seconds duration) {
if (ok) {
cq.MakeRelativeTimer(std::chrono::seconds(1))
.then([&cq](future<StatusOr<std::chrono::system_clock::time_point>>
result) { RunAndReschedule(cq, result.get().ok()); });
cq.MakeRelativeTimer(duration).then([&cq, duration](TimerFuture result) {
RunAndReschedule(cq, result.get().ok(), duration);
});
}
}
} // namespace
Expand All @@ -346,17 +349,40 @@ TEST(CompletionQueueTest, ShutdownWithReschedulingTimer) {
CompletionQueue cq;
std::thread t([&cq] { cq.Run(); });

RunAndReschedule(cq, /*ok=*/true);
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1));

cq.Shutdown();
t.join();
}

TEST(CompletionQueueTest, ShutdownWithFastReschedulingTimer) {
auto constexpr kThreadCount = 32;
auto constexpr kTimerCount = 100;
CompletionQueue cq;
std::vector<std::thread> threads(kThreadCount);
std::generate_n(threads.begin(), threads.size(),
[&cq] { return std::thread([&cq] { cq.Run(); }); });

for (int i = 0; i != kTimerCount; ++i) {
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(0));
}

promise<void> wait;
cq.MakeRelativeTimer(std::chrono::milliseconds(1)).then([&wait](TimerFuture) {
wait.set_value();
});
wait.get_future().get();
cq.Shutdown();
for (auto& t : threads) {
t.join();
}
}

TEST(CompletionQueueTest, CancelAndShutdownWithReschedulingTimer) {
CompletionQueue cq;
std::thread t([&cq] { cq.Run(); });

RunAndReschedule(cq, /*ok=*/true);
RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1));

cq.CancelAll();
cq.Shutdown();
Expand Down
31 changes: 12 additions & 19 deletions google/cloud/internal/async_read_stream_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,14 @@ class AsyncReadStreamImpl
context_ = std::move(context);
cq_ = std::move(cq);
auto callback = std::make_shared<NotifyStart>(this->shared_from_this());
void* tag = cq_->RegisterOperation(std::move(callback));
// @note If `tag == nullptr` the `CompletionQueue` has been `Shutdown()`.
// We leave `reader_` null in this case; other methods must make the
// same `tag != nullptr` check prior to accessing `reader_`. This is
// safe since `Shutdown()` cannot be undone.
if (tag != nullptr) {
cq_->StartOperation(std::move(callback), [&](void* tag) {
// @note If the the `CompletionQueue` has been `Shutdown()` this lambda is
// never called. We leave `reader_` null in this case; other methods
// must make the same `tag != nullptr` check prior to accessing
// `reader_`. This is safe since `Shutdown()` cannot be undone.
reader_ = async_call(context_.get(), request, &cq_->cq());
reader_->StartCall(tag);
}
});
}

/// Cancel the current streaming read RPC.
Expand Down Expand Up @@ -202,10 +201,8 @@ class AsyncReadStreamImpl

auto callback = std::make_shared<NotifyRead>(this->shared_from_this());
auto response = &callback->response;
void* tag = cq_->RegisterOperation(std::move(callback));
if (tag != nullptr) {
reader_->Read(response, tag);
}
cq_->StartOperation(std::move(callback),
[&](void* tag) { reader_->Read(response, tag); });
}

/// Handle the result of a `Read()` call.
Expand Down Expand Up @@ -252,10 +249,8 @@ class AsyncReadStreamImpl

auto callback = std::make_shared<NotifyFinish>(this->shared_from_this());
auto status = &callback->status;
void* tag = cq_->RegisterOperation(std::move(callback));
if (tag != nullptr) {
reader_->Finish(status, tag);
}
cq_->StartOperation(std::move(callback),
[&](void* tag) { reader_->Finish(status, tag); });
}

/// Handle the result of a Finish() request.
Expand Down Expand Up @@ -292,10 +287,8 @@ class AsyncReadStreamImpl

auto callback = std::make_shared<NotifyDiscard>(this->shared_from_this());
auto response = &callback->response;
void* tag = cq_->RegisterOperation(std::move(callback));
if (tag != nullptr) {
reader_->Read(response, tag);
}
cq_->StartOperation(std::move(callback),
[&](void* tag) { reader_->Read(response, tag); });
}

/// Handle the result of a Discard() call.
Expand Down
20 changes: 0 additions & 20 deletions google/cloud/internal/completion_queue_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,6 @@ std::unique_ptr<grpc::Alarm> CompletionQueueImpl::CreateAlarm() const {
return google::cloud::internal::make_unique<grpc::Alarm>();
}

void* CompletionQueueImpl::RegisterOperation(
std::shared_ptr<AsyncGrpcOperation> op) {
void* tag = op.get();
std::unique_lock<std::mutex> lk(mu_);
if (shutdown_) {
lk.unlock();
op->Notify(/*ok=*/false);
return nullptr;
}
auto ins =
pending_ops_.emplace(reinterpret_cast<std::intptr_t>(tag), std::move(op));
// After this point we no longer need the lock, so release it.
lk.unlock();
if (ins.second) {
return tag;
}
google::cloud::internal::ThrowRuntimeError(
"assertion failure: insertion should succeed");
}

std::shared_ptr<AsyncGrpcOperation> CompletionQueueImpl::FindOperation(
void* tag) {
std::lock_guard<std::mutex> lk(mu_);
Expand Down
26 changes: 24 additions & 2 deletions google/cloud/internal/completion_queue_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,30 @@ class CompletionQueueImpl {
/// The underlying gRPC completion queue.
grpc::CompletionQueue& cq() { return cq_; }

/// Add a new asynchronous operation to the completion queue.
void* RegisterOperation(std::shared_ptr<AsyncGrpcOperation> op);
/// Atomically add a new operation to the completion queue and start it.
template <typename Callable,
typename std::enable_if<
google::cloud::internal::is_invocable<Callable, void*>::value,
int>::type = 0>
void StartOperation(std::shared_ptr<AsyncGrpcOperation> op,
Callable&& start) {
void* tag = op.get();
std::unique_lock<std::mutex> lk(mu_);
if (shutdown_) {
lk.unlock();
op->Notify(/*ok=*/false);
return;
}
auto ins = pending_ops_.emplace(reinterpret_cast<std::intptr_t>(tag),
std::move(op));
if (ins.second) {
start(tag);
lk.unlock();
return;
}
google::cloud::internal::ThrowRuntimeError(
"assertion failure: insertion should succeed");
}

protected:
/// Return the asynchronous operation associated with @p tag.
Expand Down