From 012a5432d292121d7bc7857b4ec1fd43331d1c21 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Mon, 16 Jan 2023 13:00:56 +0800 Subject: [PATCH 1/4] refine --- .../gtest_ti_remote_block_inputstream.cpp | 16 +++-- .../Mpp/BroadcastOrPassThroughWriter.cpp | 20 ++---- .../Flash/Mpp/BroadcastOrPassThroughWriter.h | 3 +- .../Flash/Mpp/FineGrainedShuffleWriter.cpp | 46 ++---------- dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h | 3 - dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp | 9 --- dbms/src/Flash/Mpp/HashBaseWriterHelper.h | 2 - dbms/src/Flash/Mpp/HashPartitionWriter.cpp | 34 +++++---- dbms/src/Flash/Mpp/HashPartitionWriter.h | 5 +- dbms/src/Flash/Mpp/MPPTask.cpp | 2 +- dbms/src/Flash/Mpp/MPPTunnelSet.cpp | 54 +++++++++++--- dbms/src/Flash/Mpp/MPPTunnelSet.h | 23 ++++-- dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp | 72 +++++++++++++++++++ dbms/src/Flash/Mpp/MPPTunnelSetHelper.h | 32 +++++++++ .../Mpp/tests/gtest_mpp_exchange_writer.cpp | 46 +++++++++--- 15 files changed, 246 insertions(+), 121 deletions(-) create mode 100644 dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp create mode 100644 dbms/src/Flash/Mpp/MPPTunnelSetHelper.h diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index dbd30a2606a..7e7860509f5 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -58,8 +59,9 @@ bool equalSummaries(const ExecutionSummary & left, const ExecutionSummary & righ struct MockWriter { - explicit MockWriter(PacketQueuePtr queue_) - : queue(queue_) + MockWriter(DAGContext & dag_context, PacketQueuePtr queue_) + : result_field_types(dag_context.result_field_types) + , queue(queue_) {} static ExecutionSummary mockExecutionSummary() @@ -81,9 +83,9 @@ struct MockWriter return summary; } - void partitionWrite(TrackedMppDataPacketPtr &&, uint16_t) { FAIL() << "cannot reach here."; } - void broadcastOrPassThroughWrite(TrackedMppDataPacketPtr && packet) + void broadcastOrPassThroughWrite(Blocks & blocks) { + auto packet = MPPTunnelSetHelper::toPacket(blocks, result_field_types); ++total_packets; if (!packet->packet.chunks().empty()) total_bytes += packet->packet.ByteSizeLong(); @@ -117,6 +119,8 @@ struct MockWriter } uint16_t getPartitionNum() const { return 1; } + std::vector result_field_types; + PacketQueuePtr queue; bool add_summary = false; size_t total_packets = 0; @@ -441,7 +445,7 @@ class TestTiRemoteBlockInputStream : public testing::Test { PacketQueuePtr queue_ptr = std::make_shared(1000); std::vector source_blocks; - auto writer = std::make_shared(queue_ptr); + auto writer = std::make_shared(*dag_context_ptr, queue_ptr); prepareQueue(writer, source_blocks, empty_last_packet); queue_ptr->finish(); @@ -458,7 +462,7 @@ class TestTiRemoteBlockInputStream : public testing::Test { PacketQueuePtr queue_ptr = std::make_shared(1000); std::vector source_blocks; - auto writer = std::make_shared(queue_ptr); + auto writer = std::make_shared(*dag_context_ptr, queue_ptr); prepareQueueV2(writer, source_blocks, empty_last_packet); queue_ptr->finish(); auto receiver_stream = makeExchangeReceiverInputStream(queue_ptr); diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index f8fea4c1c9e..11dae1896d8 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -31,14 +31,13 @@ BroadcastOrPassThroughWriter::BroadcastOrPassThroughWriter( { rows_in_blocks = 0; RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock); - chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } template void BroadcastOrPassThroughWriter::flush() { if (rows_in_blocks > 0) - encodeThenWriteBlocks(); + writeBlocks(); } template @@ -55,27 +54,18 @@ void BroadcastOrPassThroughWriter::write(const Block & block) } if (static_cast(rows_in_blocks) > batch_send_min_limit) - encodeThenWriteBlocks(); + writeBlocks(); } template -void BroadcastOrPassThroughWriter::encodeThenWriteBlocks() +void BroadcastOrPassThroughWriter::writeBlocks() { if (unlikely(blocks.empty())) return; - auto tracked_packet = std::make_shared(); - while (!blocks.empty()) - { - const auto & block = blocks.back(); - chunk_codec_stream->encode(block, 0, block.rows()); - blocks.pop_back(); - tracked_packet->addChunk(chunk_codec_stream->getString()); - chunk_codec_stream->clear(); - } - assert(blocks.empty()); + writer->broadcastOrPassThroughWrite(blocks); + blocks.clear(); rows_in_blocks = 0; - writer->broadcastOrPassThroughWrite(std::move(tracked_packet)); } template class BroadcastOrPassThroughWriter; diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index 322d68541b3..a272ec7f1a4 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -35,14 +35,13 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter void flush() override; private: - void encodeThenWriteBlocks(); + void writeBlocks(); private: Int64 batch_send_min_limit; ExchangeWriterPtr writer; std::vector blocks; size_t rows_in_blocks; - std::unique_ptr chunk_codec_stream; }; } // namespace DB diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp index b5bb5852c5e..ad6ac0d6faf 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp @@ -42,7 +42,6 @@ FineGrainedShuffleWriter::FineGrainedShuffleWriter( partition_num = writer_->getPartitionNum(); RUNTIME_CHECK(partition_num > 0); RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock); - chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } template @@ -110,7 +109,6 @@ void FineGrainedShuffleWriter::initScatterColumns() template void FineGrainedShuffleWriter::batchWriteFineGrainedShuffle() { - auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num); if (likely(!blocks.empty())) { assert(rows_in_blocks > 0); @@ -128,46 +126,16 @@ void FineGrainedShuffleWriter::batchWriteFineGrainedShuffle() size_t part_id = 0; for (size_t bucket_idx = 0; bucket_idx < num_bucket; bucket_idx += fine_grained_shuffle_stream_count, ++part_id) { - for (uint64_t stream_idx = 0; stream_idx < fine_grained_shuffle_stream_count; ++stream_idx) - { - // assemble scatter columns into a block - MutableColumns columns; - columns.reserve(num_columns); - for (size_t col_id = 0; col_id < num_columns; ++col_id) - columns.emplace_back(std::move(scattered[col_id][bucket_idx + stream_idx])); - auto block = header.cloneWithColumns(std::move(columns)); - - // encode into packet - chunk_codec_stream->encode(block, 0, block.rows()); - tracked_packets[part_id]->addChunk(chunk_codec_stream->getString()); - tracked_packets[part_id]->getPacket().add_stream_ids(stream_idx); - chunk_codec_stream->clear(); - - // disassemble the block back to scatter columns - columns = block.mutateColumns(); - for (size_t col_id = 0; col_id < num_columns; ++col_id) - { - columns[col_id]->popBack(columns[col_id]->size()); // clear column - scattered[col_id][bucket_idx + stream_idx] = std::move(columns[col_id]); - } - } + writer->fineGrainedShuffleWrite( + header, + scattered, + bucket_idx, + fine_grained_shuffle_stream_count, + num_columns, + part_id); } rows_in_blocks = 0; } - - writePackets(tracked_packets); -} - -template -void FineGrainedShuffleWriter::writePackets(TrackedMppDataPacketPtrs & packets) -{ - for (size_t part_id = 0; part_id < packets.size(); ++part_id) - { - auto & packet = packets[part_id]; - assert(packet); - if (likely(packet->getPacket().chunks_size() > 0)) - writer->partitionWrite(std::move(packet), part_id); - } } template class FineGrainedShuffleWriter; diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h index 6b2db46770c..e7b5e7603df 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h @@ -41,8 +41,6 @@ class FineGrainedShuffleWriter : public DAGResponseWriter private: void batchWriteFineGrainedShuffle(); - void writePackets(TrackedMppDataPacketPtrs & packets); - void initScatterColumns(); private: @@ -52,7 +50,6 @@ class FineGrainedShuffleWriter : public DAGResponseWriter TiDB::TiDBCollators collators; size_t rows_in_blocks = 0; uint16_t partition_num; - std::unique_ptr chunk_codec_stream; UInt64 fine_grained_shuffle_stream_count; UInt64 fine_grained_shuffle_batch_size; diff --git a/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp b/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp index 993688bd79d..5698f27424e 100644 --- a/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp +++ b/dbms/src/Flash/Mpp/HashBaseWriterHelper.cpp @@ -142,15 +142,6 @@ void scatterColumns(const Block & input_block, } } -DB::TrackedMppDataPacketPtrs createPackets(size_t partition_num) -{ - DB::TrackedMppDataPacketPtrs tracked_packets; - tracked_packets.reserve(partition_num); - for (size_t i = 0; i < partition_num; ++i) - tracked_packets.emplace_back(std::make_shared()); - return tracked_packets; -} - void scatterColumnsForFineGrainedShuffle(const Block & block, const std::vector & partition_col_ids, const TiDB::TiDBCollators & collators, diff --git a/dbms/src/Flash/Mpp/HashBaseWriterHelper.h b/dbms/src/Flash/Mpp/HashBaseWriterHelper.h index 579e288d2cc..5684e59177d 100644 --- a/dbms/src/Flash/Mpp/HashBaseWriterHelper.h +++ b/dbms/src/Flash/Mpp/HashBaseWriterHelper.h @@ -39,8 +39,6 @@ void computeHash(size_t rows, std::vector & partition_key_containers, WeakHash32 & hash); -DB::TrackedMppDataPacketPtrs createPackets(size_t partition_num); - void scatterColumns(const Block & input_block, const std::vector & partition_col_ids, const TiDB::TiDBCollators & collators, diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index 7d48b43484b..945453619ea 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -38,14 +38,13 @@ HashPartitionWriter::HashPartitionWriter( partition_num = writer_->getPartitionNum(); RUNTIME_CHECK(partition_num > 0); RUNTIME_CHECK(dag_context.encode_type == tipb::EncodeType::TypeCHBlock); - chunk_codec_stream = std::make_unique()->newCodecStream(dag_context.result_field_types); } template void HashPartitionWriter::flush() { if (rows_in_blocks > 0) - partitionAndEncodeThenWriteBlocks(); + partitionAndWriteBlocks(); } template @@ -62,20 +61,20 @@ void HashPartitionWriter::write(const Block & block) } if (static_cast(rows_in_blocks) > batch_send_min_limit) - partitionAndEncodeThenWriteBlocks(); + partitionAndWriteBlocks(); } template -void HashPartitionWriter::partitionAndEncodeThenWriteBlocks() +void HashPartitionWriter::partitionAndWriteBlocks() { - auto tracked_packets = HashBaseWriterHelper::createPackets(partition_num); + std::vector partition_blocks; + partition_blocks.resize(partition_num); if (!blocks.empty()) { assert(rows_in_blocks > 0); HashBaseWriterHelper::materializeBlocks(blocks); - Block dest_block = blocks[0].cloneEmpty(); std::vector partition_key_containers(collators.size()); while (!blocks.empty()) @@ -87,32 +86,31 @@ void HashPartitionWriter::partitionAndEncodeThenWriteBlocks() for (size_t part_id = 0; part_id < partition_num; ++part_id) { + Block dest_block = blocks[0].cloneEmpty(); dest_block.setColumns(std::move(dest_tbl_cols[part_id])); size_t dest_block_rows = dest_block.rows(); if (dest_block_rows > 0) - { - chunk_codec_stream->encode(dest_block, 0, dest_block_rows); - tracked_packets[part_id]->addChunk(chunk_codec_stream->getString()); - chunk_codec_stream->clear(); - } + partition_blocks[part_id].push_back(std::move(dest_block)); } } assert(blocks.empty()); rows_in_blocks = 0; } - writePackets(tracked_packets); + writePartitionBlocks(partition_blocks); } template -void HashPartitionWriter::writePackets(TrackedMppDataPacketPtrs & packets) +void HashPartitionWriter::writePartitionBlocks(std::vector & partition_blocks) { - for (size_t part_id = 0; part_id < packets.size(); ++part_id) + for (size_t part_id = 0; part_id < partition_num; ++part_id) { - auto & packet = packets[part_id]; - assert(packet); - if (likely(packet->getPacket().chunks_size() > 0)) - writer->partitionWrite(std::move(packet), part_id); + auto & blocks = partition_blocks[part_id]; + if (likely(blocks.size() > 0)) + { + writer->partitionWrite(blocks, part_id); + blocks.clear(); + } } } diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.h b/dbms/src/Flash/Mpp/HashPartitionWriter.h index deebcd3dce7..f90dc4ddb7f 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.h +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.h @@ -37,9 +37,9 @@ class HashPartitionWriter : public DAGResponseWriter void flush() override; private: - void partitionAndEncodeThenWriteBlocks(); + void partitionAndWriteBlocks(); - void writePackets(TrackedMppDataPacketPtrs & packets); + void writePartitionBlocks(std::vector & partition_blocks); private: Int64 batch_send_min_limit; @@ -49,7 +49,6 @@ class HashPartitionWriter : public DAGResponseWriter TiDB::TiDBCollators collators; size_t rows_in_blocks; uint16_t partition_num; - std::unique_ptr chunk_codec_stream; }; } // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 5f251348d67..81988dfd795 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -154,7 +154,7 @@ void MPPTask::run() void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) { - auto tunnel_set_local = std::make_shared(log->identifier()); + auto tunnel_set_local = std::make_shared(*dag_context, log->identifier()); std::chrono::seconds timeout(task_request.timeout()); const auto & exchange_sender = dag_req.root_executor().exchange_sender(); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index cff6e72e7e0..55d0f7ab9be 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -14,7 +14,9 @@ #include #include +#include #include +#include #include #include #include @@ -26,8 +28,7 @@ namespace void checkPacketSize(size_t size) { static constexpr size_t max_packet_size = 1u << 31; - if (unlikely(size >= max_packet_size)) - throw Exception(fmt::format("Packet is too large to send, size : {}", size)); + RUNTIME_CHECK(size < max_packet_size, fmt::format("Packet is too large to send, size : {}", size)); } TrackedMppDataPacketPtr serializePacket(const tipb::SelectResponse & response) @@ -39,6 +40,12 @@ TrackedMppDataPacketPtr serializePacket(const tipb::SelectResponse & response) } } // namespace +template +MPPTunnelSetBase::MPPTunnelSetBase(DAGContext & dag_context, const String & req_id) + : log(Logger::get(req_id)) + , result_field_types(dag_context.result_field_types) +{} + template void MPPTunnelSetBase::sendExecutionSummary(const tipb::SelectResponse & response) { @@ -56,21 +63,50 @@ void MPPTunnelSetBase::write(tipb::SelectResponse & response) } template -void MPPTunnelSetBase::broadcastOrPassThroughWrite(TrackedMppDataPacketPtr && packet) +void MPPTunnelSetBase::broadcastOrPassThroughWrite(Blocks & blocks) { - checkPacketSize(packet->getPacket().ByteSizeLong()); RUNTIME_CHECK(!tunnels.empty()); + auto tracked_packet = MPPTunnelSetHelper::toPacket(blocks, result_field_types); + checkPacketSize(tracked_packet->getPacket().ByteSizeLong()); + // TODO avoid copy packet for broadcast. for (size_t i = 1; i < tunnels.size(); ++i) - tunnels[i]->write(packet->copy()); - tunnels[0]->write(std::move(packet)); + tunnels[i]->write(tracked_packet->copy()); + tunnels[0]->write(std::move(tracked_packet)); +} + +template +void MPPTunnelSetBase::partitionWrite(Blocks & blocks, int16_t partition_id) +{ + auto tracked_packet = MPPTunnelSetHelper::toPacket(blocks, result_field_types); + if (likely(tracked_packet->getPacket().chunks_size() > 0)) + { + checkPacketSize(tracked_packet->getPacket().ByteSizeLong()); + tunnels[partition_id]->write(std::move(tracked_packet)); + } } template -void MPPTunnelSetBase::partitionWrite(TrackedMppDataPacketPtr && packet, int16_t partition_id) +void MPPTunnelSetBase::fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id) { - checkPacketSize(packet->getPacket().ByteSizeLong()); - tunnels[partition_id]->write(std::move(packet)); + auto tracked_packet = MPPTunnelSetHelper::toFineGrainedPacket( + header, + scattered, + bucket_idx, + fine_grained_shuffle_stream_count, + num_columns, + result_field_types); + if (likely(tracked_packet->getPacket().chunks_size() > 0)) + { + checkPacketSize(tracked_packet->getPacket().ByteSizeLong()); + tunnels[partition_id]->write(std::move(tracked_packet)); + } } template diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index 86c96355d0e..ac0431eef0a 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -24,27 +24,36 @@ #ifdef __clang__ #pragma clang diagnostic pop #endif + #include namespace DB { +class DAGContext; + template class MPPTunnelSetBase : private boost::noncopyable { public: using TunnelPtr = std::shared_ptr; - explicit MPPTunnelSetBase(const String & req_id) - : log(Logger::get(req_id)) - {} + MPPTunnelSetBase(DAGContext & dag_context, const String & req_id); // this is a root mpp writing. void write(tipb::SelectResponse & response); // this is a broadcast or pass through writing. - void broadcastOrPassThroughWrite(TrackedMppDataPacketPtr && packet); + void broadcastOrPassThroughWrite(Blocks & blocks); // this is a partition writing. - void partitionWrite(TrackedMppDataPacketPtr && packet, int16_t partition_id); + void partitionWrite(Blocks & blocks, int16_t partition_id); + // this is a fine grained shuffle writing. + void fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id); /// this is a execution summary writing. - /// for both broadcast writing and partition writing, only + /// for both broadcast writing and partition/fine grained shuffle writing, only /// return meaningful execution summary for the first tunnel, /// because in TiDB, it does not know enough information /// about the execution details for the mpp query, it just @@ -74,6 +83,8 @@ class MPPTunnelSetBase : private boost::noncopyable std::unordered_map receiver_task_id_to_index_map; const LoggerPtr log; + std::vector result_field_types; + int external_thread_cnt = 0; }; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp new file mode 100644 index 00000000000..ef80177a105 --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp @@ -0,0 +1,72 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +namespace DB::MPPTunnelSetHelper +{ +TrackedMppDataPacketPtr toPacket(Blocks & blocks, const std::vector & field_types) +{ + CHBlockChunkCodec codec; + auto codec_stream = codec.newCodecStream(field_types); + auto tracked_packet = std::make_shared(); + while (!blocks.empty()) + { + const auto & block = blocks.back(); + codec_stream->encode(block, 0, block.rows()); + blocks.pop_back(); + tracked_packet->addChunk(codec_stream->getString()); + codec_stream->clear(); + } + return tracked_packet; +} + +TrackedMppDataPacketPtr toFineGrainedPacket( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + const std::vector & field_types) +{ + CHBlockChunkCodec codec; + auto codec_stream = codec.newCodecStream(field_types); + auto tracked_packet = std::make_shared(); + for (uint64_t stream_idx = 0; stream_idx < fine_grained_shuffle_stream_count; ++stream_idx) + { + // assemble scatter columns into a block + MutableColumns columns; + columns.reserve(num_columns); + for (size_t col_id = 0; col_id < num_columns; ++col_id) + columns.emplace_back(std::move(scattered[col_id][bucket_idx + stream_idx])); + auto block = header.cloneWithColumns(std::move(columns)); + + // encode into packet + codec_stream->encode(block, 0, block.rows()); + tracked_packet->addChunk(codec_stream->getString()); + tracked_packet->getPacket().add_stream_ids(stream_idx); + codec_stream->clear(); + + // disassemble the block back to scatter columns + columns = block.mutateColumns(); + for (size_t col_id = 0; col_id < num_columns; ++col_id) + { + columns[col_id]->popBack(columns[col_id]->size()); // clear column + scattered[col_id][bucket_idx + stream_idx] = std::move(columns[col_id]); + } + } + return tracked_packet; +} +} // namespace DB::HashBaseWriterHelper diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h new file mode 100644 index 00000000000..d4414cad5c4 --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h @@ -0,0 +1,32 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB::MPPTunnelSetHelper +{ +TrackedMppDataPacketPtr toPacket(Blocks & blocks, const std::vector & field_types); + +TrackedMppDataPacketPtr toFineGrainedPacket( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + const std::vector & field_types); +} // namespace DB::HashBaseWriterHelper diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index 95660c67bc4..08c53b51aa5 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -104,6 +105,7 @@ class TestMPPExchangeWriter : public testing::Test return block; } + Context context; std::vector part_col_ids; TiDB::TiDBCollators part_col_collators; @@ -115,14 +117,41 @@ using MockExchangeWriterChecker = std::function & scattered, + size_t bucket_idx, + uint16_t fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t part_id) + { + auto tracked_packet = MPPTunnelSetHelper::toFineGrainedPacket( + header, + scattered, + bucket_idx, + fine_grained_shuffle_stream_count, + num_columns, + result_field_types); + checker(tracked_packet, part_id); + } + void write(tipb::SelectResponse &) { FAIL() << "cannot reach here, only consider CH Block format"; } void sendExecutionSummary(const tipb::SelectResponse & response) { @@ -135,6 +164,7 @@ struct MockExchangeWriter private: MockExchangeWriterChecker checker; uint16_t part_num; + std::vector result_field_types; }; // Input block data is distributed uniform. @@ -160,7 +190,7 @@ try // batchWriteFineGrainedShuffle() only called once, so will only be one packet for each partition. ASSERT_TRUE(res.second); }; - auto mock_writer = std::make_shared(checker, part_num); + auto mock_writer = std::make_shared(checker, part_num, *dag_context_ptr); // 3. Start to write. auto dag_writer = std::make_shared>>( @@ -219,7 +249,7 @@ try auto checker = [&write_report](const TrackedMppDataPacketPtr & packet, uint16_t part_id) { write_report[part_id].emplace_back(packet); }; - auto mock_writer = std::make_shared(checker, part_num); + auto mock_writer = std::make_shared(checker, part_num, *dag_context_ptr); // 3. Start to write. auto dag_writer = std::make_shared>>( @@ -281,7 +311,7 @@ try auto checker = [&write_report](const TrackedMppDataPacketPtr & packet, uint16_t part_id) { write_report[part_id].emplace_back(packet); }; - auto mock_writer = std::make_shared(checker, part_num); + auto mock_writer = std::make_shared(checker, part_num, *dag_context_ptr); // 3. Start to write. auto dag_writer = std::make_shared>>( @@ -335,7 +365,7 @@ try ASSERT_EQ(part_id, 0); write_report.emplace_back(packet); }; - auto mock_writer = std::make_shared(checker, 1); + auto mock_writer = std::make_shared(checker, 1, *dag_context_ptr); // 3. Start to write. auto dag_writer = std::make_shared>>( From a221d239d4aded17cb64f666ee21db4b7cb86e6b Mon Sep 17 00:00:00 2001 From: SeaRise Date: Mon, 16 Jan 2023 13:01:34 +0800 Subject: [PATCH 2/4] fmt --- dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp | 4 ++-- dbms/src/Flash/Mpp/MPPTunnelSetHelper.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp index ef80177a105..d17f52b8284 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include namespace DB::MPPTunnelSetHelper { @@ -69,4 +69,4 @@ TrackedMppDataPacketPtr toFineGrainedPacket( } return tracked_packet; } -} // namespace DB::HashBaseWriterHelper +} // namespace DB::MPPTunnelSetHelper diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h index d4414cad5c4..38bddcef962 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.h @@ -29,4 +29,4 @@ TrackedMppDataPacketPtr toFineGrainedPacket( UInt64 fine_grained_shuffle_stream_count, size_t num_columns, const std::vector & field_types); -} // namespace DB::HashBaseWriterHelper +} // namespace DB::MPPTunnelSetHelper From 1bdd906f6f26a6e30ada29e55c6f38c4fc0d5523 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Mon, 16 Jan 2023 13:06:36 +0800 Subject: [PATCH 3/4] fix --- dbms/src/Flash/Mpp/HashPartitionWriter.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index 945453619ea..d901ca4151c 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -77,6 +77,7 @@ void HashPartitionWriter::partitionAndWriteBlocks() HashBaseWriterHelper::materializeBlocks(blocks); std::vector partition_key_containers(collators.size()); + Block header = blocks[0].cloneEmpty(); while (!blocks.empty()) { const auto & block = blocks.back(); @@ -86,7 +87,7 @@ void HashPartitionWriter::partitionAndWriteBlocks() for (size_t part_id = 0; part_id < partition_num; ++part_id) { - Block dest_block = blocks[0].cloneEmpty(); + Block dest_block = header.cloneEmpty(); dest_block.setColumns(std::move(dest_tbl_cols[part_id])); size_t dest_block_rows = dest_block.rows(); if (dest_block_rows > 0) From 640d686e95573dd62233b1223490053b7f7467c6 Mon Sep 17 00:00:00 2001 From: SeaRise Date: Mon, 16 Jan 2023 13:11:46 +0800 Subject: [PATCH 4/4] refine --- dbms/src/Flash/Mpp/HashPartitionWriter.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index d901ca4151c..7ae30a1b4e7 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -89,8 +89,7 @@ void HashPartitionWriter::partitionAndWriteBlocks() { Block dest_block = header.cloneEmpty(); dest_block.setColumns(std::move(dest_tbl_cols[part_id])); - size_t dest_block_rows = dest_block.rows(); - if (dest_block_rows > 0) + if (dest_block.rows() > 0) partition_blocks[part_id].push_back(std::move(dest_block)); } } @@ -107,7 +106,7 @@ void HashPartitionWriter::writePartitionBlocks(std::vector 0)) + if (likely(!blocks.empty())) { writer->partitionWrite(blocks, part_id); blocks.clear();