Skip to content

Commit

Permalink
Fix the buggy Future and Promise implementations (#299)
Browse files Browse the repository at this point in the history
Fixes #298

### Motivation

Currently the `Future` and `Promise` are implemented manually by
managing conditional variables. However, the conditional variable
sometimes behaviors incorrectly on macOS, while the existing `future`
and `promise` from the C++ standard library works well.

### Modifications

Redesign `Future` and `Promise` based on the utilities in the standard
`<future>` header. In addition, fix the possible race condition when
`addListener` is called after `setValue` or `setFailed`:
- Thread 1: call `setValue`, switch existing listeners and call them one
  by one out of the lock.
- Thread 2: call `addListener`, detect `complete_` is true and call the
  listener directly.

Now, the previous listeners and the new listener are called concurrently
in thread 1 and 2.

This patch fixes the problem by adding a future to wait all listeners
that were added before completing are done.

### Verifications

Run the reproduce code in #298 for 10 times and found it never failed or
hang.

Co-authored-by: Zike Yang <[email protected]>

---------

Co-authored-by: Zike Yang <[email protected]>
  • Loading branch information
BewareMyPower and RobertIndie authored Jul 5, 2023
1 parent 804f87b commit 20f6fa0
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 121 deletions.
2 changes: 1 addition & 1 deletion lib/BinaryProtoLookupService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void BinaryProtoLookupService::handlePartitionMetadataLookup(const std::string&
}

uint64_t BinaryProtoLookupService::newRequestId() {
Lock lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
return ++requestIdGenerator_;
}

Expand Down
201 changes: 86 additions & 115 deletions lib/Future.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,162 +19,133 @@
#ifndef LIB_FUTURE_H_
#define LIB_FUTURE_H_

#include <condition_variable>
#include <atomic>
#include <chrono>
#include <functional>
#include <future>
#include <list>
#include <memory>
#include <mutex>

using Lock = std::unique_lock<std::mutex>;
#include <thread>
#include <utility>

namespace pulsar {

template <typename Result, typename Type>
struct InternalState {
std::mutex mutex;
std::condition_variable condition;
Result result;
Type value;
bool complete;

std::list<typename std::function<void(Result, const Type&)> > listeners;
};

template <typename Result, typename Type>
class Future {
class InternalState {
public:
typedef std::function<void(Result, const Type&)> ListenerCallback;

Future& addListener(ListenerCallback callback) {
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);

if (state->complete) {
lock.unlock();
callback(state->result, state->value);
} else {
state->listeners.push_back(callback);
}
using Listener = std::function<void(Result, const Type &)>;
using Pair = std::pair<Result, Type>;
using Lock = std::unique_lock<std::mutex>;

return *this;
}
// NOTE: Add the constructor explicitly just to be compatible with GCC 4.8
InternalState() {}

Result get(Type& result) {
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);
void addListener(Listener listener) {
Lock lock{mutex_};
listeners_.emplace_back(listener);
lock.unlock();

if (!state->complete) {
// Wait for result
while (!state->complete) {
state->condition.wait(lock);
}
if (completed()) {
Type value;
Result result = get(value);
triggerListeners(result, value);
}

result = state->value;
return state->result;
}

template <typename Duration>
bool get(Result& res, Type& value, Duration d) {
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);

if (!state->complete) {
// Wait for result
while (!state->complete) {
if (!state->condition.wait_for(lock, d, [&state] { return state->complete; })) {
// Timeout while waiting for the future to complete
return false;
}
}
bool complete(Result result, const Type &value) {
bool expected = false;
if (!completed_.compare_exchange_strong(expected, true)) {
return false;
}

value = state->value;
res = state->result;
triggerListeners(result, value);
promise_.set_value(std::make_pair(result, value));
return true;
}

private:
typedef std::shared_ptr<InternalState<Result, Type> > InternalStatePtr;
Future(InternalStatePtr state) : state_(state) {}
bool completed() const noexcept { return completed_; }

std::shared_ptr<InternalState<Result, Type> > state_;

template <typename U, typename V>
friend class Promise;
};
Result get(Type &result) {
const auto &pair = future_.get();
result = pair.second;
return pair.first;
}

template <typename Result, typename Type>
class Promise {
public:
Promise() : state_(std::make_shared<InternalState<Result, Type> >()) {}
// Only public for test
void triggerListeners(Result result, const Type &value) {
while (true) {
Lock lock{mutex_};
if (listeners_.empty()) {
return;
}

bool setValue(const Type& value) const {
static Result DEFAULT_RESULT;
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);
bool expected = false;
if (!listenerRunning_.compare_exchange_strong(expected, true)) {
// There is another thread that polled a listener that is running, skip polling and release
// the lock. Here we wait for some time to avoid busy waiting.
std::this_thread::sleep_for(std::chrono::milliseconds(1));
continue;
}
auto listener = std::move(listeners_.front());
listeners_.pop_front();
lock.unlock();

if (state->complete) {
return false;
listener(result, value);
listenerRunning_ = false;
}
}

state->value = value;
state->result = DEFAULT_RESULT;
state->complete = true;
private:
std::atomic_bool completed_{false};
std::promise<Pair> promise_;
std::shared_future<Pair> future_{promise_.get_future()};

decltype(state->listeners) listeners;
listeners.swap(state->listeners);
std::list<Listener> listeners_;
mutable std::mutex mutex_;
std::atomic_bool listenerRunning_{false};
};

lock.unlock();
template <typename Result, typename Type>
using InternalStatePtr = std::shared_ptr<InternalState<Result, Type>>;

for (auto& callback : listeners) {
callback(DEFAULT_RESULT, value);
}
template <typename Result, typename Type>
class Future {
public:
using Listener = typename InternalState<Result, Type>::Listener;

state->condition.notify_all();
return true;
Future &addListener(Listener listener) {
state_->addListener(listener);
return *this;
}

bool setFailed(Result result) const {
static Type DEFAULT_VALUE;
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);
Result get(Type &result) { return state_->get(result); }

if (state->complete) {
return false;
}
private:
InternalStatePtr<Result, Type> state_;

state->result = result;
state->complete = true;
Future(InternalStatePtr<Result, Type> state) : state_(state) {}

decltype(state->listeners) listeners;
listeners.swap(state->listeners);
template <typename U, typename V>
friend class Promise;
};

lock.unlock();
template <typename Result, typename Type>
class Promise {
public:
Promise() : state_(std::make_shared<InternalState<Result, Type>>()) {}

for (auto& callback : listeners) {
callback(result, DEFAULT_VALUE);
}
bool setValue(const Type &value) const { return state_->complete({}, value); }

state->condition.notify_all();
return true;
}
bool setFailed(Result result) const { return state_->complete(result, {}); }

bool isComplete() const {
InternalState<Result, Type>* state = state_.get();
Lock lock(state->mutex);
return state->complete;
}
bool isComplete() const { return state_->completed(); }

Future<Result, Type> getFuture() const { return Future<Result, Type>(state_); }
Future<Result, Type> getFuture() const { return Future<Result, Type>{state_}; }

private:
typedef std::function<void(Result, const Type&)> ListenerCallback;
std::shared_ptr<InternalState<Result, Type> > state_;
const InternalStatePtr<Result, Type> state_;
};

class Void {};

} /* namespace pulsar */
} // namespace pulsar

#endif /* LIB_FUTURE_H_ */
#endif
6 changes: 3 additions & 3 deletions lib/stats/ProducerStatsImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void ProducerStatsImpl::flushAndReset(const boost::system::error_code& ec) {
return;
}

Lock lock(mutex_);
std::unique_lock<std::mutex> lock(mutex_);
std::ostringstream oss;
oss << *this;
numMsgsSent_ = 0;
Expand All @@ -86,7 +86,7 @@ void ProducerStatsImpl::flushAndReset(const boost::system::error_code& ec) {
}

void ProducerStatsImpl::messageSent(const Message& msg) {
Lock lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
numMsgsSent_++;
totalMsgsSent_++;
numBytesSent_ += msg.getLength();
Expand All @@ -96,7 +96,7 @@ void ProducerStatsImpl::messageSent(const Message& msg) {
void ProducerStatsImpl::messageReceived(Result res, const boost::posix_time::ptime& publishTime) {
boost::posix_time::ptime currentTime = boost::posix_time::microsec_clock::universal_time();
double diffInMicros = (currentTime - publishTime).total_microseconds();
Lock lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
totalLatencyAccumulator_(diffInMicros);
latencyAccumulator_(diffInMicros);
sendMap_[res] += 1; // Value will automatically be initialized to 0 in the constructor
Expand Down
4 changes: 2 additions & 2 deletions tests/BasicEndToEndTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ TEST(BasicEndToEndTest, testBatchMessages) {
}

void resendMessage(Result r, const MessageId msgId, Producer producer) {
Lock lock(mutex_);
std::unique_lock<std::mutex> lock(mutex_);
if (r != ResultOk) {
LOG_DEBUG("globalResendMessageCount" << globalResendMessageCount);
if (++globalResendMessageCount >= 3) {
Expand Down Expand Up @@ -993,7 +993,7 @@ TEST(BasicEndToEndTest, testResendViaSendCallback) {
// 3 seconds
std::this_thread::sleep_for(std::chrono::microseconds(3 * 1000 * 1000));
producer.close();
Lock lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
ASSERT_GE(globalResendMessageCount, 3);
}

Expand Down
27 changes: 27 additions & 0 deletions tests/PromiseTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include <vector>

#include "lib/Future.h"
#include "lib/LogUtils.h"

DECLARE_LOG_OBJECT()

using namespace pulsar;

Expand Down Expand Up @@ -84,3 +87,27 @@ TEST(PromiseTest, testListeners) {
ASSERT_EQ(results, (std::vector<int>(2, 0)));
ASSERT_EQ(values, (std::vector<std::string>(2, "hello")));
}

TEST(PromiseTest, testTriggerListeners) {
InternalState<int, int> state;
state.addListener([](int, const int&) {
LOG_INFO("Start task 1...");
std::this_thread::sleep_for(std::chrono::seconds(1));
LOG_INFO("Finish task 1...");
});
state.addListener([](int, const int&) {
LOG_INFO("Start task 2...");
std::this_thread::sleep_for(std::chrono::seconds(1));
LOG_INFO("Finish task 2...");
});

auto start = std::chrono::high_resolution_clock::now();
auto future1 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); });
auto future2 = std::async(std::launch::async, [&state] { state.triggerListeners(0, 0); });
future1.wait();
future2.wait();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - start)
.count();
ASSERT_TRUE(elapsed > 2000) << "elapsed: " << elapsed << "ms";
}

0 comments on commit 20f6fa0

Please sign in to comment.