Skip to content

Commit

Permalink
[native] Add http request active check when fetch data from output bu…
Browse files Browse the repository at this point in the history
…ffer manager

Add http request active check when fetch data from output buffer manager
in Velox. The active check is based on whether the http callstate has been destroyed
or the associated request has expired. This is to avoid arbitrary output buffer
to load data into a destination buffer which has set notify but the associated client
request has expired. This helps to accelerate the shuffle for query with scale writer
which uses arbitrary output buffer.

Unit test is added to verify this behavior.
  • Loading branch information
xiaoxmeng committed Jan 26, 2024
1 parent 2aa0615 commit d62a06f
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 12 deletions.
16 changes: 16 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once

#include <memory>
#include "presto_cpp/main/http/HttpServer.h"
#include "presto_cpp/main/types/PrestoTaskId.h"
#include "presto_cpp/presto_protocol/presto_protocol.h"
#include "velox/exec/Task.h"
Expand Down Expand Up @@ -59,10 +60,25 @@ struct Result {

struct ResultRequest {
PromiseHolderWeakPtr<std::unique_ptr<Result>> promise;
std::weak_ptr<http::CallbackRequestHandlerState> state;
protocol::TaskId taskId;
int64_t bufferId;
int64_t token;
protocol::DataSize maxSize;

ResultRequest(
PromiseHolderWeakPtr<std::unique_ptr<Result>> _promise,
std::weak_ptr<http::CallbackRequestHandlerState> _state,
protocol::TaskId _taskId,
int64_t _bufferId,
int64_t _token,
protocol::DataSize _maxSize)
: promise(std::move(_promise)),
state(std::move(_state)),
taskId(_taskId),
bufferId(_bufferId),
token(_token),
maxSize(_maxSize) {}
};

struct PrestoTask {
Expand Down
23 changes: 17 additions & 6 deletions presto-native-execution/presto_cpp/main/TaskManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ std::unique_ptr<Result> createCompleteResult(long token) {

void getData(
PromiseHolderPtr<std::unique_ptr<Result>> promiseHolder,
std::weak_ptr<http::CallbackRequestHandlerState> stateHolder,
const TaskId& taskId,
long destination,
long token,
Expand Down Expand Up @@ -154,6 +155,13 @@ void getData(
RECORD_METRIC_VALUE(
kCounterPartitionedOutputBufferGetDataLatencyMs,
getCurrentTimeMs() - startMs);
},
[stateHolder]() {
auto state = stateHolder.lock();
if (state == nullptr) {
return false;
}
return !state->requestExpired();
});

if (!bufferFound) {
Expand Down Expand Up @@ -382,6 +390,7 @@ void TaskManager::getDataForResultRequests(
<< ", sequence " << resultRequest->token;
getData(
resultRequest->promise.lock(),
resultRequest->state,
resultRequest->taskId,
resultRequest->bufferId,
resultRequest->token,
Expand Down Expand Up @@ -831,6 +840,7 @@ folly::Future<std::unique_ptr<Result>> TaskManager::getResults(
if (prestoTask->task->state() == exec::kRunning) {
getData(
promiseHolder,
folly::to_weak_ptr(state),
taskId,
destination,
token,
Expand All @@ -852,12 +862,13 @@ folly::Future<std::unique_ptr<Result>> TaskManager::getResults(

keepPromiseAlive(promiseHolder, state);

auto request = std::make_unique<ResultRequest>();
request->promise = folly::to_weak_ptr(promiseHolder);
request->taskId = taskId;
request->bufferId = destination;
request->token = token;
request->maxSize = maxSize;
auto request = std::make_unique<ResultRequest>(
folly::to_weak_ptr(promiseHolder),
folly::to_weak_ptr(state),
taskId,
destination,
token,
maxSize);
prestoTask->resultRequests.insert({destination, std::move(request)});
return std::move(future)
.via(httpSrvCpuExecutor_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void verifyQueryCtxCache(
} // namespace

class QueryContextCacheTest : public testing::Test {
protected:
static void SetUpTestCase() {
memory::MemoryManager::testingSetInstance({});
}

void SetUp() override {
FLAGS_velox_memory_leak_check_enabled = true;
}
Expand Down
81 changes: 75 additions & 6 deletions presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <folly/executors/ThreadedExecutor.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "folly/experimental/EventCount.h"
#include "presto_cpp/main/PrestoExchangeSource.h"
#include "presto_cpp/main/TaskResource.h"
#include "presto_cpp/main/tests/HttpServerWrapper.h"
Expand All @@ -27,6 +28,7 @@
#include "velox/dwio/common/WriterFactory.h"
#include "velox/dwio/common/tests/utils/BatchMaker.h"
#include "velox/exec/Exchange.h"
#include "velox/exec/Values.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/QueryAssertions.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
Expand Down Expand Up @@ -146,6 +148,7 @@ class TaskManagerTest : public testing::Test {
public:
static void SetUpTestCase() {
memory::MemoryManager::testingSetInstance({});
common::testutil::TestValue::enable();
}

protected:
Expand Down Expand Up @@ -709,6 +712,70 @@ TEST_F(TaskManagerTest, fecthFromFinishedTask) {
ASSERT_TRUE(newResult.value()->complete);
}

DEBUG_ONLY_TEST_F(TaskManagerTest, fecthFromArbitraryOutput) {
// Block output until the first fetch destination becomes inactive.
folly::EventCount outputWait;
std::atomic<bool> outputWaitFlag{false};
SCOPED_TESTVALUE_SET(
"facebook::velox::exec::Values::getOutput",
std::function<void(const velox::exec::Values*)>(
[&](const velox::exec::Values* values) {
outputWait.await([&]() { return outputWaitFlag.load(); });
}));

const std::vector<RowVectorPtr> batches = makeVectors(1, 1'000);
auto planFragment = exec::test::PlanBuilder()
.values(batches)
.partitionedOutputArbitrary({"c0", "c1"})
.planFragment();
const protocol::TaskId taskId = "source.0.0.1.0";
const auto taskInfo = createOrUpdateTask(taskId, {}, planFragment);

const protocol::Duration longWait("10s");
const auto maxSize = protocol::DataSize("1024MB");
auto expiredRequestState = http::CallbackRequestHandlerState::create();
auto consumeCompleted = false;
// Consume from destination 0 to simulate the case that the http request has
// expired while destination has notify setup.
auto expiredResultWait = taskManager_->getResults(
taskId, 0, 0, maxSize, protocol::Duration("1s"), expiredRequestState);
// Reset the http request to simulate the case that it has expired.
expiredRequestState.reset();

// Unblock output.
outputWaitFlag = true;
outputWait.notifyAll();

// Consuming from destination 1 and expect get result.
auto requestState = http::CallbackRequestHandlerState::create();
const auto result =
taskManager_
->getResults(
taskId, 1, 0, maxSize, protocol::Duration("10s"), requestState)
.getVia(folly::EventBaseManager::get()->getEventBase());
ASSERT_FALSE(result->complete);
ASSERT_FALSE(result->data->empty());
ASSERT_EQ(result->sequence, 0);
ASSERT_EQ(result->nextSequence, 1);

// Check the expired result hasn't fetched any data after timeout.
const auto expriredResult =
std::move(expiredResultWait)
.getVia(folly::EventBaseManager::get()->getEventBase());
ASSERT_FALSE(expriredResult->complete);
ASSERT_TRUE(expriredResult->data->empty());
ASSERT_EQ(expriredResult->sequence, 0);
ASSERT_EQ(expriredResult->nextSequence, 0);

// Close destinations and triggers the task closure.
taskManager_->abortResults(taskId, 0);
taskManager_->abortResults(taskId, 1);

auto prestoTask = taskManager_->tasks().at(taskId);
ASSERT_TRUE(waitForTaskStateChange(
prestoTask->task.get(), TaskState::kFinished, 3'000'000));
}

TEST_F(TaskManagerTest, taskCleanupWithPendingResultData) {
// Trigger old task cleanup immediately.
taskManager_->setOldTaskCleanUpMs(0);
Expand Down Expand Up @@ -1103,12 +1170,14 @@ TEST_F(TaskManagerTest, getDataOnAbortedTask) {
});
auto promiseHolder = std::make_shared<PromiseHolder<std::unique_ptr<Result>>>(
std::move(promise));
auto request = std::make_unique<ResultRequest>();
request->promise = folly::to_weak_ptr(promiseHolder);
request->taskId = scanTaskId;
request->bufferId = 0;
request->token = token;
request->maxSize = protocol::DataSize("32MB");
auto requestState = http::CallbackRequestHandlerState::create();
auto request = std::make_unique<ResultRequest>(
folly::to_weak_ptr(promiseHolder),
folly::to_weak_ptr(requestState),
scanTaskId,
0,
token,
protocol::DataSize("32MB"));
prestoTask->resultRequests.insert({0, std::move(request)});
prestoTask->task = createDummyExecTask(scanTaskId, planFragment);
taskManager_->getDataForResultRequests(prestoTask->resultRequests);
Expand Down

0 comments on commit d62a06f

Please sign in to comment.