diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h index abff5d51c0a..c3a33b06027 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h @@ -32,12 +32,12 @@ class DAGResponseWriter virtual void prepare(const Block &){}; virtual void write(const Block & block) = 0; - // For async writer, `isReadyForWrite` need to be called before calling `write`. + // For async writer, `isWritable` need to be called before calling `write`. // ``` - // while (!isReadyForWrite()) {} + // while (!isWritable()) {} // write(block); // ``` - virtual bool isReadyForWrite() const { throw Exception("Unsupport"); } + virtual bool isWritable() const { throw Exception("Unsupport"); } /// flush cached blocks for batch writer virtual void flush() = 0; diff --git a/dbms/src/Flash/Coprocessor/StreamWriter.h b/dbms/src/Flash/Coprocessor/StreamWriter.h index 8cf498694e5..f4a1d7a618a 100644 --- a/dbms/src/Flash/Coprocessor/StreamWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamWriter.h @@ -55,7 +55,7 @@ struct StreamWriter if (!writer->Write(resp)) throw Exception("Failed to write resp"); } - bool isReadyForWrite() const { throw Exception("Unsupport async write"); } + bool isWritable() const { throw Exception("Unsupport async write"); } }; using StreamWriterPtr = std::shared_ptr; diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index abea17b9f0e..4120fc13012 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -69,9 +69,9 @@ void StreamingDAGResponseWriter::flush() } template -bool StreamingDAGResponseWriter::isReadyForWrite() const +bool StreamingDAGResponseWriter::isWritable() const { - return writer->isReadyForWrite(); + return writer->isWritable(); } template diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index 6d2780f4d1e..7bb42bed63e 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -37,7 +37,7 @@ class StreamingDAGResponseWriter : public DAGResponseWriter Int64 batch_send_min_limit_, DAGContext & dag_context_); void write(const Block & block) override; - bool isReadyForWrite() const override; + bool isWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp index 4f9e7309769..d3294c34e2a 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp @@ -96,7 +96,7 @@ struct MockStreamWriter {} void write(tipb::SelectResponse & response) { checker(response); } - bool isReadyForWrite() const { throw Exception("Unsupport async write"); } + bool isWritable() const { throw Exception("Unsupport async write"); } private: MockStreamWriterChecker checker; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index b40e9a01a4a..3e36af990ad 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -149,7 +149,7 @@ struct MockWriter queue->push(tracked_packet); } static uint16_t getPartitionNum() { return 1; } - static bool isReadyForWrite() { throw Exception("Unsupport async write"); } + static bool isWritable() { throw Exception("Unsupport async write"); } std::vector result_field_types; diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index d6b3fac6d2c..de1cf6fb8f9 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -70,9 +70,9 @@ void BroadcastOrPassThroughWriter::flush() } template -bool BroadcastOrPassThroughWriter::isReadyForWrite() const +bool BroadcastOrPassThroughWriter::isWritable() const { - return writer->isReadyForWrite(); + return writer->isWritable(); } template diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index 812eeb7c70b..e70c3ec9f3e 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -37,7 +37,7 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter tipb::CompressionMode compression_mode_, tipb::ExchangeType exchange_type_); void write(const Block & block) override; - bool isReadyForWrite() const override; + bool isWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp index b1b42935739..af991163d35 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp @@ -824,29 +824,7 @@ DecodeDetail ExchangeReceiverBase::decodeChunks( } template -ReceiveResult ExchangeReceiverBase::receive(size_t stream_id) -{ - return receive( - stream_id, - [&](size_t stream_id, RecvMsgPtr & recv_msg) { - return grpc_recv_queue[stream_id].pop(recv_msg); - }); -} - -template -ReceiveResult ExchangeReceiverBase::nonBlockingReceive(size_t stream_id) -{ - return receive( - stream_id, - [&](size_t stream_id, RecvMsgPtr & recv_msg) { - return grpc_recv_queue[stream_id].tryPop(recv_msg); - }); -} - -template -ReceiveResult ExchangeReceiverBase::receive( - size_t stream_id, - std::function recv_func) +void ExchangeReceiverBase::verifyStreamId(size_t stream_id) const { if (unlikely(stream_id >= grpc_recv_queue.size())) { @@ -854,9 +832,12 @@ ReceiveResult ExchangeReceiverBase::receive( LOG_ERROR(exc_log, err_msg); throw Exception(err_msg); } +} - RecvMsgPtr recv_msg; - switch (recv_func(stream_id, recv_msg)) +template +ReceiveResult ExchangeReceiverBase::toReceiveResult(MPMCQueueResult result, RecvMsgPtr && recv_msg) +{ + switch (result) { case MPMCQueueResult::OK: assert(recv_msg); @@ -868,6 +849,24 @@ ReceiveResult ExchangeReceiverBase::receive( } } +template +ReceiveResult ExchangeReceiverBase::receive(size_t stream_id) +{ + verifyStreamId(stream_id); + RecvMsgPtr recv_msg; + auto res = grpc_recv_queue[stream_id].pop(recv_msg); + return toReceiveResult(res, std::move(recv_msg)); +} + +template +ReceiveResult ExchangeReceiverBase::tryReceive(size_t stream_id) +{ + // verifyStreamId has been called in `ExchangeReceiverSourceOp`. + RecvMsgPtr recv_msg; + auto res = grpc_recv_queue[stream_id].tryPop(recv_msg); + return toReceiveResult(res, std::move(recv_msg)); +} + template ExchangeReceiverResult ExchangeReceiverBase::toExchangeReceiveResult( ReceiveResult & recv_result, diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.h b/dbms/src/Flash/Mpp/ExchangeReceiver.h index 5e838cb7745..c3581d7fa1f 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.h +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.h @@ -123,7 +123,7 @@ class ExchangeReceiverBase void close(); ReceiveResult receive(size_t stream_id); - ReceiveResult nonBlockingReceive(size_t stream_id); + ReceiveResult tryReceive(size_t stream_id); ExchangeReceiverResult toExchangeReceiveResult( ReceiveResult & recv_result, @@ -145,6 +145,8 @@ class ExchangeReceiverBase MemoryTracker * getMemoryTracker() const { return mem_tracker.get(); } std::atomic * getDataSizeInQueue() { return &data_size_in_queue; } + void verifyStreamId(size_t stream_id) const; + private: std::shared_ptr mem_tracker; using Request = typename RPCContext::Request; @@ -187,11 +189,8 @@ class ExchangeReceiverBase const RecvMsgPtr & recv_msg, std::unique_ptr & decoder_ptr); - ReceiveResult receive( - size_t stream_id, - std::function recv_func); + inline ReceiveResult toReceiveResult(MPMCQueueResult result, RecvMsgPtr && recv_msg); -private: void prepareMsgChannels(); void prepareGRPCReceiveQueue(); void addLocalConnectionNum(); @@ -213,6 +212,7 @@ class ExchangeReceiverBase return !disaggregated_dispatch_reqs.empty(); } +private: std::shared_ptr rpc_context; const tipb::ExchangeReceiver pb_exchange_receiver; diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp index 70e6ba5cfbc..391e4d73258 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp @@ -96,9 +96,9 @@ void FineGrainedShuffleWriter::flush() } template -bool FineGrainedShuffleWriter::isReadyForWrite() const +bool FineGrainedShuffleWriter::isWritable() const { - return writer->isReadyForWrite(); + return writer->isWritable(); } template diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h index cc94f6ebf58..f1e9d578ec9 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h @@ -45,7 +45,7 @@ class FineGrainedShuffleWriter : public DAGResponseWriter tipb::CompressionMode compression_mode_); void prepare(const Block & sample_block) override; void write(const Block & block) override; - bool isReadyForWrite() const override; + bool isWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index 9646e6a1ad4..c714e037976 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -92,9 +92,9 @@ void HashPartitionWriter::flush() } template -bool HashPartitionWriter::isReadyForWrite() const +bool HashPartitionWriter::isWritable() const { - return writer->isReadyForWrite(); + return writer->isWritable(); } template diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.h b/dbms/src/Flash/Mpp/HashPartitionWriter.h index 12a22705437..a3ba528324a 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.h +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.h @@ -38,7 +38,7 @@ class HashPartitionWriter : public DAGResponseWriter MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_); void write(const Block & block) override; - bool isReadyForWrite() const override; + bool isWritable() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/LocalRequestHandler.h b/dbms/src/Flash/Mpp/LocalRequestHandler.h index 346b3e162b2..a15e34d5eec 100644 --- a/dbms/src/Flash/Mpp/LocalRequestHandler.h +++ b/dbms/src/Flash/Mpp/LocalRequestHandler.h @@ -40,9 +40,9 @@ struct LocalRequestHandler return channel_writer.write(source_index, tracked_packet); } - bool isReadyForWrite() const + bool isWritable() const { - return channel_writer.isReadyForWrite(); + return channel_writer.isWritable(); } void writeDone(bool meet_error, const String & local_err_msg) const diff --git a/dbms/src/Flash/Mpp/MPPTunnel.cpp b/dbms/src/Flash/Mpp/MPPTunnel.cpp index 3663c04ade0..1dd07783435 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnel.cpp @@ -125,7 +125,7 @@ MPPTunnel::~MPPTunnel() void MPPTunnel::close(const String & reason, bool wait_sender_finish) { { - std::unique_lock lk(mu); + std::lock_guard lk(mu); switch (status) { case TunnelStatus::Unconnected: @@ -151,6 +151,7 @@ void MPPTunnel::close(const String & reason, bool wait_sender_finish) RUNTIME_ASSERT(false, log, "Unsupported tunnel status: {}", static_cast(status)); } } + if (wait_sender_finish) waitForSenderFinish(false); } @@ -352,7 +353,7 @@ void MPPTunnel::waitUntilConnectedOrFinished(std::unique_lock & lk) throw Exception(fmt::format("MPPTunnel {} can not be connected because MPPTask is cancelled", tunnel_id)); } -bool MPPTunnel::isReadyForWrite() const +bool MPPTunnel::isWritable() const { std::unique_lock lk(mu); switch (status) @@ -371,7 +372,7 @@ bool MPPTunnel::isReadyForWrite() const } case TunnelStatus::Connected: RUNTIME_CHECK_MSG(tunnel_sender != nullptr, "write to tunnel {} which is already closed.", tunnel_id); - return tunnel_sender->isReadyForWrite(); + return tunnel_sender->isWritable(); default: // Returns true directly for TunnelStatus::WaitingForSenderFinish and TunnelStatus::Finished, // and then handled by `forceWrite`. diff --git a/dbms/src/Flash/Mpp/MPPTunnel.h b/dbms/src/Flash/Mpp/MPPTunnel.h index c5029579f48..22ef2ce0c4b 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.h +++ b/dbms/src/Flash/Mpp/MPPTunnel.h @@ -111,7 +111,7 @@ class TunnelSender : private boost::noncopyable virtual bool finish() = 0; - virtual bool isReadyForWrite() const = 0; + virtual bool isWritable() const = 0; void consumerFinish(const String & err_msg); String getConsumerFinishMsg() @@ -204,7 +204,7 @@ class SyncTunnelSender : public TunnelSender return send_queue.finish(); } - bool isReadyForWrite() const override + bool isWritable() const override { return !send_queue.isFull(); } @@ -246,7 +246,7 @@ class AsyncTunnelSender : public TunnelSender return queue.finish(); } - bool isReadyForWrite() const override + bool isWritable() const override { return !queue.isFull(); } @@ -326,14 +326,14 @@ class LocalTunnelSenderV2 : public TunnelSender return true; } - bool isReadyForWrite() const override + bool isWritable() const override { if constexpr (local_only) - return local_request_handler.isReadyForWrite(); + return local_request_handler.isWritable(); else { std::lock_guard lock(mu); - return local_request_handler.isReadyForWrite(); + return local_request_handler.isWritable(); } } @@ -428,7 +428,7 @@ class LocalTunnelSenderV1 : public TunnelSender return send_queue.finish(); } - bool isReadyForWrite() const override + bool isWritable() const override { return !send_queue.isFull(); } @@ -504,11 +504,11 @@ class MPPTunnel : private boost::noncopyable // forceWrite write a single packet to the tunnel's send queue without blocking, // and need to call isReadForWrite first. // ``` - // while (!isReadyForWrite()) {} + // while (!isWritable()) {} // forceWrite(std::move(data)); // ``` void forceWrite(TrackedMppDataPacketPtr && data); - bool isReadyForWrite() const; + bool isWritable() const; // finish the writing, and wait until the sender finishes. void writeDone(); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index 9aef733173a..ce468cb2611 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -68,11 +68,11 @@ void MPPTunnelSetBase::forceWrite(tipb::SelectResponse & response, size_ } template -bool MPPTunnelSetBase::isReadyForWrite() const +bool MPPTunnelSetBase::isWritable() const { for (const auto & tunnel : tunnels) { - if (!tunnel->isReadyForWrite()) + if (!tunnel->isWritable()) return false; } return true; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index d4ba13b89eb..1af903a912c 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -67,7 +67,7 @@ class MPPTunnelSetBase : private boost::noncopyable const std::vector & getTunnels() const { return tunnels; } - bool isReadyForWrite() const; + bool isWritable() const; bool isLocal(size_t index) const; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h index 209d3c5b605..c2613a8db70 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h @@ -63,7 +63,7 @@ class MPPTunnelSetWriterBase : private boost::noncopyable uint16_t getPartitionNum() const { return mpp_tunnel_set->getPartitionNum(); } - virtual bool isReadyForWrite() const = 0; + virtual bool isWritable() const = 0; protected: virtual void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) = 0; @@ -85,8 +85,8 @@ class SyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) {} - // For sync writer, `isReadyForWrite` will not be called, so an exception is thrown here. - bool isReadyForWrite() const override { throw Exception("Unsupport sync writer"); } + // For sync writer, `isWritable` will not be called, so an exception is thrown here. + bool isWritable() const override { throw Exception("Unsupport sync writer"); } protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; @@ -104,7 +104,7 @@ class AsyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) {} - bool isReadyForWrite() const override { return mpp_tunnel_set->isReadyForWrite(); } + bool isWritable() const override { return mpp_tunnel_set->isWritable(); } protected: void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; diff --git a/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp b/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp index 32eae252b3f..dfa2a4a48cc 100644 --- a/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp +++ b/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp @@ -106,7 +106,7 @@ bool ReceiverChannelWriter::writeNonFineGrain( return success; } -bool ReceiverChannelWriter::isReadyForWrite() const +bool ReceiverChannelWriter::isWritable() const { for (const auto & msg_channel : *msg_channels) { diff --git a/dbms/src/Flash/Mpp/ReceiverChannelWriter.h b/dbms/src/Flash/Mpp/ReceiverChannelWriter.h index 82409f862f6..a066d477217 100644 --- a/dbms/src/Flash/Mpp/ReceiverChannelWriter.h +++ b/dbms/src/Flash/Mpp/ReceiverChannelWriter.h @@ -76,7 +76,7 @@ class ReceiverChannelWriter : public ReceiverChannelBase return success; } - bool isReadyForWrite() const; + bool isWritable() const; private: using WriteToChannelFunc = std::function; diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index bbd43d92539..e0638f528c2 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -228,7 +228,7 @@ struct MockExchangeWriter // make only part 0 use local tunnel return index == 0; } - bool isReadyForWrite() const { throw Exception("Unsupport async write"); } + bool isWritable() const { throw Exception("Unsupport async write"); } private: MockExchangeWriterChecker checker; diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp index f7516618a03..dba31002083 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp @@ -797,7 +797,7 @@ TEST_F(TestMPPTunnel, SyncTunnelForceWrite) mpp_tunnel_ptr->connectSync(writer_ptr.get()); GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); - ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -814,7 +814,7 @@ TEST_F(TestMPPTunnel, AsyncTunnelForceWrite) GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); std::thread t(&MockAsyncCallData::run, call_data.get()); - ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -831,7 +831,7 @@ TEST_F(TestMPPTunnel, LocalTunnelForceWrite) GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); std::thread t(&MockExchangeReceiver::receiveAll, receiver.get()); - ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + ASSERT_TRUE(mpp_tunnel_ptr->isWritable()); mpp_tunnel_ptr->forceWrite(newDataPacket("First")); mpp_tunnel_ptr->writeDone(); GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); @@ -841,7 +841,7 @@ TEST_F(TestMPPTunnel, LocalTunnelForceWrite) GTEST_ASSERT_EQ(receiver->getReceivedMsgs().back()->packet->getPacket().data(), "First"); } -TEST_F(TestMPPTunnel, isReadyForWriteTimeout) +TEST_F(TestMPPTunnel, isWritableTimeout) try { timeout = std::chrono::seconds(1); @@ -849,7 +849,7 @@ try Stopwatch stop_watch{CLOCK_MONOTONIC_COARSE}; while (stop_watch.elapsedSeconds() < 3 * timeout.count()) { - ASSERT_FALSE(mpp_tunnel_ptr->isReadyForWrite()); + ASSERT_FALSE(mpp_tunnel_ptr->isWritable()); } GTEST_FAIL(); } diff --git a/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp b/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp index a1501dc3883..5626a9350f3 100644 --- a/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp +++ b/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp @@ -46,6 +46,20 @@ namespace DB return (op_status); \ } +PipelineExec::PipelineExec( + SourceOpPtr && source_op_, + TransformOps && transform_ops_, + SinkOpPtr && sink_op_) + : source_op(std::move(source_op_)) + , transform_ops(std::move(transform_ops_)) + , sink_op(std::move(sink_op_)) +{ + addOperatorIfAwaitable(sink_op); + for (auto it = transform_ops.rbegin(); it != transform_ops.rend(); ++it) // NOLINT(modernize-loop-convert) + addOperatorIfAwaitable(*it); + addOperatorIfAwaitable(source_op); +} + void PipelineExec::executePrefix() { sink_op->operatePrefix(); @@ -99,9 +113,7 @@ OperatorStatus PipelineExec::executeImpl() } // try fetch block from transform_ops and source_op. -OperatorStatus PipelineExec::fetchBlock( - Block & block, - size_t & start_transform_op_index) +OperatorStatus PipelineExec::fetchBlock(Block & block, size_t & start_transform_op_index) { auto op_status = sink_op->prepare(); HANDLE_OP_STATUS(sink_op, op_status, OperatorStatus::NEED_INPUT); @@ -148,17 +160,26 @@ OperatorStatus PipelineExec::await() } OperatorStatus PipelineExec::awaitImpl() { - auto op_status = sink_op->await(); - HANDLE_OP_STATUS(sink_op, op_status, OperatorStatus::NEED_INPUT); - for (auto it = transform_ops.rbegin(); it != transform_ops.rend(); ++it) // NOLINT(modernize-loop-convert) + for (auto & awaitable : awaitables) { - // If the transform_op returns `NEED_INPUT`, - // we need to call the upstream transform_op until a transform_op returns something other than `NEED_INPUT`. - op_status = (*it)->await(); - HANDLE_OP_STATUS((*it), op_status, OperatorStatus::NEED_INPUT); + auto op_status = awaitable->await(); + switch (op_status) + { + // If NEED_INPUT is returned, continue checking the next operator. + case OperatorStatus::NEED_INPUT: + break; + // For the io status, the operator needs to be filled in io_op for later use in executeIO. + case OperatorStatus::IO: + assert(!io_op); + assert(awaitable); + io_op.emplace(awaitable); + // For unexpected status, an immediate return is required. + default: + return op_status; + } } - op_status = source_op->await(); - HANDLE_LAST_OP_STATUS(source_op, op_status); + // await must eventually return HAS_OUTPUT. + return OperatorStatus::HAS_OUTPUT; } #undef HANDLE_OP_STATUS diff --git a/dbms/src/Flash/Pipeline/Exec/PipelineExec.h b/dbms/src/Flash/Pipeline/Exec/PipelineExec.h index 86fdbe75610..4d182e94898 100644 --- a/dbms/src/Flash/Pipeline/Exec/PipelineExec.h +++ b/dbms/src/Flash/Pipeline/Exec/PipelineExec.h @@ -29,11 +29,7 @@ class PipelineExec : private boost::noncopyable PipelineExec( SourceOpPtr && source_op_, TransformOps && transform_ops_, - SinkOpPtr && sink_op_) - : source_op(std::move(source_op_)) - , transform_ops(std::move(transform_ops_)) - , sink_op(std::move(sink_op_)) - {} + SinkOpPtr && sink_op_); void executePrefix(); void executeSuffix(); @@ -45,21 +41,33 @@ class PipelineExec : private boost::noncopyable OperatorStatus await(); private: - OperatorStatus executeImpl(); + inline OperatorStatus executeImpl(); - OperatorStatus executeIOImpl(); + inline OperatorStatus executeIOImpl(); - OperatorStatus awaitImpl(); + inline OperatorStatus awaitImpl(); - OperatorStatus fetchBlock( - Block & block, - size_t & start_transform_op_index); + inline OperatorStatus fetchBlock(Block & block, size_t & start_transform_op_index); + + // Put the operator that has implemented the `awaitImpl` into the awaitables. + // In order to avoid calling the virtual function Operator::await too much in await, + // only the operator that needs await will implement `awaitImpl` and `isAwaitable`, + // and then it will be called in PipelineExec::await. + template + inline void addOperatorIfAwaitable(const OperatorPtr & op) + { + if (op->isAwaitable()) + awaitables.push_back(op.get()); + } private: SourceOpPtr source_op; TransformOps transform_ops; SinkOpPtr sink_op; + // hold the operators that awaitable. + std::vector awaitables; + // hold the operator which is ready for executing io. std::optional io_op; }; diff --git a/dbms/src/Flash/Pipeline/Schedule/Events/Event.cpp b/dbms/src/Flash/Pipeline/Schedule/Events/Event.cpp index 924ee084653..417cbb75e07 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Events/Event.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Events/Event.cpp @@ -165,7 +165,9 @@ void Event::onTaskFinish(const TaskProfileInfo & task_profile_info) log, "remaining_tasks must >= 0, but actual value is {}", remaining_tasks); - LOG_DEBUG(log, "one task finished, {} tasks remaining", remaining_tasks); +#ifndef NDEBUG + LOG_TRACE(log, "one task finished, {} tasks remaining", remaining_tasks); +#endif // !NDEBUG if (0 == remaining_tasks) finish(); } @@ -210,7 +212,9 @@ void Event::switchStatus(EventStatus from, EventStatus to) magic_enum::enum_name(from), magic_enum::enum_name(to), magic_enum::enum_name(status.load())); - LOG_DEBUG(log, "switch status: {} --> {}", magic_enum::enum_name(from), magic_enum::enum_name(to)); +#ifndef NDEBUG + LOG_TRACE(log, "switch status: {} --> {}", magic_enum::enum_name(from), magic_enum::enum_name(to)); +#endif // !NDEBUG } void Event::assertStatus(EventStatus expect) diff --git a/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.cpp b/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.cpp deleted file mode 100644 index f52b38b7301..00000000000 --- a/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2023 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -namespace DB -{ -bool Spinner::awaitAndCollectReadyTask(TaskPtr && task) -{ - assert(task); - TRACE_MEMORY(task); - auto status = task->await(); - switch (status) - { - case ExecTaskStatus::WAITING: - return false; - case ExecTaskStatus::RUNNING: - task->profile_info.elapsedAwaitTime(); - cpu_tasks.push_back(std::move(task)); - return true; - case ExecTaskStatus::IO: - task->profile_info.elapsedAwaitTime(); - io_tasks.push_back(std::move(task)); - return true; - case FINISH_STATUS: - task->profile_info.elapsedAwaitTime(); - FINALIZE_TASK(task); - return true; - default: - UNEXPECTED_STATUS(logger, status); - } -} - -void Spinner::submitReadyTasks() -{ - if (cpu_tasks.empty() && io_tasks.empty()) - { - tryYield(); - return; - } - - task_scheduler.submitToCPUTaskThreadPool(cpu_tasks); - cpu_tasks.clear(); - - task_scheduler.submitToIOTaskThreadPool(io_tasks); - io_tasks.clear(); - - spin_count = 0; -} - -void Spinner::tryYield() -{ - ++spin_count; - - if (spin_count != 0 && spin_count % 64 == 0) - { - sched_yield(); - if (spin_count == 640) - { - spin_count = 0; - sched_yield(); - } - } -} -} // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.h b/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.h deleted file mode 100644 index 95daceb7a2b..00000000000 --- a/dbms/src/Flash/Pipeline/Schedule/Reactor/Spinner.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2023 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include - -namespace DB -{ -class TaskScheduler; - -/// Used for batch calling task.await and submitting the tasks that have been removed from the waiting state to task thread pools. -/// When there is no non-waiting state task for a long time, it will try to let the current thread rest for a period of time to give the CPU to other threads. -class Spinner -{ -public: - Spinner(TaskScheduler & task_scheduler_, const LoggerPtr & logger_) - : task_scheduler(task_scheduler_) - , logger(logger_->getChild("Spinner")) - {} - - // return true if the task is not in waiting status. - bool awaitAndCollectReadyTask(TaskPtr && task); - - void submitReadyTasks(); - - void tryYield(); - -private: - TaskScheduler & task_scheduler; - - LoggerPtr logger; - - int16_t spin_count = 0; - - std::vector cpu_tasks; - std::vector io_tasks; -}; -} // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.cpp b/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.cpp index 47ad6e45662..38ded3e4bf2 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.cpp @@ -25,12 +25,78 @@ namespace DB { WaitReactor::WaitReactor(TaskScheduler & scheduler_) : scheduler(scheduler_) - , spinner{scheduler, logger} { GET_METRIC(tiflash_pipeline_scheduler, type_waiting_tasks_count).Set(0); thread = std::thread(&WaitReactor::loop, this); } +bool WaitReactor::awaitAndCollectReadyTask(TaskPtr && task) +{ + assert(task); + task->startTraceMemory(); + auto status = task->await(); + switch (status) + { + case ExecTaskStatus::WAITING: + task->endTraceMemory(); + return false; + case ExecTaskStatus::RUNNING: + task->profile_info.elapsedAwaitTime(); + task->endTraceMemory(); + cpu_tasks.push_back(std::move(task)); + return true; + case ExecTaskStatus::IO: + task->profile_info.elapsedAwaitTime(); + task->endTraceMemory(); + io_tasks.push_back(std::move(task)); + return true; + case FINISH_STATUS: + task->profile_info.elapsedAwaitTime(); + task->finalize(); + task->endTraceMemory(); + task.reset(); + return true; + default: + UNEXPECTED_STATUS(logger, status); + } +} + +void WaitReactor::submitReadyTasks() +{ + if (cpu_tasks.empty() && io_tasks.empty()) + { + tryYield(); + return; + } + + scheduler.submitToCPUTaskThreadPool(cpu_tasks); + cpu_tasks.clear(); + + scheduler.submitToIOTaskThreadPool(io_tasks); + io_tasks.clear(); + + spin_count = 0; +} + +void WaitReactor::tryYield() +{ + ++spin_count; + + if (spin_count != 0 && spin_count % 64 == 0) + { +#if defined(__x86_64__) + _mm_pause(); +#else + sched_yield(); +#endif + if (spin_count == 640) + { + spin_count = 0; + sched_yield(); + } + } +} + void WaitReactor::finish() { waiting_task_list.finish(); @@ -65,15 +131,14 @@ void WaitReactor::react(std::list & local_waiting_tasks) { for (auto task_it = local_waiting_tasks.begin(); task_it != local_waiting_tasks.end();) { - if (spinner.awaitAndCollectReadyTask(std::move(*task_it))) + if (awaitAndCollectReadyTask(std::move(*task_it))) task_it = local_waiting_tasks.erase(task_it); else ++task_it; - ASSERT_MEMORY_TRACKER } GET_METRIC(tiflash_pipeline_scheduler, type_waiting_tasks_count).Set(local_waiting_tasks.size()); - spinner.submitReadyTasks(); + submitReadyTasks(); } void WaitReactor::loop() @@ -89,7 +154,6 @@ void WaitReactor::doLoop() { setThreadName("WaitReactor"); LOG_INFO(logger, "start wait reactor loop"); - ASSERT_MEMORY_TRACKER std::list local_waiting_tasks; while (takeFromWaitingTaskList(local_waiting_tasks)) diff --git a/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.h b/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.h index 3d13adcbbed..52bf73e491d 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.h +++ b/dbms/src/Flash/Pipeline/Schedule/Reactor/WaitReactor.h @@ -15,7 +15,6 @@ #pragma once #include -#include #include #include @@ -26,6 +25,8 @@ namespace DB { class TaskScheduler; +/// Used for batch calling task.await and submitting the tasks that have been removed from the waiting state to task thread pools. +/// When there is no non-waiting state task for a long time, it will try to let the current thread rest for a period of time to give the CPU to other threads. class WaitReactor { public: @@ -47,19 +48,27 @@ class WaitReactor // Get the incremental tasks from waiting_task_list. // return false if waiting_task_list is empty and has finished. - bool takeFromWaitingTaskList(std::list & local_waiting_tasks); + inline bool takeFromWaitingTaskList(std::list & local_waiting_tasks); - void react(std::list & local_waiting_tasks); + inline void react(std::list & local_waiting_tasks); -private: - WaitingTaskList waiting_task_list; + inline bool awaitAndCollectReadyTask(TaskPtr && task); + + inline void submitReadyTasks(); + inline void tryYield(); + +private: LoggerPtr logger = Logger::get(); TaskScheduler & scheduler; - Spinner spinner; - std::thread thread; + + WaitingTaskList waiting_task_list; + + int16_t spin_count = 0; + std::vector cpu_tasks; + std::vector io_tasks; }; } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp index 3f1a1c50f66..d5093ab1fab 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp @@ -12,10 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include +#include namespace DB { +namespace FailPoints +{ +extern const char random_pipeline_model_task_run_failpoint[]; +extern const char random_pipeline_model_cancel_failpoint[]; +} // namespace FailPoints + +#define EXECUTE(function) \ + fiu_do_on(FailPoints::random_pipeline_model_cancel_failpoint, exec_status.cancel()); \ + if unlikely (exec_status.isCancelled()) \ + return ExecTaskStatus::CANCELLED; \ + try \ + { \ + auto status = (function()); \ + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_task_run_failpoint); \ + return status; \ + } \ + catch (...) \ + { \ + LOG_WARNING(log, "error occurred and cancel the query"); \ + exec_status.onErrorOccurred(std::current_exception()); \ + return ExecTaskStatus::ERROR; \ + } + EventTask::EventTask( PipelineExecutorStatus & exec_status_, const EventPtr & event_) @@ -53,16 +79,19 @@ void EventTask::finalizeImpl() ExecTaskStatus EventTask::executeImpl() { - return doTaskAction([&] { return doExecuteImpl(); }); + EXECUTE(doExecuteImpl); } ExecTaskStatus EventTask::executeIOImpl() { - return doTaskAction([&] { return doExecuteIOImpl(); }); + EXECUTE(doExecuteIOImpl); } ExecTaskStatus EventTask::awaitImpl() { - return doTaskAction([&] { return doAwaitImpl(); }); + EXECUTE(doAwaitImpl); } + +#undef EXECUTE + } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h index 0905cd15adc..592a8b73124 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h @@ -14,18 +14,12 @@ #pragma once -#include -#include #include #include -#include namespace DB { -namespace FailPoints -{ -extern const char random_pipeline_model_task_run_failpoint[]; -} // namespace FailPoints +class PipelineExecutorStatus; // The base class of event related task. class EventTask : public Task @@ -53,27 +47,6 @@ class EventTask : public Task void finalizeImpl() override; virtual void doFinalizeImpl(){}; -private: - template - ExecTaskStatus doTaskAction(Action && action) - { - if (unlikely(exec_status.isCancelled())) - return ExecTaskStatus::CANCELLED; - - try - { - auto status = action(); - FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_task_run_failpoint); - return status; - } - catch (...) - { - LOG_WARNING(log, "error occurred and cancel the query"); - exec_status.onErrorOccurred(std::current_exception()); - return ExecTaskStatus::ERROR; - } - } - private: PipelineExecutorStatus & exec_status; EventPtr event; diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp index 2d03e2b17af..9c28d439af7 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp @@ -34,7 +34,7 @@ PipelineTask::PipelineTask( void PipelineTask::doFinalizeImpl() { - RUNTIME_CHECK(pipeline_exec); + assert(pipeline_exec); pipeline_exec->executeSuffix(); pipeline_exec.reset(); } @@ -62,7 +62,7 @@ void PipelineTask::doFinalizeImpl() ExecTaskStatus PipelineTask::doExecuteImpl() { - RUNTIME_CHECK(pipeline_exec); + assert(pipeline_exec); auto op_status = pipeline_exec->execute(); switch (op_status) { @@ -78,7 +78,7 @@ ExecTaskStatus PipelineTask::doExecuteImpl() ExecTaskStatus PipelineTask::doExecuteIOImpl() { - RUNTIME_CHECK(pipeline_exec); + assert(pipeline_exec); auto op_status = pipeline_exec->executeIO(); switch (op_status) { @@ -97,7 +97,7 @@ ExecTaskStatus PipelineTask::doExecuteIOImpl() ExecTaskStatus PipelineTask::doAwaitImpl() { - RUNTIME_CHECK(pipeline_exec); + assert(pipeline_exec); auto op_status = pipeline_exec->await(); switch (op_status) { diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.cpp index 17801c94701..c7167400d47 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.cpp @@ -29,7 +29,7 @@ extern const char random_pipeline_model_task_construct_failpoint[]; namespace { // TODO supports more detailed status transfer metrics, such as from waiting to running. -void addToStatusMetrics(ExecTaskStatus to) +ALWAYS_INLINE void addToStatusMetrics(ExecTaskStatus to) { #define M(expect_status, metric_name) \ case (expect_status): \ @@ -56,15 +56,10 @@ void addToStatusMetrics(ExecTaskStatus to) } } // namespace -#define CHECK_FINISHED \ - if unlikely (task_status == ExecTaskStatus::FINISHED \ - || task_status == ExecTaskStatus::ERROR \ - || task_status == ExecTaskStatus::CANCELLED) \ - return task_status; - Task::Task() : log(Logger::get()) - , mem_tracker(nullptr) + , mem_tracker_holder(nullptr) + , mem_tracker_ptr(nullptr) { FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_task_construct_failpoint); GET_METRIC(tiflash_pipeline_task_change_to_status, type_to_init).Increment(); @@ -72,8 +67,10 @@ Task::Task() Task::Task(MemoryTrackerPtr mem_tracker_, const String & req_id) : log(Logger::get(req_id)) - , mem_tracker(std::move(mem_tracker_)) + , mem_tracker_holder(std::move(mem_tracker_)) + , mem_tracker_ptr(mem_tracker_holder.get()) { + assert(mem_tracker_holder.get() == mem_tracker_ptr); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_task_construct_failpoint); GET_METRIC(tiflash_pipeline_task_change_to_status, type_to_init).Increment(); } @@ -88,11 +85,17 @@ Task::~Task() magic_enum::enum_name(task_status)); } +#define CHECK_FINISHED \ + if unlikely (task_status == ExecTaskStatus::FINISHED \ + || task_status == ExecTaskStatus::ERROR \ + || task_status == ExecTaskStatus::CANCELLED) \ + return task_status; + ExecTaskStatus Task::execute() { CHECK_FINISHED - assert(getMemTracker().get() == current_memory_tracker); - assertNormalStatus(ExecTaskStatus::RUNNING); + assert(mem_tracker_ptr == current_memory_tracker); + assert(task_status == ExecTaskStatus::RUNNING || task_status == ExecTaskStatus::INIT); switchStatus(executeImpl()); return task_status; } @@ -100,8 +103,8 @@ ExecTaskStatus Task::execute() ExecTaskStatus Task::executeIO() { CHECK_FINISHED - assert(getMemTracker().get() == current_memory_tracker); - assertNormalStatus(ExecTaskStatus::IO); + assert(mem_tracker_ptr == current_memory_tracker); + assert(task_status == ExecTaskStatus::IO || task_status == ExecTaskStatus::INIT); switchStatus(executeIOImpl()); return task_status; } @@ -109,8 +112,8 @@ ExecTaskStatus Task::executeIO() ExecTaskStatus Task::await() { CHECK_FINISHED - assert(getMemTracker().get() == current_memory_tracker); - assertNormalStatus(ExecTaskStatus::WAITING); + assert(mem_tracker_ptr == current_memory_tracker); + assert(task_status == ExecTaskStatus::WAITING || task_status == ExecTaskStatus::INIT); switchStatus(awaitImpl()); return task_status; } @@ -125,7 +128,9 @@ void Task::finalize() switchStatus(ExecTaskStatus::FINALIZE); finalizeImpl(); +#ifndef NDEBUG LOG_TRACE(log, "task finalize with profile info: {}", profile_info.toJson()); +#endif // !NDEBUG } #undef CHECK_FINISHED @@ -134,20 +139,11 @@ void Task::switchStatus(ExecTaskStatus to) { if (task_status != to) { +#ifndef NDEBUG LOG_TRACE(log, "switch status: {} --> {}", magic_enum::enum_name(task_status), magic_enum::enum_name(to)); +#endif // !NDEBUG addToStatusMetrics(to); task_status = to; } } - -void Task::assertNormalStatus(ExecTaskStatus expect) -{ - RUNTIME_ASSERT( - task_status == expect || task_status == ExecTaskStatus::INIT, - log, - "actual status is {}, but expect status are {} and {}", - magic_enum::enum_name(task_status), - magic_enum::enum_name(expect), - magic_enum::enum_name(ExecTaskStatus::INIT)); -} } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.h index de2717e182b..ef8af9c5536 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/Task.h @@ -53,11 +53,6 @@ class Task virtual ~Task(); - MemoryTrackerPtr getMemTracker() const - { - return mem_tracker; - } - ExecTaskStatus execute(); ExecTaskStatus executeIO(); @@ -68,6 +63,18 @@ class Task // `TaskHelper::FINALIZE_TASK` can help this. void finalize(); + ALWAYS_INLINE void startTraceMemory() + { + assert(nullptr == current_memory_tracker); + assert(0 == CurrentMemoryTracker::getLocalDeltaMemory()); + current_memory_tracker = mem_tracker_ptr; + } + ALWAYS_INLINE void endTraceMemory() + { + CurrentMemoryTracker::submitLocalDeltaMemory(); + current_memory_tracker = nullptr; + } + public: LoggerPtr log; @@ -81,9 +88,7 @@ class Task virtual void finalizeImpl() {} private: - void switchStatus(ExecTaskStatus to); - - void assertNormalStatus(ExecTaskStatus expect); + inline void switchStatus(ExecTaskStatus to); public: TaskProfileInfo profile_info; @@ -92,7 +97,10 @@ class Task size_t mlfq_level{0}; protected: - MemoryTrackerPtr mem_tracker; + // To ensure that the memory tracker will not be destructed prematurely and prevent crashes due to accessing invalid memory tracker pointers. + MemoryTrackerPtr mem_tracker_holder; + // To reduce the overheads of `mem_tracker.get()` + MemoryTracker * mem_tracker_ptr; ExecTaskStatus task_status{ExecTaskStatus::INIT}; }; diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskHelper.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskHelper.h index 767debbeb08..c421d5c60cf 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskHelper.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskHelper.h @@ -23,17 +23,6 @@ namespace DB { -// Hold the shared_ptr of memory tracker. -// To avoid the current_memory_tracker being an illegal pointer. -#define TRACE_MEMORY(task) \ - assert(nullptr == current_memory_tracker); \ - auto memory_tracker = (task)->getMemTracker(); \ - MemoryTrackerSetter memory_tracker_setter{true, memory_tracker.get()}; - -#define ASSERT_MEMORY_TRACKER \ - assert(nullptr == current_memory_tracker); \ - assert(0 == CurrentMemoryTracker::getLocalDeltaMemory()); - #define FINISH_STATUS \ ExecTaskStatus::FINISHED : case ExecTaskStatus::ERROR : case ExecTaskStatus::CANCELLED diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.cpp deleted file mode 100644 index ac29d3a66b4..00000000000 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2023 PingCAP, Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -namespace DB -{ -void TaskProfileInfo::startTimer() -{ - stopwatch.start(); -} - -UInt64 TaskProfileInfo::elapsedFromPrev() -{ - return stopwatch.elapsedFromLastTime(); -} - -void TaskProfileInfo::addCPUExecuteTime(UInt64 value) -{ - cpu_execute_time_ns += value; -} - -void TaskProfileInfo::elapsedCPUPendingTime() -{ - cpu_pending_time_ns += elapsedFromPrev(); -} - -void TaskProfileInfo::addIOExecuteTime(UInt64 value) -{ - io_execute_time_ns += value; -} - -void TaskProfileInfo::elapsedIOPendingTime() -{ - io_pending_time_ns += elapsedFromPrev(); -} - -void TaskProfileInfo::elapsedAwaitTime() -{ - await_time_ns += elapsedFromPrev(); -} - -void QueryProfileInfo::merge(const TaskProfileInfo & task_profile_info) -{ - cpu_execute_time_ns += task_profile_info.getCPUExecuteTimeNs(); - cpu_pending_time_ns += task_profile_info.getCPUPendingTimeNs(); - io_execute_time_ns += task_profile_info.getIOExecuteTimeNs(); - io_pending_time_ns += task_profile_info.getIOPendingTimeNs(); - await_time_ns += task_profile_info.getAwaitTimeNs(); -} -} // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.h index e87e18f0b00..7e39eee5f4e 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/TaskProfileInfo.h @@ -25,13 +25,13 @@ template class ProfileInfo { public: - UInt64 getCPUExecuteTimeNs() const { return cpu_execute_time_ns; } - UInt64 getCPUPendingTimeNs() const { return cpu_pending_time_ns; } - UInt64 getIOExecuteTimeNs() const { return io_execute_time_ns; } - UInt64 getIOPendingTimeNs() const { return io_pending_time_ns; } - UInt64 getAwaitTimeNs() const { return await_time_ns; } + ALWAYS_INLINE UInt64 getCPUExecuteTimeNs() const { return cpu_execute_time_ns; } + ALWAYS_INLINE UInt64 getCPUPendingTimeNs() const { return cpu_pending_time_ns; } + ALWAYS_INLINE UInt64 getIOExecuteTimeNs() const { return io_execute_time_ns; } + ALWAYS_INLINE UInt64 getIOPendingTimeNs() const { return io_pending_time_ns; } + ALWAYS_INLINE UInt64 getAwaitTimeNs() const { return await_time_ns; } - String toJson() const + ALWAYS_INLINE String toJson() const { return fmt::format( R"({{"cpu_execute_time_ns":{},"cpu_pending_time_ns":{},"io_execute_time_ns":{},"io_pending_time_ns":{},"await_time_ns":{}}})", @@ -53,19 +53,40 @@ class ProfileInfo class TaskProfileInfo : public ProfileInfo { public: - void startTimer(); + ALWAYS_INLINE void startTimer() + { + stopwatch.start(); + } - UInt64 elapsedFromPrev(); + ALWAYS_INLINE UInt64 elapsedFromPrev() + { + return stopwatch.elapsedFromLastTime(); + } - void addCPUExecuteTime(UInt64 value); + ALWAYS_INLINE void addCPUExecuteTime(UInt64 value) + { + cpu_execute_time_ns += value; + } - void elapsedCPUPendingTime(); + ALWAYS_INLINE void elapsedCPUPendingTime() + { + cpu_pending_time_ns += elapsedFromPrev(); + } - void addIOExecuteTime(UInt64 value); + ALWAYS_INLINE void addIOExecuteTime(UInt64 value) + { + io_execute_time_ns += value; + } - void elapsedIOPendingTime(); + ALWAYS_INLINE void elapsedIOPendingTime() + { + io_pending_time_ns += elapsedFromPrev(); + } - void elapsedAwaitTime(); + ALWAYS_INLINE void elapsedAwaitTime() + { + await_time_ns += elapsedFromPrev(); + } private: Stopwatch stopwatch{CLOCK_MONOTONIC_COARSE}; @@ -74,6 +95,13 @@ class TaskProfileInfo : public ProfileInfo class QueryProfileInfo : public ProfileInfo { public: - void merge(const TaskProfileInfo & task_profile_info); + ALWAYS_INLINE void merge(const TaskProfileInfo & task_profile_info) + { + cpu_execute_time_ns += task_profile_info.getCPUExecuteTimeNs(); + cpu_pending_time_ns += task_profile_info.getCPUPendingTimeNs(); + io_execute_time_ns += task_profile_info.getIOExecuteTimeNs(); + io_pending_time_ns += task_profile_info.getIOPendingTimeNs(); + await_time_ns += task_profile_info.getAwaitTimeNs(); + } }; } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/ThreadPool/TaskThreadPool.cpp b/dbms/src/Flash/Pipeline/Schedule/ThreadPool/TaskThreadPool.cpp index 6d86ab6eb8d..b5b367aba3d 100644 --- a/dbms/src/Flash/Pipeline/Schedule/ThreadPool/TaskThreadPool.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/ThreadPool/TaskThreadPool.cpp @@ -71,7 +71,6 @@ void TaskThreadPool::doLoop(size_t thread_no) auto thread_logger = logger->getChild(thread_no_str); setThreadName(thread_no_str.c_str()); LOG_INFO(thread_logger, "start loop"); - ASSERT_MEMORY_TRACKER TaskPtr task; while (likely(task_queue->take(task))) @@ -79,7 +78,6 @@ void TaskThreadPool::doLoop(size_t thread_no) metrics.decPendingTask(); handleTask(task); assert(!task); - ASSERT_MEMORY_TRACKER } LOG_INFO(thread_logger, "loop finished"); @@ -89,7 +87,7 @@ template void TaskThreadPool::handleTask(TaskPtr & task) { assert(task); - TRACE_MEMORY(task); + task->startTraceMemory(); metrics.incExecutingTask(); metrics.elapsedPendingTime(task); @@ -114,16 +112,21 @@ void TaskThreadPool::handleTask(TaskPtr & task) switch (status) { case ExecTaskStatus::RUNNING: + task->endTraceMemory(); scheduler.submitToCPUTaskThreadPool(std::move(task)); break; case ExecTaskStatus::IO: + task->endTraceMemory(); scheduler.submitToIOTaskThreadPool(std::move(task)); break; case ExecTaskStatus::WAITING: + task->endTraceMemory(); scheduler.submitToWaitReactor(std::move(task)); break; case FINISH_STATUS: - FINALIZE_TASK(task); + task->finalize(); + task->endTraceMemory(); + task.reset(); break; default: UNEXPECTED_STATUS(task->log, status); diff --git a/dbms/src/Operators/AggregateRestoreSourceOp.h b/dbms/src/Operators/AggregateRestoreSourceOp.h index 3c3fdd99614..436fcf07ebc 100644 --- a/dbms/src/Operators/AggregateRestoreSourceOp.h +++ b/dbms/src/Operators/AggregateRestoreSourceOp.h @@ -41,6 +41,8 @@ class AggregateRestoreSourceOp : public SourceOp OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } + private: AggregateContextPtr agg_context; SharedAggregateRestorerPtr restorer; diff --git a/dbms/src/Operators/CoprocessorReaderSourceOp.h b/dbms/src/Operators/CoprocessorReaderSourceOp.h index 74fa7b0dbf0..f7343a366fb 100644 --- a/dbms/src/Operators/CoprocessorReaderSourceOp.h +++ b/dbms/src/Operators/CoprocessorReaderSourceOp.h @@ -39,6 +39,7 @@ class CoprocessorReaderSourceOp : public SourceOp protected: OperatorStatus readImpl(Block & block) override; OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } private: Block popFromBlockQueue(); diff --git a/dbms/src/Operators/ExchangeReceiverSourceOp.cpp b/dbms/src/Operators/ExchangeReceiverSourceOp.cpp index 19bf83213a2..03e3045cfb3 100644 --- a/dbms/src/Operators/ExchangeReceiverSourceOp.cpp +++ b/dbms/src/Operators/ExchangeReceiverSourceOp.cpp @@ -92,7 +92,7 @@ OperatorStatus ExchangeReceiverSourceOp::awaitImpl() { if (!block_queue.empty() || recv_res) return OperatorStatus::HAS_OUTPUT; - recv_res.emplace(exchange_receiver->nonBlockingReceive(stream_id)); + recv_res.emplace(exchange_receiver->tryReceive(stream_id)); switch (recv_res->recv_status) { case ReceiveStatus::ok: diff --git a/dbms/src/Operators/ExchangeReceiverSourceOp.h b/dbms/src/Operators/ExchangeReceiverSourceOp.h index de0d68b21e1..f1aab0c6c19 100644 --- a/dbms/src/Operators/ExchangeReceiverSourceOp.h +++ b/dbms/src/Operators/ExchangeReceiverSourceOp.h @@ -33,6 +33,7 @@ class ExchangeReceiverSourceOp : public SourceOp , exchange_receiver(exchange_receiver_) , stream_id(stream_id_) { + exchange_receiver->verifyStreamId(stream_id); setHeader(Block(getColumnWithTypeAndName(toNamesAndTypes(exchange_receiver->getOutputSchema())))); decoder_ptr = std::make_unique(getHeader(), 8192); } @@ -49,6 +50,8 @@ class ExchangeReceiverSourceOp : public SourceOp OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } + private: Block popFromBlockQueue(); diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.cpp b/dbms/src/Operators/ExchangeSenderSinkOp.cpp index 83c90c2d32a..cf561867d84 100644 --- a/dbms/src/Operators/ExchangeSenderSinkOp.cpp +++ b/dbms/src/Operators/ExchangeSenderSinkOp.cpp @@ -62,12 +62,12 @@ OperatorStatus ExchangeSenderSinkOp::writeImpl(Block && block) OperatorStatus ExchangeSenderSinkOp::prepareImpl() { - return writer->isReadyForWrite() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; + return writer->isWritable() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; } OperatorStatus ExchangeSenderSinkOp::awaitImpl() { - return writer->isReadyForWrite() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; + return writer->isWritable() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; } } // namespace DB diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.h b/dbms/src/Operators/ExchangeSenderSinkOp.h index 25bddc4c020..ba1928e97bf 100644 --- a/dbms/src/Operators/ExchangeSenderSinkOp.h +++ b/dbms/src/Operators/ExchangeSenderSinkOp.h @@ -47,6 +47,8 @@ class ExchangeSenderSinkOp : public SinkOp OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } + private: std::unique_ptr writer; size_t total_rows = 0; diff --git a/dbms/src/Operators/GetResultSinkOp.h b/dbms/src/Operators/GetResultSinkOp.h index 51fe96ee978..eba098dae9e 100644 --- a/dbms/src/Operators/GetResultSinkOp.h +++ b/dbms/src/Operators/GetResultSinkOp.h @@ -45,6 +45,8 @@ class GetResultSinkOp : public SinkOp OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } + private: ResultQueuePtr result_queue; std::optional t_block; diff --git a/dbms/src/Operators/HashJoinProbeTransformOp.h b/dbms/src/Operators/HashJoinProbeTransformOp.h index c90a2e28a71..84a1ea82f27 100644 --- a/dbms/src/Operators/HashJoinProbeTransformOp.h +++ b/dbms/src/Operators/HashJoinProbeTransformOp.h @@ -42,6 +42,8 @@ class HashJoinProbeTransformOp : public TransformOp OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } + void transformHeaderImpl(Block & header_) override; void operateSuffix() override; diff --git a/dbms/src/Operators/Operator.cpp b/dbms/src/Operators/Operator.cpp index 3876a228c17..58cee269cc9 100644 --- a/dbms/src/Operators/Operator.cpp +++ b/dbms/src/Operators/Operator.cpp @@ -33,7 +33,9 @@ extern const char random_pipeline_model_cancel_failpoint[]; OperatorStatus Operator::await() { - CHECK_IS_CANCELLED + // `exec_status.is_cancelled` has been checked by `EventTask`. + // If `exec_status.is_cancelled` is checked here, the overhead of `exec_status.is_cancelled` will be amplified by the high frequency of `await` calls. + // TODO collect operator profile info here. auto op_status = awaitImpl(); #ifndef NDEBUG diff --git a/dbms/src/Operators/Operator.h b/dbms/src/Operators/Operator.h index 3e3089dfe63..0eb57f33edc 100644 --- a/dbms/src/Operators/Operator.h +++ b/dbms/src/Operators/Operator.h @@ -57,14 +57,16 @@ class Operator {} virtual ~Operator() = default; - // running status may return are NEED_INPUT and HAS_OUTPUT here. - OperatorStatus await(); - virtual OperatorStatus awaitImpl() { throw Exception("Unsupport"); } // running status may return are NEED_INPUT and HAS_OUTPUT here. OperatorStatus executeIO(); virtual OperatorStatus executeIOImpl() { throw Exception("Unsupport"); } + // running status may return are NEED_INPUT and HAS_OUTPUT here. + OperatorStatus await(); + virtual OperatorStatus awaitImpl() { throw Exception("Unsupport"); } + virtual bool isAwaitable() const { return false; } + // These two methods are used to set state, log and etc, and should not perform calculation logic. virtual void operatePrefix() {} virtual void operateSuffix() {} @@ -103,8 +105,6 @@ class SourceOp : public Operator // because there are many operators that need an empty block as input, such as JoinProbe and WindowFunction. OperatorStatus read(Block & block); virtual OperatorStatus readImpl(Block & block) = 0; - - OperatorStatus awaitImpl() override { return OperatorStatus::HAS_OUTPUT; } }; using SourceOpPtr = std::unique_ptr; using SourceOps = std::vector; @@ -133,8 +133,6 @@ class TransformOp : public Operator transformHeaderImpl(header_); setHeader(header_); } - - OperatorStatus awaitImpl() override { return OperatorStatus::NEED_INPUT; } }; using TransformOpPtr = std::unique_ptr; using TransformOps = std::vector; @@ -151,8 +149,6 @@ class SinkOp : public Operator OperatorStatus write(Block && block); virtual OperatorStatus writeImpl(Block && block) = 0; - - OperatorStatus awaitImpl() override { return OperatorStatus::NEED_INPUT; } }; using SinkOpPtr = std::unique_ptr; } // namespace DB diff --git a/dbms/src/Operators/UnorderedSourceOp.h b/dbms/src/Operators/UnorderedSourceOp.h index 81420a8c94b..f952079a80c 100644 --- a/dbms/src/Operators/UnorderedSourceOp.h +++ b/dbms/src/Operators/UnorderedSourceOp.h @@ -70,6 +70,7 @@ class UnorderedSourceOp : public SourceOp protected: OperatorStatus readImpl(Block & block) override; OperatorStatus awaitImpl() override; + bool isAwaitable() const override { return true; } private: void addReadTaskPoolToScheduler()