Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pass reply_builder explicitly to pubsub module #4021

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,10 @@ class ConnectionContext {
return protocol_;
}

SinkReplyBuilder* reply_builder() {
SinkReplyBuilder* reply_builder_old() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you called this _old?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that we won't use it. I will remove it in follow up PRs

return rbuilder_.get();
}

// Allows receiving the output data from the commands called from scripts.
SinkReplyBuilder* Inject(SinkReplyBuilder* new_i) {
SinkReplyBuilder* res = rbuilder_.release();
rbuilder_.reset(new_i);
return res;
}

virtual size_t UsedMemory() const;

// connection state / properties.
Expand Down
4 changes: 2 additions & 2 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ void Connection::HandleRequests() {
// down and return with an error accordingly.
if (http_res && socket_->IsOpen()) {
cc_.reset(service_->CreateContext(socket_.get(), this));
reply_builder_ = cc_->reply_builder();
reply_builder_ = cc_->reply_builder_old();

if (*http_res) {
VLOG(1) << "HTTP1.1 identified";
Expand Down Expand Up @@ -811,7 +811,7 @@ std::pair<std::string, std::string> Connection::GetClientInfoBeforeAfterTid() co
string_view phase_name = PHASE_NAMES[phase_];

if (cc_) {
DCHECK(cc_->reply_builder() && reply_builder_);
DCHECK(reply_builder_);
string cc_info = service_->GetContextInfo(cc_.get()).Format();
if (reply_builder_->IsSendActive())
phase_name = "send";
Expand Down
151 changes: 73 additions & 78 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ namespace dfly {
using namespace std;
using namespace facade;

static void SendSubscriptionChangedResponse(string_view action, std::optional<string_view> topic,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved here from below

unsigned count, RedisReplyBuilder* rb) {
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
}

StoredCmd::StoredCmd(const CommandId* cid, ArgSlice args, facade::ReplyMode mode)
: cid_{cid}, buffer_{}, sizes_(args.size()), reply_mode_{mode} {
size_t total_size = 0;
Expand Down Expand Up @@ -98,8 +109,7 @@ ConnectionContext::ConnectionContext(::io::Sink* stream, facade::Connection* own
}
}

ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb)
ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction* tx)
: facade::ConnectionContext(nullptr, nullptr), transaction{tx} {
if (owner) {
acl_commands = owner->acl_commands;
Expand All @@ -115,8 +125,6 @@ ConnectionContext::ConnectionContext(const ConnectionContext* owner, Transaction
conn_state.db_index = owner->conn_state.db_index;
conn_state.squashing_info = {owner};
}
auto* prev_reply_builder = Inject(crb);
CHECK_EQ(prev_reply_builder, nullptr);
}

void ConnectionContext::ChangeMonitor(bool start) {
Expand All @@ -137,61 +145,13 @@ void ConnectionContext::ChangeMonitor(bool start) {
EnableMonitoring(start);
}

vector<unsigned> ChangeSubscriptions(bool pattern, CmdArgList args, bool to_add, bool to_reply,
ConnectionContext* conn) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved below to be a method of ConnectionContext

vector<unsigned> result(to_reply ? args.size() : 0, 0);

auto& conn_state = conn->conn_state;
if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
conn->subscriptions++;
}

auto& sinfo = *conn->conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, conn, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : args) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(conn->subscriptions, 1u);
conn->subscriptions--;
}

return result;
}

void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(false, args, to_add, to_reply, this);
void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, false, to_add, to_reply);

if (to_reply) {
for (size_t i = 0; i < result.size(); ++i) {
const char* action[2] = {"unsubscribe", "subscribe"};
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action[to_add]);
rb->SendBulkString(ArgS(args, i));
Expand All @@ -200,53 +160,41 @@ void ConnectionContext::ChangeSubscription(bool to_add, bool to_reply, CmdArgLis
}
}

void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args) {
vector<unsigned> result = ChangeSubscriptions(true, args, to_add, to_reply, this);
void ConnectionContext::ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb) {
vector<unsigned> result = ChangeSubscriptions(args, true, to_add, to_reply);

if (to_reply) {
const char* action[2] = {"punsubscribe", "psubscribe"};
if (result.size() == 0) {
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0);
return SendSubscriptionChangedResponse(action[to_add], std::nullopt, 0, rb);
}

for (size_t i = 0; i < result.size(); ++i) {
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i]);
SendSubscriptionChangedResponse(action[to_add], ArgS(args, i), result[i], rb);
}
}
}

void ConnectionContext::UnsubscribeAll(bool to_reply) {
void ConnectionContext::UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->channels.empty())) {
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("unsubscribe", std::nullopt, 0, rb);
}
StringVec channels(conn_state.subscribe_info->channels.begin(),
conn_state.subscribe_info->channels.end());
CmdArgVec arg_vec(channels.begin(), channels.end());
ChangeSubscription(false, to_reply, CmdArgList{arg_vec});
ChangeSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

void ConnectionContext::PUnsubscribeAll(bool to_reply) {
void ConnectionContext::PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb) {
if (to_reply && (!conn_state.subscribe_info || conn_state.subscribe_info->patterns.empty())) {
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0);
return SendSubscriptionChangedResponse("punsubscribe", std::nullopt, 0, rb);
}

StringVec patterns(conn_state.subscribe_info->patterns.begin(),
conn_state.subscribe_info->patterns.end());
CmdArgVec arg_vec(patterns.begin(), patterns.end());
ChangePSubscription(false, to_reply, CmdArgList{arg_vec});
}

void ConnectionContext::SendSubscriptionChangedResponse(string_view action,
std::optional<string_view> topic,
unsigned count) {
auto rb = static_cast<RedisReplyBuilder*>(reply_builder());
rb->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rb->SendBulkString(action);
if (topic.has_value())
rb->SendBulkString(topic.value());
else
rb->SendNull();
rb->SendLong(count);
ChangePSubscription(false, to_reply, CmdArgList{arg_vec}, rb);
}

size_t ConnectionState::ExecInfo::UsedMemory() const {
Expand All @@ -269,6 +217,53 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}

vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a free function, now it's a method.


if (!to_add && !conn_state.subscribe_info)
return result;

if (!conn_state.subscribe_info) {
DCHECK(to_add);

conn_state.subscribe_info.reset(new ConnectionState::SubscribeInfo);
subscriptions++;
}

auto& sinfo = *conn_state.subscribe_info.get();
auto& local_store = pattern ? sinfo.patterns : sinfo.channels;

int32_t tid = util::ProactorBase::me()->GetPoolIndex();
DCHECK_GE(tid, 0);

ChannelStoreUpdater csu{pattern, to_add, this, uint32_t(tid)};

// Gather all the channels we need to subscribe to / remove.
size_t i = 0;
for (string_view channel : channels) {
if (to_add && local_store.emplace(channel).second)
csu.Record(channel);
else if (!to_add && local_store.erase(channel) > 0)
csu.Record(channel);

if (to_reply)
result[i++] = sinfo.SubscriptionCount();
}

csu.Apply();

// Important to reset conn_state.subscribe_info only after all references to it were
// removed.
if (!to_add && conn_state.subscribe_info->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
subscriptions--;
}

return result;
}

void ConnectionState::ExecInfo::Clear() {
DCHECK(!preborrowed_interpreter); // Must have been released properly
state = EXEC_INACTIVE;
Expand Down
19 changes: 10 additions & 9 deletions src/server/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ struct ConnectionState {
class ConnectionContext : public facade::ConnectionContext {
public:
ConnectionContext(::io::Sink* stream, facade::Connection* owner, dfly::acl::UserCredentials cred);

ConnectionContext(const ConnectionContext* owner, Transaction* tx,
facade::CapturingReplyBuilder* crb);
ConnectionContext(const ConnectionContext* owner, Transaction* tx);

struct DebugInfo {
uint32_t shards_count = 0;
Expand All @@ -292,10 +290,13 @@ class ConnectionContext : public facade::ConnectionContext {
return conn_state.db_index;
}

void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args);
void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args);
void UnsubscribeAll(bool to_reply);
void PUnsubscribeAll(bool to_reply);
void ChangeSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);

void ChangePSubscription(bool to_add, bool to_reply, CmdArgList args,
facade::RedisReplyBuilder* rb);
void UnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void PUnsubscribeAll(bool to_reply, facade::RedisReplyBuilder* rb);
void ChangeMonitor(bool start); // either start or stop monitor on a given connection

size_t UsedMemory() const override;
Expand All @@ -317,8 +318,8 @@ class ConnectionContext : public facade::ConnectionContext {
monitor = enable;
}

void SendSubscriptionChangedResponse(std::string_view action,
std::optional<std::string_view> topic, unsigned count);
std::vector<unsigned> ChangeSubscriptions(CmdArgList channels, bool pattern, bool to_add,
bool to_reply);
};

} // namespace dfly
3 changes: 1 addition & 2 deletions src/server/debugcmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool

absl::InlinedVector<string_view, 5> args_view;
facade::CapturingReplyBuilder crb;
ConnectionContext local_cntx{cntx, stub_tx.get(), &crb};
ConnectionContext local_cntx{cntx, stub_tx.get()};

absl::InsecureBitGen gen;
for (unsigned i = 0; i < batch.sz; ++i) {
Expand All @@ -175,7 +175,6 @@ void DoPopulateBatch(string_view type, string_view prefix, size_t val_size, bool
sf->service().InvokeCmd(cid, args_span, &crb, &local_cntx);
}

local_cntx.Inject(nullptr);
local_tx->UnlockMulti();
}

Expand Down
5 changes: 1 addition & 4 deletions src/server/journal/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,14 @@ template <typename... Ts> journal::ParsedEntry::CmdData BuildFromParts(Ts... par
} // namespace

JournalExecutor::JournalExecutor(Service* service)
: service_{service},
reply_builder_{facade::ReplyMode::NONE},
conn_context_{nullptr, nullptr, &reply_builder_} {
: service_{service}, reply_builder_{facade::ReplyMode::NONE}, conn_context_{nullptr, nullptr} {
conn_context_.is_replicating = true;
conn_context_.journal_emulated = true;
conn_context_.skip_acl_validation = true;
conn_context_.ns = &namespaces.GetDefaultNamespace();
}

JournalExecutor::~JournalExecutor() {
conn_context_.Inject(nullptr);
}

void JournalExecutor::Execute(DbIndex dbid, absl::Span<journal::ParsedEntry::CmdData> cmds) {
Expand Down
Loading
Loading