From d62a06f2baa7280cbbeca5fcbdddf80e9c2b0856 Mon Sep 17 00:00:00 2001 From: xiaoxmeng Date: Thu, 25 Jan 2024 22:55:12 -0800 Subject: [PATCH] [native] Add http request active check when fetch data from output buffer manager Add http request active check when fetch data from output buffer manager in Velox. The active check is based on whether the http callstate has been destroyed or the associated request has expired. This is to avoid arbitrary output buffer to load data into a destination buffer which has set notify but the associated client request has expired. This helps to accelerate the shuffle for query with scale writer which uses arbitrary output buffer. Unit test is added to verify this behavior. --- .../presto_cpp/main/PrestoTask.h | 16 ++++ .../presto_cpp/main/TaskManager.cpp | 23 ++++-- .../main/tests/QueryContextCacheTest.cpp | 5 ++ .../presto_cpp/main/tests/TaskManagerTest.cpp | 81 +++++++++++++++++-- 4 files changed, 113 insertions(+), 12 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/PrestoTask.h b/presto-native-execution/presto_cpp/main/PrestoTask.h index e159bdf2247b..96b3d84b9b42 100644 --- a/presto-native-execution/presto_cpp/main/PrestoTask.h +++ b/presto-native-execution/presto_cpp/main/PrestoTask.h @@ -14,6 +14,7 @@ #pragma once #include +#include "presto_cpp/main/http/HttpServer.h" #include "presto_cpp/main/types/PrestoTaskId.h" #include "presto_cpp/presto_protocol/presto_protocol.h" #include "velox/exec/Task.h" @@ -59,10 +60,25 @@ struct Result { struct ResultRequest { PromiseHolderWeakPtr> promise; + std::weak_ptr state; protocol::TaskId taskId; int64_t bufferId; int64_t token; protocol::DataSize maxSize; + + ResultRequest( + PromiseHolderWeakPtr> _promise, + std::weak_ptr _state, + protocol::TaskId _taskId, + int64_t _bufferId, + int64_t _token, + protocol::DataSize _maxSize) + : promise(std::move(_promise)), + state(std::move(_state)), + taskId(_taskId), + bufferId(_bufferId), + token(_token), + maxSize(_maxSize) {} }; struct PrestoTask { diff --git a/presto-native-execution/presto_cpp/main/TaskManager.cpp b/presto-native-execution/presto_cpp/main/TaskManager.cpp index 21dbbcfbbb85..7481a77db005 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/TaskManager.cpp @@ -98,6 +98,7 @@ std::unique_ptr createCompleteResult(long token) { void getData( PromiseHolderPtr> promiseHolder, + std::weak_ptr stateHolder, const TaskId& taskId, long destination, long token, @@ -154,6 +155,13 @@ void getData( RECORD_METRIC_VALUE( kCounterPartitionedOutputBufferGetDataLatencyMs, getCurrentTimeMs() - startMs); + }, + [stateHolder]() { + auto state = stateHolder.lock(); + if (state == nullptr) { + return false; + } + return !state->requestExpired(); }); if (!bufferFound) { @@ -382,6 +390,7 @@ void TaskManager::getDataForResultRequests( << ", sequence " << resultRequest->token; getData( resultRequest->promise.lock(), + resultRequest->state, resultRequest->taskId, resultRequest->bufferId, resultRequest->token, @@ -831,6 +840,7 @@ folly::Future> TaskManager::getResults( if (prestoTask->task->state() == exec::kRunning) { getData( promiseHolder, + folly::to_weak_ptr(state), taskId, destination, token, @@ -852,12 +862,13 @@ folly::Future> TaskManager::getResults( keepPromiseAlive(promiseHolder, state); - auto request = std::make_unique(); - request->promise = folly::to_weak_ptr(promiseHolder); - request->taskId = taskId; - request->bufferId = destination; - request->token = token; - request->maxSize = maxSize; + auto request = std::make_unique( + folly::to_weak_ptr(promiseHolder), + folly::to_weak_ptr(state), + taskId, + destination, + token, + maxSize); prestoTask->resultRequests.insert({destination, std::move(request)}); return std::move(future) .via(httpSrvCpuExecutor_) diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp index 0edd4ed845ee..216402924f49 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp @@ -35,6 +35,11 @@ void verifyQueryCtxCache( } // namespace class QueryContextCacheTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + void SetUp() override { FLAGS_velox_memory_leak_check_enabled = true; } diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 02af7db62a2b..970937f9f3ae 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -15,6 +15,7 @@ #include #include #include +#include "folly/experimental/EventCount.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/TaskResource.h" #include "presto_cpp/main/tests/HttpServerWrapper.h" @@ -27,6 +28,7 @@ #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Exchange.h" +#include "velox/exec/Values.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" @@ -146,6 +148,7 @@ class TaskManagerTest : public testing::Test { public: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance({}); + common::testutil::TestValue::enable(); } protected: @@ -709,6 +712,70 @@ TEST_F(TaskManagerTest, fecthFromFinishedTask) { ASSERT_TRUE(newResult.value()->complete); } +DEBUG_ONLY_TEST_F(TaskManagerTest, fecthFromArbitraryOutput) { + // Block output until the first fetch destination becomes inactive. + folly::EventCount outputWait; + std::atomic outputWaitFlag{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Values::getOutput", + std::function( + [&](const velox::exec::Values* values) { + outputWait.await([&]() { return outputWaitFlag.load(); }); + })); + + const std::vector batches = makeVectors(1, 1'000); + auto planFragment = exec::test::PlanBuilder() + .values(batches) + .partitionedOutputArbitrary({"c0", "c1"}) + .planFragment(); + const protocol::TaskId taskId = "source.0.0.1.0"; + const auto taskInfo = createOrUpdateTask(taskId, {}, planFragment); + + const protocol::Duration longWait("10s"); + const auto maxSize = protocol::DataSize("1024MB"); + auto expiredRequestState = http::CallbackRequestHandlerState::create(); + auto consumeCompleted = false; + // Consume from destination 0 to simulate the case that the http request has + // expired while destination has notify setup. + auto expiredResultWait = taskManager_->getResults( + taskId, 0, 0, maxSize, protocol::Duration("1s"), expiredRequestState); + // Reset the http request to simulate the case that it has expired. + expiredRequestState.reset(); + + // Unblock output. + outputWaitFlag = true; + outputWait.notifyAll(); + + // Consuming from destination 1 and expect get result. + auto requestState = http::CallbackRequestHandlerState::create(); + const auto result = + taskManager_ + ->getResults( + taskId, 1, 0, maxSize, protocol::Duration("10s"), requestState) + .getVia(folly::EventBaseManager::get()->getEventBase()); + ASSERT_FALSE(result->complete); + ASSERT_FALSE(result->data->empty()); + ASSERT_EQ(result->sequence, 0); + ASSERT_EQ(result->nextSequence, 1); + + // Check the expired result hasn't fetched any data after timeout. + const auto expriredResult = + std::move(expiredResultWait) + .getVia(folly::EventBaseManager::get()->getEventBase()); + ASSERT_FALSE(expriredResult->complete); + ASSERT_TRUE(expriredResult->data->empty()); + ASSERT_EQ(expriredResult->sequence, 0); + ASSERT_EQ(expriredResult->nextSequence, 0); + + // Close destinations and triggers the task closure. + taskManager_->abortResults(taskId, 0); + taskManager_->abortResults(taskId, 1); + + auto prestoTask = taskManager_->tasks().at(taskId); + ASSERT_TRUE(waitForTaskStateChange( + prestoTask->task.get(), TaskState::kFinished, 3'000'000)); +} + TEST_F(TaskManagerTest, taskCleanupWithPendingResultData) { // Trigger old task cleanup immediately. taskManager_->setOldTaskCleanUpMs(0); @@ -1103,12 +1170,14 @@ TEST_F(TaskManagerTest, getDataOnAbortedTask) { }); auto promiseHolder = std::make_shared>>( std::move(promise)); - auto request = std::make_unique(); - request->promise = folly::to_weak_ptr(promiseHolder); - request->taskId = scanTaskId; - request->bufferId = 0; - request->token = token; - request->maxSize = protocol::DataSize("32MB"); + auto requestState = http::CallbackRequestHandlerState::create(); + auto request = std::make_unique( + folly::to_weak_ptr(promiseHolder), + folly::to_weak_ptr(requestState), + scanTaskId, + 0, + token, + protocol::DataSize("32MB")); prestoTask->resultRequests.insert({0, std::move(request)}); prestoTask->task = createDummyExecTask(scanTaskId, planFragment); taskManager_->getDataForResultRequests(prestoTask->resultRequests);