Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the logic of encoding mpp packet from exchange writer to mpp tunnel set #6644

Merged
merged 4 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <DataTypes/DataTypesNumber.h>
#include <Flash/Coprocessor/CHBlockChunkCodec.h>
#include <Flash/Coprocessor/ExecutionSummaryCollector.h>
#include <Flash/Mpp/MPPTunnelSetHelper.h>
#include <Interpreters/Context.h>
#include <Storages/DeltaMerge/ScanContext.h>
#include <Storages/StorageDisaggregated.h>
Expand Down Expand Up @@ -58,8 +59,9 @@ bool equalSummaries(const ExecutionSummary & left, const ExecutionSummary & righ

struct MockWriter
{
explicit MockWriter(PacketQueuePtr queue_)
: queue(queue_)
MockWriter(DAGContext & dag_context, PacketQueuePtr queue_)
: result_field_types(dag_context.result_field_types)
, queue(queue_)
{}

static ExecutionSummary mockExecutionSummary()
Expand All @@ -81,9 +83,9 @@ struct MockWriter
return summary;
}

void partitionWrite(TrackedMppDataPacketPtr &&, uint16_t) { FAIL() << "cannot reach here."; }
void broadcastOrPassThroughWrite(TrackedMppDataPacketPtr && packet)
void broadcastOrPassThroughWrite(Blocks & blocks)
{
auto packet = MPPTunnelSetHelper::toPacket(blocks, result_field_types);
++total_packets;
if (!packet->packet.chunks().empty())
total_bytes += packet->packet.ByteSizeLong();
Expand Down Expand Up @@ -117,6 +119,8 @@ struct MockWriter
}
uint16_t getPartitionNum() const { return 1; }

std::vector<tipb::FieldType> result_field_types;

PacketQueuePtr queue;
bool add_summary = false;
size_t total_packets = 0;
Expand Down Expand Up @@ -441,7 +445,7 @@ class TestTiRemoteBlockInputStream : public testing::Test
{
PacketQueuePtr queue_ptr = std::make_shared<PacketQueue>(1000);
std::vector<Block> source_blocks;
auto writer = std::make_shared<MockWriter>(queue_ptr);
auto writer = std::make_shared<MockWriter>(*dag_context_ptr, queue_ptr);
prepareQueue(writer, source_blocks, empty_last_packet);
queue_ptr->finish();

Expand All @@ -458,7 +462,7 @@ class TestTiRemoteBlockInputStream : public testing::Test
{
PacketQueuePtr queue_ptr = std::make_shared<PacketQueue>(1000);
std::vector<Block> source_blocks;
auto writer = std::make_shared<MockWriter>(queue_ptr);
auto writer = std::make_shared<MockWriter>(*dag_context_ptr, queue_ptr);
prepareQueueV2(writer, source_blocks, empty_last_packet);
queue_ptr->finish();
auto receiver_stream = makeExchangeReceiverInputStream(queue_ptr);
Expand Down
20 changes: 5 additions & 15 deletions dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@ BroadcastOrPassThroughWriter<ExchangeWriterPtr>::BroadcastOrPassThroughWriter(
{
rows_in_blocks = 0;
RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock);
chunk_codec_stream = std::make_unique<CHBlockChunkCodec>()->newCodecStream(dag_context.result_field_types);
}

template <class ExchangeWriterPtr>
void BroadcastOrPassThroughWriter<ExchangeWriterPtr>::flush()
{
if (rows_in_blocks > 0)
encodeThenWriteBlocks();
writeBlocks();
}

template <class ExchangeWriterPtr>
Expand All @@ -55,27 +54,18 @@ void BroadcastOrPassThroughWriter<ExchangeWriterPtr>::write(const Block & block)
}

if (static_cast<Int64>(rows_in_blocks) > batch_send_min_limit)
encodeThenWriteBlocks();
writeBlocks();
}

template <class ExchangeWriterPtr>
void BroadcastOrPassThroughWriter<ExchangeWriterPtr>::encodeThenWriteBlocks()
void BroadcastOrPassThroughWriter<ExchangeWriterPtr>::writeBlocks()
{
if (unlikely(blocks.empty()))
return;

auto tracked_packet = std::make_shared<TrackedMppDataPacket>();
while (!blocks.empty())
{
const auto & block = blocks.back();
chunk_codec_stream->encode(block, 0, block.rows());
blocks.pop_back();
tracked_packet->addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
}
assert(blocks.empty());
writer->broadcastOrPassThroughWrite(blocks);
blocks.clear();
rows_in_blocks = 0;
writer->broadcastOrPassThroughWrite(std::move(tracked_packet));
}

template class BroadcastOrPassThroughWriter<MPPTunnelSetPtr>;
Expand Down
3 changes: 1 addition & 2 deletions dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter
void flush() override;

private:
void encodeThenWriteBlocks();
void writeBlocks();

private:
Int64 batch_send_min_limit;
ExchangeWriterPtr writer;
std::vector<Block> blocks;
size_t rows_in_blocks;
std::unique_ptr<ChunkCodecStream> chunk_codec_stream;
};

} // namespace DB
46 changes: 7 additions & 39 deletions dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ FineGrainedShuffleWriter<ExchangeWriterPtr>::FineGrainedShuffleWriter(
partition_num = writer_->getPartitionNum();
RUNTIME_CHECK(partition_num > 0);
RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock);
chunk_codec_stream = std::make_unique<CHBlockChunkCodec>()->newCodecStream(dag_context.result_field_types);
}

template <class ExchangeWriterPtr>
Expand Down Expand Up @@ -110,7 +109,6 @@ void FineGrainedShuffleWriter<ExchangeWriterPtr>::initScatterColumns()
template <class ExchangeWriterPtr>
void FineGrainedShuffleWriter<ExchangeWriterPtr>::batchWriteFineGrainedShuffle()
{
auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num);
if (likely(!blocks.empty()))
{
assert(rows_in_blocks > 0);
Expand All @@ -128,46 +126,16 @@ void FineGrainedShuffleWriter<ExchangeWriterPtr>::batchWriteFineGrainedShuffle()
size_t part_id = 0;
for (size_t bucket_idx = 0; bucket_idx < num_bucket; bucket_idx += fine_grained_shuffle_stream_count, ++part_id)
{
for (uint64_t stream_idx = 0; stream_idx < fine_grained_shuffle_stream_count; ++stream_idx)
{
// assemble scatter columns into a block
MutableColumns columns;
columns.reserve(num_columns);
for (size_t col_id = 0; col_id < num_columns; ++col_id)
columns.emplace_back(std::move(scattered[col_id][bucket_idx + stream_idx]));
auto block = header.cloneWithColumns(std::move(columns));

// encode into packet
chunk_codec_stream->encode(block, 0, block.rows());
tracked_packets[part_id]->addChunk(chunk_codec_stream->getString());
tracked_packets[part_id]->getPacket().add_stream_ids(stream_idx);
chunk_codec_stream->clear();

// disassemble the block back to scatter columns
columns = block.mutateColumns();
for (size_t col_id = 0; col_id < num_columns; ++col_id)
{
columns[col_id]->popBack(columns[col_id]->size()); // clear column
scattered[col_id][bucket_idx + stream_idx] = std::move(columns[col_id]);
}
}
writer->fineGrainedShuffleWrite(
header,
scattered,
bucket_idx,
fine_grained_shuffle_stream_count,
num_columns,
part_id);
}
rows_in_blocks = 0;
}

writePackets(tracked_packets);
}

template <class ExchangeWriterPtr>
void FineGrainedShuffleWriter<ExchangeWriterPtr>::writePackets(TrackedMppDataPacketPtrs & packets)
{
for (size_t part_id = 0; part_id < packets.size(); ++part_id)
{
auto & packet = packets[part_id];
assert(packet);
if (likely(packet->getPacket().chunks_size() > 0))
writer->partitionWrite(std::move(packet), part_id);
}
}

template class FineGrainedShuffleWriter<MPPTunnelSetPtr>;
Expand Down
3 changes: 0 additions & 3 deletions dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ class FineGrainedShuffleWriter : public DAGResponseWriter
private:
void batchWriteFineGrainedShuffle();

void writePackets(TrackedMppDataPacketPtrs & packets);

void initScatterColumns();

private:
Expand All @@ -52,7 +50,6 @@ class FineGrainedShuffleWriter : public DAGResponseWriter
TiDB::TiDBCollators collators;
size_t rows_in_blocks = 0;
uint16_t partition_num;
std::unique_ptr<ChunkCodecStream> chunk_codec_stream;
UInt64 fine_grained_shuffle_stream_count;
UInt64 fine_grained_shuffle_batch_size;

Expand Down
9 changes: 0 additions & 9 deletions dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,6 @@ void scatterColumns(const Block & input_block,
}
}

DB::TrackedMppDataPacketPtrs createPackets(size_t partition_num)
{
DB::TrackedMppDataPacketPtrs tracked_packets;
tracked_packets.reserve(partition_num);
for (size_t i = 0; i < partition_num; ++i)
tracked_packets.emplace_back(std::make_shared<TrackedMppDataPacket>());
return tracked_packets;
}

void scatterColumnsForFineGrainedShuffle(const Block & block,
const std::vector<Int64> & partition_col_ids,
const TiDB::TiDBCollators & collators,
Expand Down
2 changes: 0 additions & 2 deletions dbms/src/Flash/Mpp/HashBaseWriterHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ void computeHash(size_t rows,
std::vector<String> & partition_key_containers,
WeakHash32 & hash);

DB::TrackedMppDataPacketPtrs createPackets(size_t partition_num);

void scatterColumns(const Block & input_block,
const std::vector<Int64> & partition_col_ids,
const TiDB::TiDBCollators & collators,
Expand Down
38 changes: 18 additions & 20 deletions dbms/src/Flash/Mpp/HashPartitionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ HashPartitionWriter<ExchangeWriterPtr>::HashPartitionWriter(
partition_num = writer_->getPartitionNum();
RUNTIME_CHECK(partition_num > 0);
RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock);
chunk_codec_stream = std::make_unique<CHBlockChunkCodec>()->newCodecStream(dag_context.result_field_types);
}

template <class ExchangeWriterPtr>
void HashPartitionWriter<ExchangeWriterPtr>::flush()
{
if (rows_in_blocks > 0)
partitionAndEncodeThenWriteBlocks();
partitionAndWriteBlocks();
}

template <class ExchangeWriterPtr>
Expand All @@ -62,22 +61,23 @@ void HashPartitionWriter<ExchangeWriterPtr>::write(const Block & block)
}

if (static_cast<Int64>(rows_in_blocks) > batch_send_min_limit)
partitionAndEncodeThenWriteBlocks();
partitionAndWriteBlocks();
}

template <class ExchangeWriterPtr>
void HashPartitionWriter<ExchangeWriterPtr>::partitionAndEncodeThenWriteBlocks()
void HashPartitionWriter<ExchangeWriterPtr>::partitionAndWriteBlocks()
{
auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num);
std::vector<Blocks> partition_blocks;
partition_blocks.resize(partition_num);

if (!blocks.empty())
{
assert(rows_in_blocks > 0);

HashBaseWriterHelper::materializeBlocks(blocks);
Block dest_block = blocks[0].cloneEmpty();
std::vector<String> partition_key_containers(collators.size());

Block header = blocks[0].cloneEmpty();
while (!blocks.empty())
{
const auto & block = blocks.back();
Expand All @@ -87,32 +87,30 @@ void HashPartitionWriter<ExchangeWriterPtr>::partitionAndEncodeThenWriteBlocks()

for (size_t part_id = 0; part_id < partition_num; ++part_id)
{
Block dest_block = header.cloneEmpty();
dest_block.setColumns(std::move(dest_tbl_cols[part_id]));
size_t dest_block_rows = dest_block.rows();
if (dest_block_rows > 0)
{
chunk_codec_stream->encode(dest_block, 0, dest_block_rows);
tracked_packets[part_id]->addChunk(chunk_codec_stream->getString());
chunk_codec_stream->clear();
}
if (dest_block.rows() > 0)
partition_blocks[part_id].push_back(std::move(dest_block));
}
}
assert(blocks.empty());
rows_in_blocks = 0;
}

writePackets(tracked_packets);
writePartitionBlocks(partition_blocks);
}

template <class ExchangeWriterPtr>
void HashPartitionWriter<ExchangeWriterPtr>::writePackets(TrackedMppDataPacketPtrs & packets)
void HashPartitionWriter<ExchangeWriterPtr>::writePartitionBlocks(std::vector<Blocks> & partition_blocks)
{
for (size_t part_id = 0; part_id < packets.size(); ++part_id)
for (size_t part_id = 0; part_id < partition_num; ++part_id)
{
auto & packet = packets[part_id];
assert(packet);
if (likely(packet->getPacket().chunks_size() > 0))
writer->partitionWrite(std::move(packet), part_id);
auto & blocks = partition_blocks[part_id];
if (likely(!blocks.empty()))
{
writer->partitionWrite(blocks, part_id);
blocks.clear();
}
}
}

Expand Down
5 changes: 2 additions & 3 deletions dbms/src/Flash/Mpp/HashPartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class HashPartitionWriter : public DAGResponseWriter
void flush() override;

private:
void partitionAndEncodeThenWriteBlocks();
void partitionAndWriteBlocks();

void writePackets(TrackedMppDataPacketPtrs & packets);
void writePartitionBlocks(std::vector<Blocks> & partition_blocks);

private:
Int64 batch_send_min_limit;
Expand All @@ -49,7 +49,6 @@ class HashPartitionWriter : public DAGResponseWriter
TiDB::TiDBCollators collators;
size_t rows_in_blocks;
uint16_t partition_num;
std::unique_ptr<ChunkCodecStream> chunk_codec_stream;
};

} // namespace DB
2 changes: 1 addition & 1 deletion dbms/src/Flash/Mpp/MPPTask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void MPPTask::run()

void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request)
{
auto tunnel_set_local = std::make_shared<MPPTunnelSet>(log->identifier());
auto tunnel_set_local = std::make_shared<MPPTunnelSet>(*dag_context, log->identifier());
std::chrono::seconds timeout(task_request.timeout());
const auto & exchange_sender = dag_req.root_executor().exchange_sender();

Expand Down
Loading