Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.*: refine streaming writer and exchange writer #6186

Merged
merged 27 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions dbms/src/Flash/Coprocessor/StreamWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,7 @@ struct StreamWriter
explicit StreamWriter(::grpc::ServerWriter<::coprocessor::BatchResponse> * writer_)
: writer(writer_)
{}
void write(mpp::MPPDataPacket &)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &) do not support writing MPPDataPacket!");
}
void write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t)
{
throw Exception("StreamWriter::write(mpp::MPPDataPacket &, [[maybe_unused]] uint16_t) do not support writing MPPDataPacket!");
}
void write(tipb::SelectResponse & response, [[maybe_unused]] uint16_t id = 0)
void write(tipb::SelectResponse & response)
{
::coprocessor::BatchResponse resp;
if (!response.SerializeToString(resp.mutable_data()))
Expand All @@ -59,8 +51,6 @@ struct StreamWriter
if (!writer->Write(resp))
throw Exception("Failed to write resp");
}
// a helper function
uint16_t getPartitionNum() { return 0; }
};

using StreamWriterPtr = std::shared_ptr<StreamWriter>;
ywqzzy marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
14 changes: 4 additions & 10 deletions dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,15 @@ class TestStreamingWriter : public testing::Test
std::unique_ptr<DAGContext> dag_context_ptr;
};

using MockStreamWriterChecker = std::function<void(tipb::SelectResponse &, uint16_t)>;
using MockStreamWriterChecker = std::function<void(tipb::SelectResponse &)>;

struct MockStreamWriter
{
explicit MockStreamWriter(MockStreamWriterChecker checker_)
: checker(checker_)
{}

void write(mpp::MPPDataPacket &) { FAIL() << "cannot reach here."; }
void write(mpp::MPPDataPacket &, uint16_t) { FAIL() << "cannot reach here."; }
void write(tipb::SelectResponse & response, uint16_t part_id) { checker(response, part_id); }
void write(tipb::SelectResponse & response) { checker(response, 0); }
uint16_t getPartitionNum() const { return 1; }
void write(tipb::SelectResponse & response) { checker(response); }

private:
MockStreamWriterChecker checker;
Expand Down Expand Up @@ -137,8 +133,7 @@ try

// 2. Build MockStreamWriter.
std::vector<tipb::SelectResponse> write_report;
auto checker = [&write_report](tipb::SelectResponse & response, uint16_t part_id) {
ASSERT_EQ(part_id, 0);
auto checker = [&write_report](tipb::SelectResponse & response) {
write_report.emplace_back(std::move(response));
};
auto mock_writer = std::make_shared<MockStreamWriter>(checker);
Expand Down Expand Up @@ -195,8 +190,7 @@ try
dag_context_ptr->encode_type = encode_type;

std::vector<tipb::SelectResponse> write_report;
auto checker = [&write_report](tipb::SelectResponse & response, uint16_t part_id) {
ASSERT_EQ(part_id, 0);
auto checker = [&write_report](tipb::SelectResponse & response) {
write_report.emplace_back(std::move(response));
};
auto mock_writer = std::make_shared<MockStreamWriter>(checker);
Expand Down
39 changes: 19 additions & 20 deletions dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,32 +69,31 @@ template <class StreamWriterPtr>
template <bool send_exec_summary_at_last>
void BroadcastOrPassThroughWriter<StreamWriterPtr>::encodeThenWriteBlocks()
{
TrackedMppDataPacket tracked_packet(current_memory_tracker);
if constexpr (send_exec_summary_at_last)
{
TrackedSelectResp response;
summary_collector.addExecuteSummaries(response.getResponse(), /*delta_mode=*/false);
tracked_packet.serializeByResponse(response.getResponse());
}
if (blocks.empty())
if (!blocks.empty())
{
if constexpr (send_exec_summary_at_last)
auto tracked_packet = std::make_shared<TrackedMppDataPacket>();
while (!blocks.empty())
{
writer->write(tracked_packet.getPacket());
const auto & block = blocks.back();
chunk_codec_stream->encode(block, 0, block.rows());
blocks.pop_back();
tracked_packet->addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
}
return;
assert(blocks.empty());
rows_in_blocks = 0;
writer->write(tracked_packet);
}
while (!blocks.empty())

if constexpr (send_exec_summary_at_last)
{
const auto & block = blocks.back();
chunk_codec_stream->encode(block, 0, block.rows());
blocks.pop_back();
tracked_packet.addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
TrackedSelectResp response;
summary_collector.addExecuteSummaries(response.getResponse(), /*delta_mode=*/false);
auto tracked_packet = std::make_shared<TrackedMppDataPacket>();
tracked_packet->serializeByResponse(response.getResponse());
// only send to one tunnel.
writer->write(tracked_packet, 0);
}
assert(blocks.empty());
rows_in_blocks = 0;
writer->write(tracked_packet.getPacket());
}
SeaRise marked this conversation as resolved.
Show resolved Hide resolved

template class BroadcastOrPassThroughWriter<MPPTunnelSetPtr>;
Expand Down
18 changes: 10 additions & 8 deletions dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ template <class StreamWriterPtr>
template <bool send_exec_summary_at_last>
void FineGrainedShuffleWriter<StreamWriterPtr>::batchWriteFineGrainedShuffle()
{
std::vector<TrackedMppDataPacket> tracked_packets(partition_num);
auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num);

if (!blocks.empty())
{
Expand Down Expand Up @@ -126,9 +126,9 @@ void FineGrainedShuffleWriter<StreamWriterPtr>::batchWriteFineGrainedShuffle()
if (dest_block_rows > 0)
{
chunk_codec_stream->encode(dest_block, 0, dest_block_rows);
tracked_packets[part_id].addChunk(chunk_codec_stream->getString());
tracked_packets[part_id]->addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
tracked_packets[part_id].packet.add_stream_ids(stream_idx);
tracked_packets[part_id]->getPacket().add_stream_ids(stream_idx);
}
}
}
Expand All @@ -139,7 +139,7 @@ void FineGrainedShuffleWriter<StreamWriterPtr>::batchWriteFineGrainedShuffle()

template <class StreamWriterPtr>
template <bool send_exec_summary_at_last>
void FineGrainedShuffleWriter<StreamWriterPtr>::writePackets(std::vector<TrackedMppDataPacket> & packets)
void FineGrainedShuffleWriter<StreamWriterPtr>::writePackets(const std::vector<TrackedMppDataPacketPtr> & packets)
{
size_t part_id = 0;

Expand All @@ -149,15 +149,17 @@ void FineGrainedShuffleWriter<StreamWriterPtr>::writePackets(std::vector<Tracked
summary_collector.addExecuteSummaries(response, /*delta_mode=*/false);
/// Sending the response to only one node, default the first one.
assert(!packets.empty());
packets[0].serializeByResponse(response);
writer->write(packets[0].getPacket(), 0);
assert(packets[0]);
packets[0]->serializeByResponse(response);
writer->write(packets[0], 0);
part_id = 1;
}

for (; part_id < packets.size(); ++part_id)
{
auto & packet = packets[part_id].getPacket();
if (packet.chunks_size() > 0)
const auto & packet = packets[part_id];
assert(packet);
if (packet->getPacket().chunks_size() > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

writer->write(packet, part_id);
}
}
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FineGrainedShuffleWriter : public DAGResponseWriter
void batchWriteFineGrainedShuffle();

template <bool send_exec_summary_at_last>
void writePackets(std::vector<TrackedMppDataPacket> & packets);
void writePackets(const std::vector<TrackedMppDataPacketPtr> & packets);

bool should_send_exec_summary_at_last;
StreamWriterPtr writer;
Expand Down
9 changes: 9 additions & 0 deletions dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,13 @@ void computeHash(const Block & input_block,
}
}
}

std::vector<TrackedMppDataPacketPtr> createPackets(size_t partition_num)
{
std::vector<TrackedMppDataPacketPtr> tracked_packets;
tracked_packets.reserve(partition_num);
for (size_t i = 0; i < partition_num; ++i)
tracked_packets.emplace_back(std::make_shared<TrackedMppDataPacket>());
return tracked_packets;
}
} // namespace DB::HashBaseWriterHelper
3 changes: 3 additions & 0 deletions dbms/src/Flash/Mpp/HashBaseWriterHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include <Core/Block.h>
#include <Flash/Mpp/TrackedMppDataPacket.h>
#include <Storages/Transaction/Collator.h>

namespace DB::HashBaseWriterHelper
Expand All @@ -29,4 +30,6 @@ void computeHash(const Block & input_block,
std::vector<String> & partition_key_containers,
const std::vector<Int64> & partition_col_ids,
std::vector<std::vector<MutableColumnPtr>> & result_columns);

std::vector<TrackedMppDataPacketPtr> createPackets(size_t partition_num);
} // namespace DB::HashBaseWriterHelper
18 changes: 10 additions & 8 deletions dbms/src/Flash/Mpp/HashPartitionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <Common/TiFlashException.h>
#include <Flash/Coprocessor/CHBlockChunkCodec.h>
#include <Flash/Mpp/HashBaseWriterHelper.h>
#include <Flash/Mpp/HashParitionWriter.h>
#include <Flash/Mpp/HashPartitionWriter.h>
#include <Flash/Mpp/MPPTunnelSet.h>

namespace DB
Expand Down Expand Up @@ -76,7 +76,7 @@ template <class StreamWriterPtr>
template <bool send_exec_summary_at_last>
void HashPartitionWriter<StreamWriterPtr>::partitionAndEncodeThenWriteBlocks()
{
std::vector<TrackedMppDataPacket> tracked_packets(partition_num);
auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num);

if (!blocks.empty())
{
Expand All @@ -100,7 +100,7 @@ void HashPartitionWriter<StreamWriterPtr>::partitionAndEncodeThenWriteBlocks()
if (dest_block_rows > 0)
{
chunk_codec_stream->encode(dest_block, 0, dest_block_rows);
tracked_packets[part_id].addChunk(chunk_codec_stream->getString());
tracked_packets[part_id]->addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
}
}
Expand All @@ -114,7 +114,7 @@ void HashPartitionWriter<StreamWriterPtr>::partitionAndEncodeThenWriteBlocks()

template <class StreamWriterPtr>
template <bool send_exec_summary_at_last>
void HashPartitionWriter<StreamWriterPtr>::writePackets(std::vector<TrackedMppDataPacket> & packets)
void HashPartitionWriter<StreamWriterPtr>::writePackets(const std::vector<TrackedMppDataPacketPtr> & packets)
{
size_t part_id = 0;

Expand All @@ -124,15 +124,17 @@ void HashPartitionWriter<StreamWriterPtr>::writePackets(std::vector<TrackedMppDa
summary_collector.addExecuteSummaries(response, /*delta_mode=*/false);
/// Sending the response to only one node, default the first one.
assert(!packets.empty());
packets[0].serializeByResponse(response);
writer->write(packets[0].getPacket(), 0);
assert(packets[0]);
packets[0]->serializeByResponse(response);
writer->write(packets[0], 0);
part_id = 1;
}

for (; part_id < packets.size(); ++part_id)
{
auto & packet = packets[part_id].getPacket();
if (packet.chunks_size() > 0)
const auto & packet = packets[part_id];
assert(packet);
if (packet->getPacket().chunks_size() > 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

writer->write(packet, part_id);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class HashPartitionWriter : public DAGResponseWriter
void partitionAndEncodeThenWriteBlocks();

template <bool send_exec_summary_at_last>
void writePackets(std::vector<TrackedMppDataPacket> & packets);
void writePackets(const std::vector<TrackedMppDataPacketPtr> & packets);

Int64 batch_send_min_limit;
bool should_send_exec_summary_at_last;
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Flash/Mpp/MPPTunnel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ void MPPTunnel::close(const String & reason, bool wait_sender_finish)
}

// TODO: consider to hold a buffer
void MPPTunnel::write(const mpp::MPPDataPacket & data)
void MPPTunnel::write(const TrackedMppDataPacketPtr & data)
{
LOG_TRACE(log, "ready to write");
{
Expand All @@ -144,7 +144,7 @@ void MPPTunnel::write(const mpp::MPPDataPacket & data)

if (tunnel_sender->push(data))
{
connection_profile_info.bytes += data.ByteSizeLong();
connection_profile_info.bytes += data->getPacket().ByteSizeLong();
connection_profile_info.packets += 1;
return;
}
Expand Down
14 changes: 7 additions & 7 deletions dbms/src/Flash/Mpp/MPPTunnel.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ enum class TunnelSenderMode
ASYNC_GRPC // Using async grpc writer
};

using TrackedMppDataPacketPtr = std::shared_ptr<DB::TrackedMppDataPacket>;

/// TunnelSender is responsible for consuming data from Tunnel's internal send_queue and do the actual sending work
/// After TunnelSend finished its work, either normally or abnormally, set ConsumerState to inform Tunnel
class TunnelSender : private boost::noncopyable
Expand All @@ -74,9 +72,10 @@ class TunnelSender : private boost::noncopyable
{
}

virtual bool push(const mpp::MPPDataPacket & data)
virtual bool push(const TrackedMppDataPacketPtr & data)
{
return send_queue.push(std::make_shared<TrackedMppDataPacket>(data, getMemoryTracker())) == MPMCQueueResult::OK;
data->switchMemTracker(getMemoryTracker());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to switch memory tracker here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the memory tracker of ExchangeWriter and MPPTunnel are the same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the memory tracker of MPPTunnel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the memory tracker in ExchangeWriter and MPPTunnel are the same.

Because the memory tracker of ExchangeWriter and MPPTunnel are the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the memory tracker of MPPTunnel?

Yes, just keep the memory tracker in TunnelSender is enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, updated.

return send_queue.push(data) == MPMCQueueResult::OK;
}

virtual void cancelWith(const String & reason)
Expand Down Expand Up @@ -176,9 +175,10 @@ class AsyncTunnelSender : public TunnelSender
, queue(queue_size, func)
{}

bool push(const mpp::MPPDataPacket & data) override
bool push(const TrackedMppDataPacketPtr & data) override
{
return queue.push(std::make_shared<TrackedMppDataPacket>(data, getMemoryTracker()));
data->switchMemTracker(getMemoryTracker());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

return queue.push(data);
}

bool finish() override
Expand Down Expand Up @@ -275,7 +275,7 @@ class MPPTunnel : private boost::noncopyable
const String & id() const { return tunnel_id; }

// write a single packet to the tunnel's send queue, it will block if tunnel is not ready.
void write(const mpp::MPPDataPacket & data);
void write(const TrackedMppDataPacketPtr & data);

// finish the writing, and wait until the sender finishes.
void writeDone();
Expand Down
Loading