From dfbc93d104ba54e58259a326f44445df98ede86d Mon Sep 17 00:00:00 2001 From: Carlos O'Ryan Date: Tue, 17 Mar 2020 17:34:57 -0400 Subject: [PATCH] fix: fix racy conditions during shutdown (googleapis/google-cloud-cpp-common#250) * bug: fix race condition during shutdown We were checking if the completion queue was shutdown and then scheduling work, but not atomically. --- google/cloud/completion_queue.cc | 6 +-- google/cloud/completion_queue.h | 5 +-- google/cloud/completion_queue_test.cc | 38 ++++++++++++++++--- .../cloud/internal/async_read_stream_impl.h | 31 ++++++--------- .../cloud/internal/completion_queue_impl.cc | 20 ---------- google/cloud/internal/completion_queue_impl.h | 26 ++++++++++++- 6 files changed, 72 insertions(+), 54 deletions(-) diff --git a/google/cloud/completion_queue.cc b/google/cloud/completion_queue.cc index dd43050a08293..3153bcf84bf9a 100644 --- a/google/cloud/completion_queue.cc +++ b/google/cloud/completion_queue.cc @@ -90,10 +90,8 @@ google::cloud::future> CompletionQueue::MakeDeadlineTimer( std::chrono::system_clock::time_point deadline) { auto op = std::make_shared(impl_->CreateAlarm()); - void* tag = impl_->RegisterOperation(op); - if (tag != nullptr) { - op->Set(impl_->cq(), deadline, tag); - } + impl_->StartOperation( + op, [&](void* tag) { op->Set(impl_->cq(), deadline, tag); }); return op->GetFuture(); } diff --git a/google/cloud/completion_queue.h b/google/cloud/completion_queue.h index e5ca7c2192efc..d6e781cfed0d3 100644 --- a/google/cloud/completion_queue.h +++ b/google/cloud/completion_queue.h @@ -113,10 +113,9 @@ class CompletionQueue { std::unique_ptr context) { auto op = std::make_shared>(); - void* tag = impl_->RegisterOperation(op); - if (tag != nullptr) { + impl_->StartOperation(op, [&](void* tag) { op->Start(async_call, std::move(context), request, &impl_->cq(), tag); - } + }); return op->GetFuture(); } diff --git a/google/cloud/completion_queue_test.cc b/google/cloud/completion_queue_test.cc index f427e7bcd5e8e..04f4b18ba04c5 100644 --- a/google/cloud/completion_queue_test.cc +++ b/google/cloud/completion_queue_test.cc @@ -333,11 +333,14 @@ TEST(CompletionQueueTest, RunAsync) { // Sets up a timer that reschedules itself and verifies we can shut down // cleanly whether we call `CancelAll()` on the queue first or not. namespace { -void RunAndReschedule(CompletionQueue& cq, bool ok) { +using TimerFuture = future>; + +void RunAndReschedule(CompletionQueue& cq, bool ok, + std::chrono::seconds duration) { if (ok) { - cq.MakeRelativeTimer(std::chrono::seconds(1)) - .then([&cq](future> - result) { RunAndReschedule(cq, result.get().ok()); }); + cq.MakeRelativeTimer(duration).then([&cq, duration](TimerFuture result) { + RunAndReschedule(cq, result.get().ok(), duration); + }); } } } // namespace @@ -346,17 +349,40 @@ TEST(CompletionQueueTest, ShutdownWithReschedulingTimer) { CompletionQueue cq; std::thread t([&cq] { cq.Run(); }); - RunAndReschedule(cq, /*ok=*/true); + RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1)); cq.Shutdown(); t.join(); } +TEST(CompletionQueueTest, ShutdownWithFastReschedulingTimer) { + auto constexpr kThreadCount = 32; + auto constexpr kTimerCount = 100; + CompletionQueue cq; + std::vector threads(kThreadCount); + std::generate_n(threads.begin(), threads.size(), + [&cq] { return std::thread([&cq] { cq.Run(); }); }); + + for (int i = 0; i != kTimerCount; ++i) { + RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(0)); + } + + promise wait; + cq.MakeRelativeTimer(std::chrono::milliseconds(1)).then([&wait](TimerFuture) { + wait.set_value(); + }); + wait.get_future().get(); + cq.Shutdown(); + for (auto& t : threads) { + t.join(); + } +} + TEST(CompletionQueueTest, CancelAndShutdownWithReschedulingTimer) { CompletionQueue cq; std::thread t([&cq] { cq.Run(); }); - RunAndReschedule(cq, /*ok=*/true); + RunAndReschedule(cq, /*ok=*/true, std::chrono::seconds(1)); cq.CancelAll(); cq.Shutdown(); diff --git a/google/cloud/internal/async_read_stream_impl.h b/google/cloud/internal/async_read_stream_impl.h index 4eadf0c7e92db..0e021eb4a70bf 100644 --- a/google/cloud/internal/async_read_stream_impl.h +++ b/google/cloud/internal/async_read_stream_impl.h @@ -157,15 +157,14 @@ class AsyncReadStreamImpl context_ = std::move(context); cq_ = std::move(cq); auto callback = std::make_shared(this->shared_from_this()); - void* tag = cq_->RegisterOperation(std::move(callback)); - // @note If `tag == nullptr` the `CompletionQueue` has been `Shutdown()`. - // We leave `reader_` null in this case; other methods must make the - // same `tag != nullptr` check prior to accessing `reader_`. This is - // safe since `Shutdown()` cannot be undone. - if (tag != nullptr) { + cq_->StartOperation(std::move(callback), [&](void* tag) { + // @note If the the `CompletionQueue` has been `Shutdown()` this lambda is + // never called. We leave `reader_` null in this case; other methods + // must make the same `tag != nullptr` check prior to accessing + // `reader_`. This is safe since `Shutdown()` cannot be undone. reader_ = async_call(context_.get(), request, &cq_->cq()); reader_->StartCall(tag); - } + }); } /// Cancel the current streaming read RPC. @@ -202,10 +201,8 @@ class AsyncReadStreamImpl auto callback = std::make_shared(this->shared_from_this()); auto response = &callback->response; - void* tag = cq_->RegisterOperation(std::move(callback)); - if (tag != nullptr) { - reader_->Read(response, tag); - } + cq_->StartOperation(std::move(callback), + [&](void* tag) { reader_->Read(response, tag); }); } /// Handle the result of a `Read()` call. @@ -252,10 +249,8 @@ class AsyncReadStreamImpl auto callback = std::make_shared(this->shared_from_this()); auto status = &callback->status; - void* tag = cq_->RegisterOperation(std::move(callback)); - if (tag != nullptr) { - reader_->Finish(status, tag); - } + cq_->StartOperation(std::move(callback), + [&](void* tag) { reader_->Finish(status, tag); }); } /// Handle the result of a Finish() request. @@ -292,10 +287,8 @@ class AsyncReadStreamImpl auto callback = std::make_shared(this->shared_from_this()); auto response = &callback->response; - void* tag = cq_->RegisterOperation(std::move(callback)); - if (tag != nullptr) { - reader_->Read(response, tag); - } + cq_->StartOperation(std::move(callback), + [&](void* tag) { reader_->Read(response, tag); }); } /// Handle the result of a Discard() call. diff --git a/google/cloud/internal/completion_queue_impl.cc b/google/cloud/internal/completion_queue_impl.cc index 8e368b2cd09c4..1e64fc089c4f0 100644 --- a/google/cloud/internal/completion_queue_impl.cc +++ b/google/cloud/internal/completion_queue_impl.cc @@ -73,26 +73,6 @@ std::unique_ptr CompletionQueueImpl::CreateAlarm() const { return google::cloud::internal::make_unique(); } -void* CompletionQueueImpl::RegisterOperation( - std::shared_ptr op) { - void* tag = op.get(); - std::unique_lock lk(mu_); - if (shutdown_) { - lk.unlock(); - op->Notify(/*ok=*/false); - return nullptr; - } - auto ins = - pending_ops_.emplace(reinterpret_cast(tag), std::move(op)); - // After this point we no longer need the lock, so release it. - lk.unlock(); - if (ins.second) { - return tag; - } - google::cloud::internal::ThrowRuntimeError( - "assertion failure: insertion should succeed"); -} - std::shared_ptr CompletionQueueImpl::FindOperation( void* tag) { std::lock_guard lk(mu_); diff --git a/google/cloud/internal/completion_queue_impl.h b/google/cloud/internal/completion_queue_impl.h index 64ddeb5f8fe91..d432a2c3dad70 100644 --- a/google/cloud/internal/completion_queue_impl.h +++ b/google/cloud/internal/completion_queue_impl.h @@ -247,8 +247,30 @@ class CompletionQueueImpl { /// The underlying gRPC completion queue. grpc::CompletionQueue& cq() { return cq_; } - /// Add a new asynchronous operation to the completion queue. - void* RegisterOperation(std::shared_ptr op); + /// Atomically add a new operation to the completion queue and start it. + template ::value, + int>::type = 0> + void StartOperation(std::shared_ptr op, + Callable&& start) { + void* tag = op.get(); + std::unique_lock lk(mu_); + if (shutdown_) { + lk.unlock(); + op->Notify(/*ok=*/false); + return; + } + auto ins = pending_ops_.emplace(reinterpret_cast(tag), + std::move(op)); + if (ins.second) { + start(tag); + lk.unlock(); + return; + } + google::cloud::internal::ThrowRuntimeError( + "assertion failure: insertion should succeed"); + } protected: /// Return the asynchronous operation associated with @p tag.