diff --git a/dbms/src/Common/UniThreadPool.cpp b/dbms/src/Common/UniThreadPool.cpp index 122c31c4c94..5c512feed33 100644 --- a/dbms/src/Common/UniThreadPool.cpp +++ b/dbms/src/Common/UniThreadPool.cpp @@ -90,6 +90,13 @@ void ThreadPoolImpl::setQueueSize(size_t value) jobs.reserve(queue_size); } +template +size_t ThreadPoolImpl::getQueueSize() const +{ + std::lock_guard lock(mutex); + return queue_size; +} + template template @@ -202,7 +209,7 @@ template std::future ThreadPoolImpl::scheduleWithFuture(Job job, uint64_t wait_timeout_us) { auto task = std::make_shared>(std::move(job)); - scheduleOrThrow([task]() { (*task)(); }, 0, wait_timeout_us); + scheduleImpl([task]() { (*task)(); }, /*priority*/ 0, wait_timeout_us); return task->get_future(); } diff --git a/dbms/src/Common/UniThreadPool.h b/dbms/src/Common/UniThreadPool.h index 3a9ad956b37..b751f6f9e63 100644 --- a/dbms/src/Common/UniThreadPool.h +++ b/dbms/src/Common/UniThreadPool.h @@ -74,16 +74,19 @@ class ThreadPoolImpl void scheduleOrThrowOnError(Job job, ssize_t priority = 0); /// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or return false. + /// If wait_microseconds is zero, it means never wait. bool trySchedule(Job job, ssize_t priority = 0, uint64_t wait_microseconds = 0) noexcept; /// Similar to scheduleOrThrowOnError(...). Wait for specified amount of time and schedule a job or throw an exception. + /// If wait_microseconds is zero, it means never wait. void scheduleOrThrow( Job job, ssize_t priority = 0, uint64_t wait_microseconds = 0, bool propagate_opentelemetry_tracing_context = true); - /// Wrap job with std::packaged_task and returns a std::future object to check result of the job. + /// Wrap job with std::packaged_task and returns a std::future object to check if the task has finished or thrown an exception. + /// If wait_microseconds is zero, it means never wait. std::future scheduleWithFuture(Job job, uint64_t wait_timeout_us = 0); /// Wait for all currently active jobs to be done. @@ -107,6 +110,7 @@ class ThreadPoolImpl void setMaxFreeThreads(size_t value); void setQueueSize(size_t value); size_t getMaxThreads() const; + size_t getQueueSize() const; std::unique_ptr> waitGroup() { diff --git a/dbms/src/IO/tests/gtest_io_thread.cpp b/dbms/src/IO/tests/gtest_io_thread.cpp index d5bc1cc10e9..51849122a8f 100644 --- a/dbms/src/IO/tests/gtest_io_thread.cpp +++ b/dbms/src/IO/tests/gtest_io_thread.cpp @@ -146,4 +146,49 @@ TEST(IOThreadPool, TaskChain) buildReadTasksForWNs(a, b, c, d); } +TEST(IOThreadPool, WaitTimeout) +{ + auto & thread_pool = BuildReadTaskPool::get(); + const auto queue_size = thread_pool.getQueueSize(); + std::atomic stop_flag{false}; + IOPoolHelper::FutureContainer futures(Logger::get()); + auto loop_until_stop = [&]() { + while (!stop_flag) + std::this_thread::sleep_for(std::chrono::seconds(1)); + }; + for (size_t i = 0; i < queue_size; ++i) + { + auto f = thread_pool.scheduleWithFuture(loop_until_stop); + futures.add(std::move(f)); + } + ASSERT_EQ(thread_pool.active(), queue_size); + + auto try_result = thread_pool.trySchedule(loop_until_stop); + ASSERT_FALSE(try_result); + + try + { + auto f = thread_pool.scheduleWithFuture(loop_until_stop); + futures.add(std::move(f)); + FAIL() << "Should throw exception."; + } + catch (Exception & e) + { + ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=0)")); + } + + try + { + auto f = thread_pool.scheduleWithFuture(loop_until_stop, 10000); + futures.add(std::move(f)); + FAIL() << "Should throw exception."; + } + catch (Exception & e) + { + ASSERT_TRUE(e.message().starts_with("Cannot schedule a task: no free thread (timeout=10000)")); + } + + stop_flag.store(true); + futures.getAllResults(); +} } // namespace DB::tests diff --git a/dbms/src/Storages/StorageDisaggregated.h b/dbms/src/Storages/StorageDisaggregated.h index 70a7753e23c..c63efd91534 100644 --- a/dbms/src/Storages/StorageDisaggregated.h +++ b/dbms/src/Storages/StorageDisaggregated.h @@ -143,6 +143,9 @@ class StorageDisaggregated : public IStorage DAGExpressionAnalyzer & analyzer); tipb::Executor buildTableScanTiPB(); + size_t getBuildTaskRPCTimeout() const; + size_t getBuildTaskIOThreadPoolTimeout() const; + private: Context & context; const TiDBTableScan & table_scan; diff --git a/dbms/src/Storages/StorageDisaggregatedRemote.cpp b/dbms/src/Storages/StorageDisaggregatedRemote.cpp index d920fc6dbd9..5c5b5d02353 100644 --- a/dbms/src/Storages/StorageDisaggregatedRemote.cpp +++ b/dbms/src/Storages/StorageDisaggregatedRemote.cpp @@ -217,7 +217,8 @@ DM::SegmentReadTasks StorageDisaggregated::buildReadTask( for (const auto & cop_task : batch_cop_tasks) { auto f = BuildReadTaskForWNPool::get().scheduleWithFuture( - [&] { buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks); }); + [&] { buildReadTaskForWriteNode(db_context, scan_context, cop_task, output_lock, output_seg_tasks); }, + getBuildTaskIOThreadPoolTimeout()); futures.add(std::move(f)); } futures.getAllResults(); @@ -246,7 +247,7 @@ void StorageDisaggregated::buildReadTaskForWriteNode( pingcap::kv::RpcCall rpc(cluster->rpc_client, req->address()); disaggregated::EstablishDisaggTaskResponse resp; grpc::ClientContext client_context; - rpc.setClientContext(client_context, db_context.getSettingsRef().disagg_build_task_timeout); + rpc.setClientContext(client_context, getBuildTaskRPCTimeout()); auto status = rpc.call(&client_context, *req, &resp); if (status.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) throw Exception( @@ -364,17 +365,19 @@ void StorageDisaggregated::buildReadTaskForWriteNode( IOPoolHelper::FutureContainer futures(log, resp.tables().size()); for (const auto & serialized_physical_table : resp.tables()) { - auto f = BuildReadTaskForWNTablePool::get().scheduleWithFuture([&] { - buildReadTaskForWriteNodeTable( - db_context, - scan_context, - snapshot_id, - resp.store_id(), - req->address(), - serialized_physical_table, - output_lock, - output_seg_tasks); - }); + auto f = BuildReadTaskForWNTablePool::get().scheduleWithFuture( + [&] { + buildReadTaskForWriteNodeTable( + db_context, + scan_context, + snapshot_id, + resp.store_id(), + req->address(), + serialized_physical_table, + output_lock, + output_seg_tasks); + }, + getBuildTaskIOThreadPoolTimeout()); futures.add(std::move(f)); } futures.getAllResults(); @@ -395,7 +398,6 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable( RUNTIME_CHECK_MSG(parse_ok, "Failed to deserialize RemotePhysicalTable from response"); auto table_tracing_logger = log->getChild( fmt::format("store_id={} keyspace={} table_id={}", store_id, table.keyspace_id(), table.table_id())); - auto disagg_build_task_timeout_us = db_context.getSettingsRef().disagg_build_task_timeout * 1000000; IOPoolHelper::FutureContainer futures(log, table.segments().size()); for (const auto & remote_seg : table.segments()) @@ -415,7 +417,7 @@ void StorageDisaggregated::buildReadTaskForWriteNodeTable( std::lock_guard lock(output_lock); output_seg_tasks.push_back(seg_read_task); }, - disagg_build_task_timeout_us); + getBuildTaskIOThreadPoolTimeout()); futures.add(std::move(f)); } futures.getAllResults(); @@ -704,4 +706,14 @@ void StorageDisaggregated::buildRemoteSegmentSourceOps( group_builder.getCurProfileInfos()); } +size_t StorageDisaggregated::getBuildTaskRPCTimeout() const +{ + return context.getSettingsRef().disagg_build_task_timeout; +} + +size_t StorageDisaggregated::getBuildTaskIOThreadPoolTimeout() const +{ + return context.getSettingsRef().disagg_build_task_timeout * 1000000; +} + } // namespace DB