Skip to content

Commit

Permalink
Add changes to channel that are needed for select op (#9084)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavarora authored Mar 14, 2018
1 parent a4b0e4a commit 41894da
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 7 deletions.
109 changes: 109 additions & 0 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,43 @@ limitations under the License. */
#pragma once

#include <stddef.h> // for size_t
#include <condition_variable>
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {

enum class ChannelAction {
SEND = 0,
RECEIVE = 1,
CLOSE = 2,
};

// Channel is the abstract class of buffered and un-buffered channels.
template <typename T>
class Channel {
public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;

virtual void Unlock() = 0;
virtual bool IsClosed() = 0;
virtual void Close() = 0;
virtual ~Channel() {}

virtual void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) = 0;
virtual void RemoveFromSendQ(const void* referrer) = 0;
virtual void RemoveFromReceiveQ(const void* referrer) = 0;
};

// Forward declaration of channel implementations.
Expand Down Expand Up @@ -80,6 +100,27 @@ class ChannelHolder {
return channel != nullptr ? channel->Receive(data) : false;
}

bool IsClosed() {
if (IsInitialized()) {
return holder_->IsClosed();
}
return false;
}

bool CanSend() {
if (IsInitialized()) {
return holder_->CanSend();
}
return false;
}

bool CanReceive() {
if (IsInitialized()) {
return holder_->CanReceive();
}
return false;
}

void close() {
if (IsInitialized()) holder_->Close();
}
Expand All @@ -97,6 +138,50 @@ class ChannelHolder {
if (IsInitialized()) holder_->Unlock();
}

template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}
}

template <typename T>
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}
}

template <typename T>
void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromSendQ(referrer);
}
}
}

template <typename T>
void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->RemoveFromReceiveQ(referrer);
}
}
}

inline bool IsInitialized() const { return holder_ != nullptr; }

inline const std::type_index Type() {
Expand All @@ -113,6 +198,9 @@ class ChannelHolder {
virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual bool IsClosed() = 0;
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual void Close() = 0;
virtual void Lock() = 0;
virtual void Unlock() = 0;
Expand All @@ -129,6 +217,27 @@ class ChannelHolder {

virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }

virtual bool IsClosed() {
if (channel_) {
return channel_->IsClosed();
}
return false;
}

virtual bool CanSend() {
if (channel_) {
return channel_->CanSend();
}
return false;
}

virtual bool CanReceive() {
if (channel_) {
return channel_->CanReceive();
}
return false;
}

virtual void Close() {
if (channel_) channel_->Close();
}
Expand Down
153 changes: 146 additions & 7 deletions paddle/fluid/framework/channel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,50 @@ class ChannelImpl : public paddle::framework::Channel<T> {
friend void paddle::framework::CloseChannel<T>(Channel<T> *);

public:
virtual bool CanSend();
virtual bool CanReceive();
virtual bool Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
virtual void Unlock();
virtual bool IsClosed();
virtual void Close();

ChannelImpl(size_t);
virtual ~ChannelImpl();

virtual void AddToSendQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);
virtual void AddToReceiveQ(const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb);

virtual void RemoveFromSendQ(const void *referrer);
virtual void RemoveFromReceiveQ(const void *referrer);

private:
struct QueueMessage {
T *data;
std::condition_variable_any cond;
std::shared_ptr<std::condition_variable_any> cond;
bool chan_closed = false;
bool completed = false;
const void *referrer; // TODO(thuan): figure out better way to do this
std::function<bool(ChannelAction)> callback;

QueueMessage(T *item) : data(item) {}
QueueMessage(T *item)
: data(item), cond(std::make_shared<std::condition_variable_any>()) {}

QueueMessage(T *item, std::shared_ptr<std::condition_variable_any> cond)
: data(item), cond(cond) {}

void Wait(std::unique_lock<std::recursive_mutex> &lock) {
cond.wait(lock, [this]() { return completed; });
cond->wait(lock, [this]() { return completed; });
}

void Notify() {
completed = true;
cond.notify_all();
cond->notify_all();
}
};

Expand Down Expand Up @@ -87,6 +105,18 @@ ChannelImpl<T>::ChannelImpl(size_t capacity)
PADDLE_ENFORCE_GE(capacity, 0);
}

template <typename T>
bool ChannelImpl<T>::CanSend() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !closed_ && (!recvq.empty() || buf_.size() < cap_);
}

template <typename T>
bool ChannelImpl<T>::CanReceive() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return !(closed_ && buf_.empty()) && (!sendq.empty() || buf_.size() > 0);
}

template <typename T>
bool ChannelImpl<T>::Send(T *item) {
send_ctr++;
Expand All @@ -105,7 +135,24 @@ bool ChannelImpl<T>::Send(T *item) {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
// Do the data transfer
*(m->data) = std::move(*item);
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_send = true;
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item);
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Send(item);

// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
Expand Down Expand Up @@ -150,7 +197,25 @@ bool ChannelImpl<T>::Receive(T *item) {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
// Do the data transfer
*item = std::move(*(m->data));
// We will do this data transfer if either of the following
// cases are true
// 1. callback == nullptr // This means it was a regular channel send
// 2. callback returns true
bool do_receive = true;
if (m->callback != nullptr)
do_receive = m->callback(ChannelAction::RECEIVE);
if (do_receive)
*item = std::move(*(m->data));
else
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Receive function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return Receive(item);

// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
Expand Down Expand Up @@ -186,6 +251,12 @@ void ChannelImpl<T>::Unlock() {
mu_.unlock();
}

template <typename T>
bool ChannelImpl<T>::IsClosed() {
std::lock_guard<std::recursive_mutex> lock{mu_};
return closed_;
}

template <typename T>
void ChannelImpl<T>::Close() {
std::unique_lock<std::recursive_mutex> lock{mu_};
Expand All @@ -203,6 +274,12 @@ void ChannelImpl<T>::Close() {
std::shared_ptr<QueueMessage> m = recvq.front();
recvq.pop_front();
m->chan_closed = true;

// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}

m->Notify();
}

Expand All @@ -211,10 +288,72 @@ void ChannelImpl<T>::Close() {
std::shared_ptr<QueueMessage> m = sendq.front();
sendq.pop_front();
m->chan_closed = true;

// Execute callback function (if any)
if (m->callback != nullptr) {
m->callback(ChannelAction::CLOSE);
}

m->Notify();
}
}

template <typename T>
void ChannelImpl<T>::AddToSendQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
sendq.push_back(m);
}

template <typename T>
void ChannelImpl<T>::AddToReceiveQ(
const void *referrer, T *data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
std::lock_guard<std::recursive_mutex> lock{mu_};
auto m = std::make_shared<QueueMessage>(data, cond);
m->referrer = referrer;
m->callback = cb;
recvq.push_back(m);
}

template <typename T>
void ChannelImpl<T>::RemoveFromSendQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};

for (auto it = sendq.begin(); it != sendq.end();) {
std::shared_ptr<QueueMessage> sendMsg = (std::shared_ptr<QueueMessage>)*it;

if (sendMsg->referrer == referrer) {
it = sendq.erase(it);
send_ctr--;
} else {
++it;
}
}
}

template <typename T>
void ChannelImpl<T>::RemoveFromReceiveQ(const void *referrer) {
std::lock_guard<std::recursive_mutex> lock{mu_};

for (auto it = recvq.begin(); it != recvq.end();) {
std::shared_ptr<QueueMessage> recvMsg = (std::shared_ptr<QueueMessage>)*it;

if (recvMsg->referrer == referrer) {
it = recvq.erase(it);
recv_ctr--;
} else {
++it;
}
}
}

template <typename T>
ChannelImpl<T>::~ChannelImpl() {
Close();
Expand Down

0 comments on commit 41894da

Please sign in to comment.