Skip to content

Commit

Permalink
pw_rpc: Move on_completed_ before invoking it
Browse files Browse the repository at this point in the history
Move the on_completed_ callback to a local variable while holding the
RPC lock to fix a race condition.

Bug: b/234876851
Change-Id: Ib64cd6b7542198dd0c0a5c08f162f60bf0756001
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/125212
Commit-Queue: Auto-Submit <[email protected]>
Pigweed-Auto-Submit: Wyatt Hepler <[email protected]>
Reviewed-by: Alexei Frolov <[email protected]>
  • Loading branch information
255 authored and CQ Bot Account committed Dec 23, 2022
1 parent 5052698 commit c4a481b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
22 changes: 9 additions & 13 deletions pw_rpc/public/pw_rpc/internal/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ class UnaryResponseClientCall : public ClientCall {

void HandleCompleted(ConstByteSpan response, Status status)
PW_UNLOCK_FUNCTION(rpc_lock()) {
const bool invoke_callback = on_completed_ != nullptr;
UnregisterAndMarkClosed();

auto on_completed_local = std::move(on_completed_);
rpc_lock().unlock();
if (invoke_callback) {
on_completed_(response, status);

if (on_completed_local) {
on_completed_local(response, status);
}
}

Expand Down Expand Up @@ -120,7 +120,6 @@ class UnaryResponseClientCall : public ClientCall {

void set_on_completed(Function<void(ConstByteSpan, Status)>&& on_completed)
PW_LOCKS_EXCLUDED(rpc_lock()) {
// TODO(b/234876851): Ensure on_completed_ is properly guarded.
LockGuard lock(rpc_lock());
set_on_completed_locked(std::move(on_completed));
}
Expand All @@ -134,7 +133,7 @@ class UnaryResponseClientCall : public ClientCall {
private:
using internal::ClientCall::set_on_next; // Not used in unary response calls.

Function<void(ConstByteSpan, Status)> on_completed_;
Function<void(ConstByteSpan, Status)> on_completed_ PW_GUARDED_BY(rpc_lock());
};

// Stream response client calls only receive the status in their on_completed
Expand Down Expand Up @@ -162,14 +161,12 @@ class StreamResponseClientCall : public ClientCall {
}

void HandleCompleted(Status status) PW_UNLOCK_FUNCTION(rpc_lock()) {
const bool invoke_callback = on_completed_ != nullptr;

UnregisterAndMarkClosed();
auto on_completed_local = std::move(on_completed_);
rpc_lock().unlock();

// TODO(b/234876851): Ensure on_completed_ is properly guarded.
if (invoke_callback) {
on_completed_(status);
if (on_completed_local) {
on_completed_local(status);
}
}

Expand Down Expand Up @@ -203,7 +200,6 @@ class StreamResponseClientCall : public ClientCall {

void set_on_completed(Function<void(Status)>&& on_completed)
PW_LOCKS_EXCLUDED(rpc_lock()) {
// TODO(b/234876851): Ensure on_completed_ is properly guarded.
LockGuard lock(rpc_lock());
set_on_completed_locked(std::move(on_completed));
}
Expand All @@ -214,7 +210,7 @@ class StreamResponseClientCall : public ClientCall {
}

private:
Function<void(Status)> on_completed_;
Function<void(Status)> on_completed_ PW_GUARDED_BY(rpc_lock());
};

} // namespace pw::rpc::internal
19 changes: 19 additions & 0 deletions pw_rpc/raw/client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ class TestUnaryCall : public internal::UnaryResponseClientCall {
set_on_error_locked([this](Status status) { error = status; });
}

void clear_on_completed() PW_EXCLUSIVE_LOCKS_REQUIRED(internal::rpc_lock()) {
set_on_completed_locked(nullptr);
}

const char* payload;
std::optional<Status> completed;
std::optional<Status> error;
Expand All @@ -112,6 +116,21 @@ TEST(Client, ProcessPacket_InvokesUnaryCallbacks) {
ASSERT_NE(call.payload, nullptr);
EXPECT_STREQ(call.payload, "you nary?!?");
EXPECT_EQ(call.completed, OkStatus());
EXPECT_FALSE(call.active());
}

TEST(Client, ProcessPacket_NoCallbackSet) {
RawClientTestContext context;
internal::rpc_lock().lock();
TestUnaryCall call = MakeCall<UnaryMethod, TestUnaryCall>(context);
call.clear_on_completed();
call.SendInitialClientRequest({});

ASSERT_NE(call.completed, OkStatus());

context.server().SendResponse<UnaryMethod>(as_bytes(span("you nary?!?")),
OkStatus());
EXPECT_FALSE(call.active());
}

TEST(Client, ProcessPacket_InvokesStreamCallbacks) {
Expand Down

0 comments on commit c4a481b

Please sign in to comment.