diff --git a/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java b/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java index e134226df5921..13b1b5176a0b6 100644 --- a/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java +++ b/presto-main/src/main/java/io/prestosql/execution/QueryStateMachine.java @@ -998,7 +998,7 @@ private static boolean isScheduled(Optional rootStage) } return getAllStages(rootStage).stream() .map(StageInfo::getState) - .allMatch(state -> (state == StageState.RUNNING) || state.isDone()); + .allMatch(state -> state == StageState.RUNNING || state == StageState.FLUSHING || state.isDone()); } public Optional getFailureInfo() diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java index 0f75569cafdef..169ad7cd4821e 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlStageExecution.java @@ -89,6 +89,8 @@ public final class SqlStageExecution @GuardedBy("this") private final Set finishedTasks = newConcurrentHashSet(); @GuardedBy("this") + private final Set flushingTasks = newConcurrentHashSet(); + @GuardedBy("this") private final Set tasksWithFinalInfo = newConcurrentHashSet(); @GuardedBy("this") private final AtomicBoolean splitsScheduled = new AtomicBoolean(); @@ -225,6 +227,9 @@ public synchronized void schedulingComplete() if (getAllTasks().stream().anyMatch(task -> getState() == StageState.RUNNING)) { stateMachine.transitionToRunning(); } + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } if (finishedTasks.containsAll(allTasks)) { stateMachine.transitionToFinished(); } @@ -505,14 +510,21 @@ else if (taskState == TaskState.ABORTED) { // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) stateMachine.transitionToFailed(new PrestoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); } + else if (taskState == TaskState.FLUSHING) { + flushingTasks.add(taskStatus.getTaskId()); + } else if (taskState == TaskState.FINISHED) { finishedTasks.add(taskStatus.getTaskId()); + flushingTasks.remove(taskStatus.getTaskId()); } - if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING) { + if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING || stageState == StageState.FLUSHING) { if (taskState == TaskState.RUNNING) { stateMachine.transitionToRunning(); } + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } if (finishedTasks.containsAll(allTasks)) { stateMachine.transitionToFinished(); } @@ -524,6 +536,13 @@ else if (taskState == TaskState.FINISHED) { } } + private synchronized boolean isFlushing() + { + // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. + return !flushingTasks.isEmpty() + && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); + } + private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) { tasksWithFinalInfo.add(finalTaskInfo.getTaskStatus().getTaskId()); diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java index 78b4801636322..394d1b963793e 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlTaskExecution.java @@ -641,6 +641,7 @@ private synchronized void checkTaskCompletion() // are there still pages in the output buffer if (!outputBuffer.isFinished()) { + taskStateMachine.transitionToFlushing(); return; } diff --git a/presto-main/src/main/java/io/prestosql/execution/StageState.java b/presto-main/src/main/java/io/prestosql/execution/StageState.java index aed57b6cef5d0..57d837cced771 100644 --- a/presto-main/src/main/java/io/prestosql/execution/StageState.java +++ b/presto-main/src/main/java/io/prestosql/execution/StageState.java @@ -43,6 +43,11 @@ public enum StageState * Stage is running. */ RUNNING(false, false), + /** + * Stage has finished executing and output being consumed. + * In this state, at-least one of the tasks is flushing and the non-flushing tasks are finished + */ + FLUSHING(false, false), /** * Stage has finished executing and all output has been consumed. */ @@ -99,6 +104,7 @@ public boolean canScheduleMoreTasks() case SCHEDULING_SPLITS: case SCHEDULED: case RUNNING: + case FLUSHING: case FINISHED: case CANCELED: // no more workers will be added to the query diff --git a/presto-main/src/main/java/io/prestosql/execution/StageStateMachine.java b/presto-main/src/main/java/io/prestosql/execution/StageStateMachine.java index 6acb783719624..0961155df1895 100644 --- a/presto-main/src/main/java/io/prestosql/execution/StageStateMachine.java +++ b/presto-main/src/main/java/io/prestosql/execution/StageStateMachine.java @@ -56,6 +56,7 @@ import static io.prestosql.execution.StageState.CANCELED; import static io.prestosql.execution.StageState.FAILED; import static io.prestosql.execution.StageState.FINISHED; +import static io.prestosql.execution.StageState.FLUSHING; import static io.prestosql.execution.StageState.PLANNED; import static io.prestosql.execution.StageState.RUNNING; import static io.prestosql.execution.StageState.SCHEDULED; @@ -161,7 +162,12 @@ public synchronized boolean transitionToScheduled() public boolean transitionToRunning() { - return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && !currentState.isDone()); + return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && currentState != FLUSHING && !currentState.isDone()); + } + + public boolean transitionToFlushing() + { + return stageState.setIf(FLUSHING, currentState -> currentState != FLUSHING && !currentState.isDone()); } public boolean transitionToFinished() @@ -253,7 +259,7 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos // information, the stage could finish, and the task states would // never be visible. StageState state = stageState.get(); - boolean isScheduled = (state == RUNNING) || state.isDone(); + boolean isScheduled = state == RUNNING || state == FLUSHING || state.isDone(); List taskInfos = ImmutableList.copyOf(taskInfosSupplier.get()); diff --git a/presto-main/src/main/java/io/prestosql/execution/StageStats.java b/presto-main/src/main/java/io/prestosql/execution/StageStats.java index de0c639db164d..a6963bfc58cb8 100644 --- a/presto-main/src/main/java/io/prestosql/execution/StageStats.java +++ b/presto-main/src/main/java/io/prestosql/execution/StageStats.java @@ -32,6 +32,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static io.prestosql.execution.StageState.FLUSHING; import static io.prestosql.execution.StageState.RUNNING; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -422,7 +423,7 @@ public List getOperatorSummaries() public BasicStageStats toBasicStageStats(StageState stageState) { - boolean isScheduled = (stageState == RUNNING) || stageState.isDone(); + boolean isScheduled = stageState == RUNNING || stageState == FLUSHING || stageState.isDone(); OptionalDouble progressPercentage = OptionalDouble.empty(); if (isScheduled && totalDrivers != 0) { diff --git a/presto-main/src/main/java/io/prestosql/execution/TaskState.java b/presto-main/src/main/java/io/prestosql/execution/TaskState.java index 0b2120b1e82c1..9610592166228 100644 --- a/presto-main/src/main/java/io/prestosql/execution/TaskState.java +++ b/presto-main/src/main/java/io/prestosql/execution/TaskState.java @@ -30,6 +30,12 @@ public enum TaskState * Task is running. */ RUNNING(false), + /** + * Task has finished executing and output is left to be consumed. + * In this state, there will be no new drivers, the existing drivers have finished + * and the output buffer of the task is at-least in a 'no-more-pages' state. + */ + FLUSHING(false), /** * Task has finished executing and all output has been consumed. */ diff --git a/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java b/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java index 56f4f7c2a6c79..5b143fab4dc00 100644 --- a/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java +++ b/presto-main/src/main/java/io/prestosql/execution/TaskStateMachine.java @@ -26,6 +26,8 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.prestosql.execution.TaskState.FLUSHING; +import static io.prestosql.execution.TaskState.RUNNING; import static io.prestosql.execution.TaskState.TERMINAL_TASK_STATES; import static java.util.Objects.requireNonNull; @@ -80,6 +82,11 @@ public LinkedBlockingQueue getFailureCauses() return failureCauses; } + public void transitionToFlushing() + { + taskState.setIf(FLUSHING, currentState -> currentState == RUNNING); + } + public void finished() { transitionToDoneState(TaskState.FINISHED); diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java index 82bd9ff53c567..3ab1e072533d2 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/AllAtOnceExecutionSchedule.java @@ -43,6 +43,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static io.prestosql.execution.StageState.FLUSHING; import static io.prestosql.execution.StageState.RUNNING; import static io.prestosql.execution.StageState.SCHEDULED; import static java.util.Objects.requireNonNull; @@ -71,7 +72,7 @@ public Set getStagesToSchedule() { for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { StageState state = iterator.next().getState(); - if (state == SCHEDULED || state == RUNNING || state.isDone()) { + if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { iterator.remove(); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java index 495d7c98e7370..b7a756ce75916 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/PhasedExecutionSchedule.java @@ -50,6 +50,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.execution.StageState.FLUSHING; import static io.prestosql.execution.StageState.RUNNING; import static io.prestosql.execution.StageState.SCHEDULED; import static io.prestosql.sql.planner.plan.ExchangeNode.Scope.LOCAL; @@ -92,7 +93,7 @@ private void removeCompletedStages() { for (Iterator stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { StageState state = stageIterator.next().getState(); - if (state == SCHEDULED || state == RUNNING || state.isDone()) { + if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { stageIterator.remove(); } } diff --git a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java index 49620ccc2f4ab..522e239dfcec6 100644 --- a/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java +++ b/presto-main/src/main/java/io/prestosql/execution/scheduler/SqlQueryScheduler.java @@ -90,6 +90,7 @@ import static io.prestosql.execution.StageState.CANCELED; import static io.prestosql.execution.StageState.FAILED; import static io.prestosql.execution.StageState.FINISHED; +import static io.prestosql.execution.StageState.FLUSHING; import static io.prestosql.execution.StageState.RUNNING; import static io.prestosql.execution.StageState.SCHEDULED; import static io.prestosql.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; @@ -420,7 +421,7 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { } Set childStages = childStagesBuilder.build(); stage.addStateChangeListener(newState -> { - if (newState.isDone()) { + if (newState == FLUSHING || newState.isDone()) { childStages.forEach(SqlStageExecution::cancel); } }); @@ -595,7 +596,7 @@ else if (!result.getBlocked().isDone()) { for (SqlStageExecution stage : stages.values()) { StageState state = stage.getState(); - if (state != SCHEDULED && state != RUNNING && !state.isDone()) { + if (state != SCHEDULED && state != RUNNING && state != FLUSHING && !state.isDone()) { throw new PrestoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stage.getStageId(), state)); } } diff --git a/presto-main/src/main/java/io/prestosql/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/io/prestosql/server/testing/TestingPrestoServer.java index 7b7fd7c94cb2b..dca135827bf68 100644 --- a/presto-main/src/main/java/io/prestosql/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/io/prestosql/server/testing/TestingPrestoServer.java @@ -364,6 +364,11 @@ public Plan getQueryPlan(QueryId queryId) return queryManager.getQueryPlan(queryId); } + public QueryInfo getFullQueryInfo(QueryId queryId) + { + return queryManager.getFullQueryInfo(queryId); + } + public void addFinalQueryInfoListener(QueryId queryId, StateChangeListener stateChangeListener) { queryManager.addFinalQueryInfoListener(queryId, stateChangeListener); diff --git a/presto-main/src/test/java/io/prestosql/execution/TestSqlTask.java b/presto-main/src/test/java/io/prestosql/execution/TestSqlTask.java index 2cef01b5fa278..c367db68f9e77 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestSqlTask.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestSqlTask.java @@ -137,18 +137,18 @@ public void testSimpleQuery() { SqlTask sqlTask = createInitialTask(); - TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION, + assertEquals(sqlTask.getTaskStatus().getState(), TaskState.RUNNING); + sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), OptionalInt.empty()); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); - taskInfo = sqlTask.getTaskInfo(); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + TaskInfo taskInfo = sqlTask.getTaskInfo(TaskState.RUNNING).get(1, SECONDS); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); BufferResult results = sqlTask.getTaskResults(OUT, 0, DataSize.of(1, MEGABYTE)).get(); - assertEquals(results.isBufferComplete(), false); + assertFalse(results.isBufferComplete()); assertEquals(results.getSerializedPages().size(), 1); assertEquals(results.getSerializedPages().get(0).getPositionCount(), 1); @@ -202,15 +202,15 @@ public void testAbort() { SqlTask sqlTask = createInitialTask(); - TaskInfo taskInfo = sqlTask.updateTask(TEST_SESSION, + assertEquals(sqlTask.getTaskStatus().getState(), TaskState.RUNNING); + sqlTask.updateTask(TEST_SESSION, Optional.of(PLAN_FRAGMENT), ImmutableList.of(new TaskSource(TABLE_SCAN_NODE_ID, ImmutableSet.of(SPLIT), true)), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), OptionalInt.empty()); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); - taskInfo = sqlTask.getTaskInfo(); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + TaskInfo taskInfo = sqlTask.getTaskInfo(TaskState.RUNNING).get(1, SECONDS); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); sqlTask.abortTaskResults(OUT); diff --git a/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskExecution.java b/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskExecution.java index b73eabbb199cf..b2482466acf52 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskExecution.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskExecution.java @@ -88,6 +88,9 @@ import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.block.BlockAssertions.createStringSequenceBlock; import static io.prestosql.block.BlockAssertions.createStringsBlock; +import static io.prestosql.execution.TaskState.FINISHED; +import static io.prestosql.execution.TaskState.FLUSHING; +import static io.prestosql.execution.TaskState.RUNNING; import static io.prestosql.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; import static io.prestosql.execution.TaskTestUtils.createTestSplitMonitor; import static io.prestosql.execution.buffer.BufferState.OPEN; @@ -177,7 +180,7 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) // // test body - assertEquals(taskStateMachine.getState(), TaskState.RUNNING); + assertEquals(taskStateMachine.getState(), RUNNING); switch (executionStrategy) { case UNGROUPED_EXECUTION: @@ -277,9 +280,9 @@ public void testSimple(PipelineExecutionStrategy executionStrategy) throw new UnsupportedOperationException(); } + assertEquals(taskStateMachine.getStateChange(RUNNING).get(10, SECONDS), FLUSHING); outputBufferConsumer.abort(); // complete the task by calling abort on it - TaskState taskState = taskStateMachine.getStateChange(TaskState.RUNNING).get(10, SECONDS); - assertEquals(taskState, TaskState.FINISHED); + assertEquals(taskStateMachine.getStateChange(FLUSHING).get(10, SECONDS), FINISHED); } finally { taskExecutor.stop(); @@ -428,7 +431,7 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) // // test body - assertEquals(taskStateMachine.getState(), TaskState.RUNNING); + assertEquals(taskStateMachine.getState(), RUNNING); switch (executionStrategy) { case UNGROUPED_EXECUTION: @@ -579,9 +582,9 @@ public void testComplex(PipelineExecutionStrategy executionStrategy) throw new UnsupportedOperationException(); } + assertEquals(taskStateMachine.getStateChange(RUNNING).get(10, SECONDS), FLUSHING); outputBufferConsumer.abort(); // complete the task by calling abort on it - TaskState taskState = taskStateMachine.getStateChange(TaskState.RUNNING).get(10, SECONDS); - assertEquals(taskState, TaskState.FINISHED); + assertEquals(taskStateMachine.getStateChange(FLUSHING).get(10, SECONDS), FINISHED); } finally { taskExecutor.stop(); diff --git a/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskManager.java b/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskManager.java index 3a5d285da9d38..e284863624e26 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskManager.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestSqlTaskManager.java @@ -54,9 +54,11 @@ import static io.prestosql.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; @Test public class TestSqlTaskManager @@ -110,21 +112,20 @@ public void testSimpleQuery() { try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) { TaskId taskId = TASK_ID; - TaskInfo taskInfo = createTask(sqlTaskManager, taskId, ImmutableSet.of(SPLIT), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + createTask(sqlTaskManager, taskId, ImmutableSet.of(SPLIT), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); - taskInfo = sqlTaskManager.getTaskInfo(taskId); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + TaskInfo taskInfo = sqlTaskManager.getTaskInfo(taskId, TaskState.RUNNING).get(1, TimeUnit.SECONDS); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); BufferResult results = sqlTaskManager.getTaskResults(taskId, OUT, 0, DataSize.of(1, Unit.MEGABYTE)).get(); - assertEquals(results.isBufferComplete(), false); + assertFalse(results.isBufferComplete()); assertEquals(results.getSerializedPages().size(), 1); assertEquals(results.getSerializedPages().get(0).getPositionCount(), 1); for (boolean moreResults = true; moreResults; moreResults = !results.isBufferComplete()) { results = sqlTaskManager.getTaskResults(taskId, OUT, results.getToken() + results.getSerializedPages().size(), DataSize.of(1, Unit.MEGABYTE)).get(); } - assertEquals(results.isBufferComplete(), true); + assertTrue(results.isBufferComplete()); assertEquals(results.getSerializedPages().size(), 0); // complete the task by calling abort on it @@ -190,11 +191,10 @@ public void testAbortResults() { try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) { TaskId taskId = TASK_ID; - TaskInfo taskInfo = createTask(sqlTaskManager, taskId, ImmutableSet.of(SPLIT), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + createTask(sqlTaskManager, taskId, ImmutableSet.of(SPLIT), createInitialEmptyOutputBuffers(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); - taskInfo = sqlTaskManager.getTaskInfo(taskId); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); + TaskInfo taskInfo = sqlTaskManager.getTaskInfo(taskId, TaskState.RUNNING).get(1, TimeUnit.SECONDS); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); sqlTaskManager.abortTaskResults(taskId, OUT); diff --git a/presto-main/src/test/java/io/prestosql/execution/TestStageStateMachine.java b/presto-main/src/test/java/io/prestosql/execution/TestStageStateMachine.java index 13ba72498a907..a6740469487df 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestStageStateMachine.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestStageStateMachine.java @@ -81,6 +81,9 @@ public void testBasicStateChanges() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); } @@ -99,6 +102,10 @@ public void testPlanned() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); + stateMachine = createStageStateMachine(); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + stateMachine = createStageStateMachine(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -134,6 +141,11 @@ public void testScheduling() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); + stateMachine = createStageStateMachine(); + stateMachine.transitionToScheduling(); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + stateMachine = createStageStateMachine(); stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToFinished()); @@ -171,6 +183,9 @@ public void testScheduled() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + stateMachine = createStageStateMachine(); stateMachine.transitionToScheduled(); assertTrue(stateMachine.transitionToFinished()); @@ -208,6 +223,11 @@ public void testRunning() assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + + stateMachine = createStageStateMachine(); + stateMachine.transitionToRunning(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -227,6 +247,46 @@ public void testRunning() assertState(stateMachine, StageState.CANCELED); } + @Test + public void testFlushing() + { + StageStateMachine stateMachine = createStageStateMachine(); + assertTrue(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + + assertFalse(stateMachine.transitionToScheduling()); + assertState(stateMachine, StageState.FLUSHING); + + assertFalse(stateMachine.transitionToScheduled()); + assertState(stateMachine, StageState.FLUSHING); + + assertFalse(stateMachine.transitionToRunning()); + assertState(stateMachine, StageState.FLUSHING); + + assertFalse(stateMachine.transitionToFlushing()); + assertState(stateMachine, StageState.FLUSHING); + + stateMachine = createStageStateMachine(); + stateMachine.transitionToFlushing(); + assertTrue(stateMachine.transitionToFinished()); + assertState(stateMachine, StageState.FINISHED); + + stateMachine = createStageStateMachine(); + stateMachine.transitionToFlushing(); + assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); + assertState(stateMachine, StageState.FAILED); + + stateMachine = createStageStateMachine(); + stateMachine.transitionToFlushing(); + assertTrue(stateMachine.transitionToAborted()); + assertState(stateMachine, StageState.ABORTED); + + stateMachine = createStageStateMachine(); + stateMachine.transitionToFlushing(); + assertTrue(stateMachine.transitionToCanceled()); + assertState(stateMachine, StageState.CANCELED); + } + @Test public void testFinished() { @@ -278,6 +338,9 @@ private static void assertFinalState(StageStateMachine stateMachine, StageState assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, expectedState); + assertFalse(stateMachine.transitionToFlushing()); + assertState(stateMachine, expectedState); + assertFalse(stateMachine.transitionToFinished()); assertState(stateMachine, expectedState); diff --git a/presto-tests/src/test/java/io/prestosql/execution/TestFlushingStageState.java b/presto-tests/src/test/java/io/prestosql/execution/TestFlushingStageState.java new file mode 100644 index 0000000000000..64f2121f6bda6 --- /dev/null +++ b/presto-tests/src/test/java/io/prestosql/execution/TestFlushingStageState.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.execution; + +import com.google.common.collect.ImmutableMap; +import io.airlift.units.Duration; +import io.prestosql.spi.QueryId; +import io.prestosql.testing.DistributedQueryRunner; +import io.prestosql.tests.tpch.TpchQueryRunnerBuilder; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static io.prestosql.SessionTestUtils.TEST_SESSION; +import static io.prestosql.execution.QueryState.RUNNING; +import static io.prestosql.execution.StageState.CANCELED; +import static io.prestosql.execution.StageState.FLUSHING; +import static io.prestosql.execution.TestQueryRunnerUtil.createQuery; +import static io.prestosql.execution.TestQueryRunnerUtil.waitForQueryState; +import static io.prestosql.testing.assertions.Assert.assertEventually; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.testng.Assert.assertEquals; + +public class TestFlushingStageState +{ + private DistributedQueryRunner queryRunner; + + @BeforeClass + public void setup() + throws Exception + { + queryRunner = TpchQueryRunnerBuilder.builder().buildWithoutCatalogs(); + queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of("tpch.splits-per-node", "10000")); + } + + @Test(timeOut = 30_000) + public void testFlushingState() + throws Exception + { + QueryId queryId = createQuery(queryRunner, TEST_SESSION, "SELECT * FROM tpch.sf1000.lineitem limit 1"); + waitForQueryState(queryRunner, queryId, RUNNING); + + // wait for the query to finish producing results, but don't poll them + assertEventually( + new Duration(10, SECONDS), + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getState(), FLUSHING)); + + // wait for the sub stages to go to cancelled state + assertEventually( + new Duration(10, SECONDS), + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getSubStages().get(0).getState(), CANCELED)); + + QueryInfo queryInfo = queryRunner.getCoordinator().getFullQueryInfo(queryId); + assertEquals(queryInfo.getState(), RUNNING); + assertEquals(queryInfo.getOutputStage().get().getState(), FLUSHING); + assertEquals(queryInfo.getOutputStage().get().getSubStages().size(), 1); + assertEquals(queryInfo.getOutputStage().get().getSubStages().get(0).getState(), CANCELED); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + if (queryRunner != null) { + queryRunner.close(); + } + } +}