Skip to content

Commit

Permalink
chore: pass reply_builder explicitly to pubsub module
Browse files Browse the repository at this point in the history
Also, deprecate `reply_builder()` access method.

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange committed Oct 30, 2024
1 parent daf8604 commit a753c23
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 91 deletions.
2 changes: 1 addition & 1 deletion src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ConnectionContext {
return protocol_;
}

SinkReplyBuilder* reply_builder() {
SinkReplyBuilder* reply_builder_old() {
return rbuilder_.get();
}

Expand Down
4 changes: 2 additions & 2 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ void Connection::HandleRequests() {
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder();
reply_builder_ = cc_->reply_builder_old();

if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
Expand Down Expand Up @@ -811,7 +811,7 @@ std::pair<std::string, std::string> Connection::GetClientInfoBeforeAfterTid() co
string_view phase_name = PHASE_NAMES[phase_];

if (cc_) {
DCHECK(cc_->reply_builder() && reply_builder_);
DCHECK(reply_builder_);
string cc_info = service_->GetContextInfo(cc_.get()).Format();
if (reply_builder_->IsSendActive())
phase_name = "send";
Expand Down
146 changes: 72 additions & 74 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ namespace dfly {
using namespace std;
using namespace facade;

static void SendSubscriptionChangedResponse(string_view action, std::optional<string_view> topic,
unsigned count, RedisReplyBuilder* rb) {
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
}

StoredCmd::StoredCmd(const CommandId* cid, ArgSlice args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
Expand Down Expand Up @@ -137,61 +148,13 @@ void ConnectionContext::ChangeMonitor(bool start) {
EnableMonitoring(start);
}

vector<unsigned> ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply,
ConnectionContext* conn) {
vector<unsigned> result(to_reply ? args.size() : 0, 0);

auto& conn_state = conn->conn_state;
if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
conn->subscriptions++;
}

auto& sinfo = *conn->conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : args) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(conn->subscriptions, 1u);
conn->subscriptions--;
}

return result;
}

void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(false, args, to_add, to_reply, this);
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, false, to_add, to_reply);

if (to_reply) {
for (size_t i = 0; i < result.size(); ++i) {
const char* action[2] = {"unsubscribe", "subscribe"};
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action[to_add]);
rb->SendBulkString(ArgS(args, i));
Expand All @@ -200,53 +163,41 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
}

void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(true, args, to_add, to_reply, this);
void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, true, to_add, to_reply);

if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
if (result.size() == 0) {
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0);
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0, rb);
}

for (size_t i = 0; i < result.size(); ++i) {
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i]);
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i], rb);
}
}
}

void ConnectionContext::UnsubscribeAll(bool to_reply) {
void ConnectionContext::UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) {
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0, rb);
}
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
ChangeSubscription(false, to_reply, CmdArgList{arg_vec});
ChangeSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

void ConnectionContext::PUnsubscribeAll(bool to_reply) {
void ConnectionContext::PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) {
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0, rb);
}

StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
ChangePSubscription(false, to_reply, CmdArgList{arg_vec});
}

void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
std::optional<string_view> topic,
unsigned count) {
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
ChangePSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

size_t ConnectionState::ExecInfo::UsedMemory() const {
Expand All @@ -269,6 +220,53 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}

vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);

if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
subscriptions++;
}

auto& sinfo = *conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, this, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : channels) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
subscriptions--;
}

return result;
}

void ConnectionState::ExecInfo::Clear() {
DCHECK(!preborrowed_interpreter); // Must have been released properly
state = EXEC_INACTIVE;
Expand Down
15 changes: 9 additions & 6 deletions src/server/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,13 @@ class ConnectionContext : public facade::ConnectionContext {
return conn_state.db_index;
}

void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args);
void UnsubscribeAll(bool to_reply);
void PUnsubscribeAll(bool to_reply);
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);

void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);
void UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void ChangeMonitor(bool start); // either start or stop monitor on a given connection

size_t UsedMemory() const override;
Expand All @@ -317,8 +320,8 @@ class ConnectionContext : public facade::ConnectionContext {
monitor = enable;
}

void SendSubscriptionChangedResponse(std::string_view action,
std::optional<std::string_view> topic, unsigned count);
std::vector<unsigned> ChangeSubscriptions(CmdArgList channels, bool pattern, bool to_add,
bool to_reply);
};

} // namespace dfly
17 changes: 9 additions & 8 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2364,7 +2364,8 @@ void Service::Subscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* buil
if (cluster::IsClusterEnabled()) {
return builder->SendError("SUBSCRIBE is not supported in cluster mode yet");
}
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args));
cntx->ChangeSubscription(true /*add*/, true /* reply*/, std::move(args),
static_cast<RedisReplyBuilder*>(builder));
}

void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
Expand All @@ -2373,9 +2374,9 @@ void Service::Unsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu
return builder->SendError("UNSUBSCRIBE is not supported in cluster mode yet");
}
if (args.size() == 0) {
cntx->UnsubscribeAll(true);
cntx->UnsubscribeAll(true, static_cast<RedisReplyBuilder*>(builder));
} else {
cntx->ChangeSubscription(false, true, args);
cntx->ChangeSubscription(false, true, args, static_cast<RedisReplyBuilder*>(builder));
}
}

Expand All @@ -2384,7 +2385,7 @@ void Service::PSubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* bui
if (cluster::IsClusterEnabled()) {
return builder->SendError("PSUBSCRIBE is not supported in cluster mode yet");
}
cntx->ChangePSubscription(true, true, args);
cntx->ChangePSubscription(true, true, args, static_cast<RedisReplyBuilder*>(builder));
}

void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder,
Expand All @@ -2393,9 +2394,9 @@ void Service::PUnsubscribe(CmdArgList args, Transaction* tx, SinkReplyBuilder* b
return builder->SendError("PUNSUBSCRIBE is not supported in cluster mode yet");
}
if (args.size() == 0) {
cntx->PUnsubscribeAll(true);
cntx->PUnsubscribeAll(true, static_cast<RedisReplyBuilder*>(builder));
} else {
cntx->ChangePSubscription(false, true, args);
cntx->ChangePSubscription(false, true, args, static_cast<RedisReplyBuilder*>(builder));
}
}

Expand Down Expand Up @@ -2653,12 +2654,12 @@ void Service::OnConnectionClose(facade::ConnectionContext* cntx) {

if (conn_state.subscribe_info) { // Clean-ups related to PUBSUB
if (!conn_state.subscribe_info->channels.empty()) {
server_cntx->UnsubscribeAll(false);
server_cntx->UnsubscribeAll(false, nullptr);
}

if (conn_state.subscribe_info) {
DCHECK(!conn_state.subscribe_info->patterns.empty());
server_cntx->PUnsubscribeAll(false);
server_cntx->PUnsubscribeAll(false, nullptr);
}

DCHECK(!conn_state.subscribe_info);
Expand Down

0 comments on commit a753c23

Please sign in to comment.