diff --git a/dbms/src/Common/GRPCQueue.h b/dbms/src/Common/GRPCQueue.h index 497dab3a4ad..d0de1e329cf 100644 --- a/dbms/src/Common/GRPCQueue.h +++ b/dbms/src/Common/GRPCQueue.h @@ -138,6 +138,7 @@ class GRPCSendQueue } bool isWritable() const { return send_queue.isWritable(); } + void notifyNextPipelineWriter() { send_queue.notifyNextPipelineWriter(); } void registerPipeReadTask(TaskPtr && task) { send_queue.registerPipeReadTask(std::move(task)); } void registerPipeWriteTask(TaskPtr && task) { send_queue.registerPipeWriteTask(std::move(task)); } @@ -299,6 +300,7 @@ class GRPCRecvQueue } bool isWritable() const { return recv_queue.isWritable(); } + void notifyNextPipelineWriter() { return recv_queue.notifyNextPipelineWriter(); } void registerPipeReadTask(TaskPtr && task) { recv_queue.registerPipeReadTask(std::move(task)); } void registerPipeWriteTask(TaskPtr && task) { recv_queue.registerPipeWriteTask(std::move(task)); } diff --git a/dbms/src/Common/LooseBoundedMPMCQueue.h b/dbms/src/Common/LooseBoundedMPMCQueue.h index bb9920c9aa3..6611d9bf37e 100644 --- a/dbms/src/Common/LooseBoundedMPMCQueue.h +++ b/dbms/src/Common/LooseBoundedMPMCQueue.h @@ -220,6 +220,8 @@ class LooseBoundedMPMCQueue return !isFullWithoutLock(); } + void notifyNextPipelineWriter() { pipe_writer_cv.notifyOne(); } + MPMCQueueStatus getStatus() const { std::lock_guard lock(mu); diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h index b9a756e365d..5a58707c2ef 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h @@ -29,7 +29,13 @@ class DAGResponseWriter DAGResponseWriter(Int64 records_per_chunk_, DAGContext & dag_context_); /// prepared with sample block virtual void prepare(const Block &){}; - virtual void write(const Block & block) = 0; + void write(const Block & block) + { + if (!doWrite(block)) + { + notifyNextPipelineWriter(); + } + } // For async writer, `waitForWritable` need to be called before calling `write`. // ``` @@ -40,10 +46,23 @@ class DAGResponseWriter virtual WaitResult waitForWritable() const { throw Exception("Unsupport"); } /// flush cached blocks for batch writer - virtual void flush() = 0; + void flush() + { + if (!doFlush()) + { + notifyNextPipelineWriter(); + } + } + virtual ~DAGResponseWriter() = default; protected: + // return true if write is actually write the data + virtual bool doWrite(const Block & block) = 0; + // return true if flush is actually flush data + virtual bool doFlush() = 0; + virtual void notifyNextPipelineWriter() = 0; + Int64 records_per_chunk; DAGContext & dag_context; }; diff --git a/dbms/src/Flash/Coprocessor/StreamWriter.h b/dbms/src/Flash/Coprocessor/StreamWriter.h index 41383ee49c2..061ebd75cd9 100644 --- a/dbms/src/Flash/Coprocessor/StreamWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamWriter.h @@ -59,6 +59,7 @@ struct CopStreamWriter throw Exception("Failed to write resp"); } static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } + static void notifyNextPipelineWriter() {} }; struct BatchCopStreamWriter @@ -83,6 +84,7 @@ struct BatchCopStreamWriter throw Exception("Failed to write resp"); } static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } + static void notifyNextPipelineWriter() {} }; using CopStreamWriterPtr = std::shared_ptr; diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index a6f39cb25dc..2c8939bf138 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -61,10 +61,14 @@ StreamingDAGResponseWriter::StreamingDAGResponseWriter( } template -void StreamingDAGResponseWriter::flush() +bool StreamingDAGResponseWriter::doFlush() { if (rows_in_blocks > 0) + { encodeThenWriteBlocks(); + return true; + } + return false; } template @@ -74,7 +78,13 @@ WaitResult StreamingDAGResponseWriter::waitForWritable() const } template -void StreamingDAGResponseWriter::write(const Block & block) +void StreamingDAGResponseWriter::notifyNextPipelineWriter() +{ + return writer->notifyNextPipelineWriter(); +} + +template +bool StreamingDAGResponseWriter::doWrite(const Block & block) { RUNTIME_CHECK_MSG( block.columns() == dag_context.result_field_types.size(), @@ -87,14 +97,17 @@ void StreamingDAGResponseWriter::write(const Block & block) } if (static_cast(rows_in_blocks) > batch_send_min_limit) + { encodeThenWriteBlocks(); + return true; + } + return false; } template void StreamingDAGResponseWriter::encodeThenWriteBlocks() { - if (unlikely(blocks.empty())) - return; + assert(!blocks.empty()); TrackedSelectResp response; response.setEncodeType(dag_context.encode_type); diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index 61ca9a71517..4c0ebd15bc5 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -36,9 +36,12 @@ class StreamingDAGResponseWriter : public DAGResponseWriter Int64 records_per_chunk_, Int64 batch_send_min_limit_, DAGContext & dag_context_); - void write(const Block & block) override; WaitResult waitForWritable() const override; - void flush() override; + +protected: + bool doWrite(const Block & block) override; + bool doFlush() override; + void notifyNextPipelineWriter() override; private: void encodeThenWriteBlocks(); diff --git a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp index b616cc05ecb..0803064bfc1 100644 --- a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.cpp @@ -71,7 +71,7 @@ void UnaryDAGResponseWriter::appendWarningsToDAGResponse() dag_response->set_warning_count(dag_context.getWarningCount()); } -void UnaryDAGResponseWriter::flush() +bool UnaryDAGResponseWriter::doFlush() { if (current_records_num > 0) { @@ -86,9 +86,10 @@ void UnaryDAGResponseWriter::flush() throw TiFlashException( "DAG response is too big, please check config about region size or region merge scheduler", Errors::Coprocessor::Internal); + return true; } -void UnaryDAGResponseWriter::write(const Block & block) +bool UnaryDAGResponseWriter::doWrite(const Block & block) { if (block.columns() != dag_context.result_field_types.size()) throw TiFlashException("Output column size mismatch with field type size", Errors::Coprocessor::Internal); @@ -116,5 +117,6 @@ void UnaryDAGResponseWriter::write(const Block & block) row_index = upper; } } + return true; } } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h index 2d718fa1104..f5480011f72 100644 --- a/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/UnaryDAGResponseWriter.h @@ -33,11 +33,14 @@ class UnaryDAGResponseWriter : public DAGResponseWriter public: UnaryDAGResponseWriter(tipb::SelectResponse * response_, Int64 records_per_chunk_, DAGContext & dag_context_); - void write(const Block & block) override; - void flush() override; void encodeChunkToDAGResponse(); void appendWarningsToDAGResponse(); +protected: + bool doWrite(const Block & block) override; + bool doFlush() override; + void notifyNextPipelineWriter() override{}; + private: tipb::SelectResponse * dag_response; std::unique_ptr chunk_codec_stream; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp index a3351f294ca..5dcc8249df2 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp @@ -94,6 +94,7 @@ struct MockStreamWriter void write(tipb::SelectResponse & response) { checker(response); } static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } + static void notifyNextPipelineWriter() {} private: MockStreamWriterChecker checker; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index dafd6f8ce93..af067ffdf39 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -148,6 +148,7 @@ struct MockWriter } static uint16_t getPartitionNum() { return 1; } static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } + static void notifyNextPipelineWriter() {} std::vector result_field_types; diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index 1b33e73019a..c15a670ed02 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -44,6 +44,8 @@ BroadcastOrPassThroughWriter::BroadcastOrPassThroughWriter( switch (data_codec_version) { case MPPDataPacketV0: + if (batch_send_min_limit <= 0) + batch_send_min_limit = 1; break; case MPPDataPacketV1: default: @@ -64,10 +66,14 @@ BroadcastOrPassThroughWriter::BroadcastOrPassThroughWriter( } template -void BroadcastOrPassThroughWriter::flush() +bool BroadcastOrPassThroughWriter::doFlush() { if (rows_in_blocks > 0) + { writeBlocks(); + return true; + } + return false; } template @@ -77,7 +83,13 @@ WaitResult BroadcastOrPassThroughWriter::waitForWritable() co } template -void BroadcastOrPassThroughWriter::write(const Block & block) +void BroadcastOrPassThroughWriter::notifyNextPipelineWriter() +{ + writer->notifyNextPipelineWriter(); +} + +template +bool BroadcastOrPassThroughWriter::doWrite(const Block & block) { RUNTIME_CHECK(!block.info.selective); RUNTIME_CHECK_MSG( @@ -90,15 +102,18 @@ void BroadcastOrPassThroughWriter::write(const Block & block) blocks.push_back(block); } - if (static_cast(rows_in_blocks) > batch_send_min_limit) + if (static_cast(rows_in_blocks) >= batch_send_min_limit) + { writeBlocks(); + return true; + } + return false; } template void BroadcastOrPassThroughWriter::writeBlocks() { - if unlikely (blocks.empty()) - return; + assert(!blocks.empty()); // check schema if (!expected_types.empty()) diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index be615c4c21c..12e9668c903 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -36,9 +36,12 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_, tipb::ExchangeType exchange_type_); - void write(const Block & block) override; WaitResult waitForWritable() const override; - void flush() override; + +protected: + bool doWrite(const Block & block) override; + bool doFlush() override; + void notifyNextPipelineWriter() override; private: void writeBlocks(); diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp index 37387f7e23c..81941b00f2d 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp @@ -90,10 +90,20 @@ void FineGrainedShuffleWriter::prepare(const Block & sample_b } template -void FineGrainedShuffleWriter::flush() +bool FineGrainedShuffleWriter::doFlush() { if (rows_in_blocks > 0) + { batchWriteFineGrainedShuffle(); + return true; + } + return false; +} + +template +void FineGrainedShuffleWriter::notifyNextPipelineWriter() +{ + writer->notifyNextPipelineWriter(); } template @@ -103,7 +113,7 @@ WaitResult FineGrainedShuffleWriter::waitForWritable() const } template -void FineGrainedShuffleWriter::write(const Block & block) +bool FineGrainedShuffleWriter::doWrite(const Block & block) { RUNTIME_CHECK_MSG(prepared, "FineGrainedShuffleWriter should be prepared before writing."); RUNTIME_CHECK_MSG( @@ -124,7 +134,11 @@ void FineGrainedShuffleWriter::write(const Block & block) if (blocks.size() == fine_grained_shuffle_stream_count || static_cast(rows_in_blocks) >= batch_send_row_limit) + { batchWriteFineGrainedShuffle(); + return true; + } + return false; } template @@ -148,8 +162,7 @@ template template void FineGrainedShuffleWriter::batchWriteFineGrainedShuffleImpl() { - if (blocks.empty()) - return; + assert(!blocks.empty()); { assert(rows_in_blocks > 0); diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h index 5bdb5a52e77..144659ad9f8 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h @@ -39,9 +39,12 @@ class FineGrainedShuffleWriter : public DAGResponseWriter MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_); void prepare(const Block & sample_block) override; - void write(const Block & block) override; WaitResult waitForWritable() const override; - void flush() override; + +protected: + bool doWrite(const Block & block) override; + bool doFlush() override; + void notifyNextPipelineWriter() override; private: void batchWriteFineGrainedShuffle(); diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index fd501015663..5422fc1fe07 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -52,6 +52,8 @@ HashPartitionWriter::HashPartitionWriter( switch (data_codec_version) { case MPPDataPacketV0: + if (batch_send_min_limit <= 0) + batch_send_min_limit = 1; break; case MPPDataPacketV1: default: @@ -72,10 +74,10 @@ HashPartitionWriter::HashPartitionWriter( } template -void HashPartitionWriter::flush() +bool HashPartitionWriter::doFlush() { if (0 == rows_in_blocks) - return; + return false; switch (data_codec_version) { @@ -91,6 +93,13 @@ void HashPartitionWriter::flush() break; } } + return true; +} + +template +void HashPartitionWriter::notifyNextPipelineWriter() +{ + writer->notifyNextPipelineWriter(); } template @@ -100,7 +109,7 @@ WaitResult HashPartitionWriter::waitForWritable() const } template -void HashPartitionWriter::writeImplV1(const Block & block) +bool HashPartitionWriter::writeImplV1(const Block & block) { size_t rows = 0; if (block.info.selective) @@ -116,11 +125,15 @@ void HashPartitionWriter::writeImplV1(const Block & block) } if (static_cast(rows_in_blocks) >= batch_send_min_limit || mem_size_in_blocks >= MAX_BATCH_SEND_MIN_LIMIT_MEM_SIZE) + { partitionAndWriteBlocksV1(); + return true; + } + return false; } template -void HashPartitionWriter::writeImpl(const Block & block) +bool HashPartitionWriter::writeImpl(const Block & block) { size_t rows = 0; if (block.info.selective) @@ -133,12 +146,16 @@ void HashPartitionWriter::writeImpl(const Block & block) rows_in_blocks += rows; blocks.push_back(block); } - if (static_cast(rows_in_blocks) > batch_send_min_limit) + if (static_cast(rows_in_blocks) >= batch_send_min_limit) + { partitionAndWriteBlocks(); + return true; + } + return false; } template -void HashPartitionWriter::write(const Block & block) +bool HashPartitionWriter::doWrite(const Block & block) { RUNTIME_CHECK_MSG( block.columns() == dag_context.result_field_types.size(), @@ -228,8 +245,7 @@ void HashPartitionWriter::partitionAndWriteBlocksV1() template void HashPartitionWriter::partitionAndWriteBlocks() { - if unlikely (blocks.empty()) - return; + assert(!blocks.empty()); std::vector partition_blocks; partition_blocks.resize(partition_num); @@ -282,11 +298,8 @@ void HashPartitionWriter::writePartitionBlocks(std::vectorpartitionWrite(blocks, part_id); - blocks.clear(); - } + writer->partitionWrite(blocks, part_id); + blocks.clear(); } } diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.h b/dbms/src/Flash/Mpp/HashPartitionWriter.h index 8e36d28234d..102df7dc283 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.h +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.h @@ -37,13 +37,17 @@ class HashPartitionWriter : public DAGResponseWriter DAGContext & dag_context_, MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_); - void write(const Block & block) override; WaitResult waitForWritable() const override; - void flush() override; + +protected: + bool doWrite(const Block & block) override; + bool doFlush() override; + void notifyNextPipelineWriter() override; + private: - void writeImpl(const Block & block); - void writeImplV1(const Block & block); + bool writeImpl(const Block & block); + bool writeImplV1(const Block & block); void partitionAndWriteBlocks(); void partitionAndWriteBlocksV1(); diff --git a/dbms/src/Flash/Mpp/LocalRequestHandler.h b/dbms/src/Flash/Mpp/LocalRequestHandler.h index a6422d79880..e598f67bae8 100644 --- a/dbms/src/Flash/Mpp/LocalRequestHandler.h +++ b/dbms/src/Flash/Mpp/LocalRequestHandler.h @@ -41,6 +41,7 @@ struct LocalRequestHandler } bool isWritable() const { return msg_queue->isWritable(); } + void notifyNextPipelineWriter() const { return msg_queue->notifyNextPipelineWriter(); } void registerPipeReadTask(TaskPtr && task) const { msg_queue->registerPipeReadTask(std::move(task)); } void registerPipeWriteTask(TaskPtr && task) const { msg_queue->registerPipeWriteTask(std::move(task)); } diff --git a/dbms/src/Flash/Mpp/MPPTunnel.h b/dbms/src/Flash/Mpp/MPPTunnel.h index 4c2421437e4..e60077403a2 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.h +++ b/dbms/src/Flash/Mpp/MPPTunnel.h @@ -119,6 +119,7 @@ class TunnelSender virtual bool finish() = 0; virtual bool isWritable() const = 0; + virtual void notifyNextPipelineWriter() = 0; void consumerFinish(const String & err_msg); String getConsumerFinishMsg() { return consumer_state.getMsg(); } @@ -197,6 +198,8 @@ class SyncTunnelSender : public TunnelSender bool isWritable() const override { return send_queue.isWritable(); } + void notifyNextPipelineWriter() override { send_queue.notifyNextPipelineWriter(); } + void registerTask(TaskPtr && task) override { send_queue.registerPipeWriteTask(std::move(task)); } private: @@ -249,6 +252,8 @@ class AsyncTunnelSender : public TunnelSender bool isWritable() const override { return queue.isWritable(); } + void notifyNextPipelineWriter() override { queue.notifyNextPipelineWriter(); } + void cancelWith(const String & reason) override { queue.cancelWith(reason); } const String & getCancelReason() const { return queue.getCancelReason(); } @@ -319,6 +324,17 @@ class LocalTunnelSenderV2 : public TunnelSender } } + void notifyNextPipelineWriter() override + { + if constexpr (local_only) + local_request_handler.notifyNextPipelineWriter(); + else + { + std::lock_guard lock(mu); + local_request_handler.notifyNextPipelineWriter(); + } + } + void registerTask(TaskPtr && task) override { if constexpr (local_only) @@ -423,6 +439,7 @@ class LocalTunnelSenderV1 : public TunnelSender bool finish() override { return send_queue.finish(); } bool isWritable() const override { return send_queue.isWritable(); } + void notifyNextPipelineWriter() override { send_queue.notifyNextPipelineWriter(); } void registerTask(TaskPtr && task) override { send_queue.registerPipeWriteTask(std::move(task)); } @@ -502,6 +519,12 @@ class MPPTunnel : private boost::noncopyable WaitResult waitForWritable() const; void forceWrite(TrackedMppDataPacketPtr && data); + void notifyNextPipelineWriter() + { + assert(tunnel_sender != nullptr); + tunnel_sender->notifyNextPipelineWriter(); + } + // finish the writing, and wait until the sender finishes. void writeDone(); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 47f7fe2299c..d2fb8c01c58 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -78,6 +78,15 @@ WaitResult MPPTunnelSetBase::waitForWritable() const return WaitResult::Ready; } +template +void MPPTunnelSetBase::notifyNextPipelineWriter() const +{ + for (const auto & tunnel : tunnels) + { + tunnel->notifyNextPipelineWriter(); + } +} + template void MPPTunnelSetBase::registerTunnel(const MPPTaskId & receiver_task_id, const TunnelPtr & tunnel) { diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index a57f57b3ac8..78fd1018724 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -61,6 +61,7 @@ class MPPTunnelSetBase : private boost::noncopyable const std::vector & getTunnels() const { return tunnels; } WaitResult waitForWritable() const; + void notifyNextPipelineWriter() const; bool isLocal(size_t index) const; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp index 6a044b8a010..aae8a325281 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSetHelper.cpp @@ -34,12 +34,13 @@ TrackedMppDataPacketPtr ToPacket( }; auto && res = codec.encode(std::move(part_columns), method); - if unlikely (res.empty()) - return nullptr; auto tracked_packet = std::make_shared(version); - tracked_packet->addChunk(std::move(res)); - original_size += codec.original_size; + if likely (!res.empty()) + { + tracked_packet->addChunk(std::move(res)); + original_size += codec.original_size; + } return tracked_packet; } @@ -67,9 +68,6 @@ TrackedMppDataPacketPtr ToPacket( TrackedMppDataPacketPtr ToPacketV0(Blocks & blocks, const std::vector & field_types) { - if (blocks.empty()) - return nullptr; - CHBlockChunkCodec codec; auto codec_stream = codec.newCodecStream(field_types); auto tracked_packet = std::make_shared(MPPDataPacketV0); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp index f75a08c1883..072fcc2bee3 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp @@ -365,8 +365,7 @@ void MPPTunnelSetWriterBase::passThroughWrite( void MPPTunnelSetWriterBase::partitionWrite(Blocks & blocks, int16_t partition_id) { auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); - if (!tracked_packet) - return; + assert(tracked_packet); auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); checkPacketSize(packet_bytes); writeToTunnel(std::move(tracked_packet), partition_id); @@ -392,8 +391,7 @@ void MPPTunnelSetWriterBase::partitionWrite( size_t original_size = 0; auto tracked_packet = MPPTunnelSetHelper::ToPacket(header, std::move(part_columns), version, compression_method, original_size); - if (!tracked_packet) - return; + assert(tracked_packet); auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); checkPacketSize(packet_bytes); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h index 1af6730f108..03d77bf6681 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h @@ -70,6 +70,7 @@ class MPPTunnelSetWriterBase : private boost::noncopyable uint16_t getPartitionNum() const { return mpp_tunnel_set->getPartitionNum(); } virtual WaitResult waitForWritable() const = 0; + virtual void notifyNextPipelineWriter() const = 0; protected: virtual void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) = 0; @@ -93,6 +94,7 @@ class SyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase // For sync writer, `waitForWritable` will not be called, so an exception is thrown here. WaitResult waitForWritable() const override { throw Exception("Unsupport sync writer"); } + void notifyNextPipelineWriter() const override {} protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; @@ -111,6 +113,7 @@ class AsyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase {} WaitResult waitForWritable() const override { return mpp_tunnel_set->waitForWritable(); } + void notifyNextPipelineWriter() const override { mpp_tunnel_set->notifyNextPipelineWriter(); } protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; diff --git a/dbms/src/Flash/Mpp/ReceivedMessageQueue.h b/dbms/src/Flash/Mpp/ReceivedMessageQueue.h index c975e4d2ab1..e160dfcc49c 100644 --- a/dbms/src/Flash/Mpp/ReceivedMessageQueue.h +++ b/dbms/src/Flash/Mpp/ReceivedMessageQueue.h @@ -98,6 +98,7 @@ class ReceivedMessageQueue } bool isWritable() const { return grpc_recv_queue.isWritable(); } + void notifyNextPipelineWriter() { grpc_recv_queue.notifyNextPipelineWriter(); } void registerPipeReadTask(TaskPtr && task) { grpc_recv_queue.registerPipeReadTask(std::move(task)); } void registerPipeWriteTask(TaskPtr && task) { grpc_recv_queue.registerPipeWriteTask(std::move(task)); } diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index 84e51cb7151..00ae84c93a2 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -169,6 +169,7 @@ struct MockExchangeWriter return index == 0; } static WaitResult waitForWritable() { throw Exception("Unsupport async write"); } + static void notifyNextPipelineWriter() {} private: MockExchangeWriterChecker checker; diff --git a/dbms/src/Flash/Mpp/tests/gtest_trigger_pipeline_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_trigger_pipeline_writer.cpp new file mode 100644 index 00000000000..f4c1db357ba --- /dev/null +++ b/dbms/src/Flash/Mpp/tests/gtest_trigger_pipeline_writer.cpp @@ -0,0 +1,142 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ +namespace tests +{ +class MockPipelineTriggerWriter : public DAGResponseWriter +{ +public: + MockPipelineTriggerWriter(Int64 records_per_chunk, DAGContext & dag_context) + : DAGResponseWriter(records_per_chunk, dag_context) + , rng(dev()) + {} + bool doWrite(const Block &) override + { + std::uniform_int_distribution dist; + auto next = dist(rng); + return next % 3 == 1; + } + bool doFlush() override + { + std::uniform_int_distribution dist; + auto next = dist(rng); + return next % 3 == 0; + } + void notifyNextPipelineWriter() override { writer_notify_count--; } + WaitResult waitForWritable() const override + { + std::uniform_int_distribution dist; + auto next = dist(rng); + if (next % 3 == 1) + { + writer_notify_count++; + return WaitResult::WaitForNotify; + } + return WaitResult::Ready; + } + Int64 getPipelineNotifyCount() const { return writer_notify_count; } + +private: + mutable Int64 writer_notify_count = 0; + std::random_device dev; + mutable std::mt19937 rng; +}; + +class TestTriggerPipelineWriter : public testing::Test +{ +protected: + void SetUp() override + { + dag_context_ptr = std::make_unique(1024); + dag_context_ptr->encode_type = tipb::EncodeType::TypeCHBlock; + dag_context_ptr->kind = DAGRequestKind::MPP; + dag_context_ptr->is_root_mpp_task = false; + dag_context_ptr->result_field_types = makeFields(); + } + +public: + TestTriggerPipelineWriter() = default; + + // Return 10 Int64 column. + static std::vector makeFields() + { + std::vector fields(10); + for (int i = 0; i < 10; ++i) + { + fields[i].set_tp(TiDB::TypeLongLong); + fields[i].set_flag(TiDB::ColumnFlagNotNull); + } + return fields; + } + + // Return a block with **rows** and 10 Int64 column. + static Block prepareRandomBlock(size_t rows) + { + Block block; + for (size_t i = 0; i < 10; ++i) + { + DataTypePtr int64_data_type = std::make_shared(); + auto int64_column = ColumnGenerator::instance().generate({rows, "Int64", RANDOM}).column; + block.insert( + ColumnWithTypeAndName{std::move(int64_column), int64_data_type, String("col") + std::to_string(i)}); + } + return block; + } + + std::unique_ptr dag_context_ptr; +}; + +TEST_F(TestTriggerPipelineWriter, testPipelineWriter) +try +{ + const size_t block_rows = 1024; + // 1. Build Block. + auto block = prepareRandomBlock(block_rows); + + for (int test_index = 0; test_index < 100; test_index++) + { + // 2. Build MockWriter. + auto mock_writer = std::make_shared(-1, *dag_context_ptr); + + // 3. write something + for (int i = 0; i < 100; i++) + { + mock_writer->waitForWritable(); + mock_writer->write(block); + } + + // 4. flush + mock_writer->waitForWritable(); + mock_writer->flush(); + + // 5. check results, note redudent notify is allowed + ASSERT_EQ(true, mock_writer->getPipelineNotifyCount() <= 0); + } +} +CATCH + +} // namespace tests +} // namespace DB diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.h b/dbms/src/Operators/ExchangeSenderSinkOp.h index 77ed8532f98..e8a3a20ce68 100644 --- a/dbms/src/Operators/ExchangeSenderSinkOp.h +++ b/dbms/src/Operators/ExchangeSenderSinkOp.h @@ -26,9 +26,9 @@ class ExchangeSenderSinkOp : public SinkOp ExchangeSenderSinkOp( PipelineExecutorContext & exec_context_, const String & req_id, - std::unique_ptr && writer) + std::unique_ptr && writer_) : SinkOp(exec_context_, req_id) - , writer(std::move(writer)) + , writer(std::move(writer_)) {} String getName() const override { return "ExchangeSenderSinkOp"; }