Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support task batches in ReservedStateUpdateTaskExecutor #116353

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@
import org.elasticsearch.cluster.metadata.ReservedStateMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.reservedstate.action.ReservedClusterSettingsAction;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
Expand Down Expand Up @@ -357,7 +360,7 @@ private void assertClusterStateNotSaved(CountDownLatch savedClusterState, Atomic
updateClusterSettings(Settings.builder().put("search.allow_expensive_queries", "false"));
}

public void testErrorSaved() throws Exception {
public void testErrorNotSaved() throws Exception {
internalCluster().setBootstrapMasterNodeIndex(0);
logger.info("--> start data node / non master node");
String dataNode = internalCluster().startNode(Settings.builder().put(dataOnlyNode()).put("discovery.initial_state_timeout", "1s"));
Expand All @@ -381,6 +384,40 @@ public void testErrorSaved() throws Exception {
assertClusterStateNotSaved(savedClusterState.v1(), savedClusterState.v2());
}

public void testLastSettingsInBatchApplied() throws Exception {
internalCluster().setBootstrapMasterNodeIndex(0);
logger.info("--> start data node / non master node");
String dataNode = internalCluster().startNode(Settings.builder().put(dataOnlyNode()).put("discovery.initial_state_timeout", "1s"));
FileSettingsService dataFileSettingsService = internalCluster().getInstance(FileSettingsService.class, dataNode);

assertFalse(dataFileSettingsService.watching());

logger.info("--> start master node");
final String masterNode = internalCluster().startMasterOnlyNode();
assertMasterNode(internalCluster().nonMasterClient(), masterNode);
var savedClusterState = setupClusterStateListener(masterNode);

FileSettingsService masterFileSettingsService = internalCluster().getInstance(FileSettingsService.class, masterNode);

assertTrue(masterFileSettingsService.watching());
assertFalse(dataFileSettingsService.watching());

final var masterNodeClusterService = internalCluster().getCurrentMasterNodeInstance(ClusterService.class);
final var barrier = new CyclicBarrier(2);
masterNodeClusterService.createTaskQueue("block", Priority.NORMAL, batchExecutionContext -> {
safeAwait(barrier);
safeAwait(barrier);
batchExecutionContext.taskContexts().forEach(c -> c.success(() -> {}));
return batchExecutionContext.initialState();
}).submitTask("block", ESTestCase::fail, null);

safeAwait(barrier);
writeJSONFile(masterNode, testJSON, versionCounter, logger); // Valid but skipped
writeJSONFile(masterNode, testJSON43mb, versionCounter, logger); // The last valid setting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we sometimes try 3, I have a hunch that might cover some logic that the 2 case doesn't.

Could we give them different version values? And not always apply them in ascending order?

Could we wait for the master to pick up one file before writing the next?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good ideas!

Waiting for the master to pick up one file before writing the next would exercise existing functionality, rather than my changes, but it's obviously an important case to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the expected behaviour when the versions are not ascending?

Should ReservedStateUpdateTaskExecutor be attempting them in descending version order rather than just iterating backward through the batch list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guiding principle is that the expected outcome should be the same whether the tasks are executed in a single batch or one-at-a-time. So yes AIUI I think we should attempt them in descending version order.

Copy link
Contributor Author

@prdoyle prdoyle Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah of course, that makes perfect sense. I'll make that change (on the assumption that updates with lower version numbers are ignored if processed one-per-batch).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I found this javadoc for TaskContext:

A task to be executed, along with callbacks for the executor to record the outcome of this task's execution. The executor must call exactly one of these methods for every task in its batch.

It seems we can't simply skip the intervening tasks after all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to call success (or onFailure) to record the outcome of the task, but that doesn't mean you have to actually do anything else with the task. According to the guiding principle, skipping a task because we processed a newer one counts as success right?

Copy link
Contributor Author

@prdoyle prdoyle Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright makes sense. So perhaps:

  1. Scan the tasks to identify the one that would have run last, had the tasks all been run individually instead of in a batch, taking into consideration task ordering, version numbers, and the ReservedStateVersionCheck mode.
  2. Try that task. If it succeeds, call success on that one and all prior ones, and possibly on all following ones?
  3. If it fails, call onFailure, remove it from the candidates list, and loop back to step 1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that sounds right although maybe it'd be simpler just to sort the list of tasks in the right order ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was afraid I couldn't determine the right order before I know whether the first fails. But now that I think about it, if the first task succeeds, the rest are irrelevant, so I could sort them on the assumption that each one fails.

safeAwait(barrier);
assertClusterStateSaveOK(savedClusterState.v1(), savedClusterState.v2(), "43mb");
}

public void testErrorCanRecoverOnRestart() throws Exception {
internalCluster().setBootstrapMasterNodeIndex(0);
logger.info("--> start data node / non master node");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.SimpleBatchedExecutor;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.routing.RerouteService;
import org.elasticsearch.common.Priority;
import org.elasticsearch.core.Tuple;

/**
* Reserved cluster state update task executor
*/
public class ReservedStateUpdateTaskExecutor extends SimpleBatchedExecutor<ReservedStateUpdateTask, Void> {
public class ReservedStateUpdateTaskExecutor implements ClusterStateTaskExecutor<ReservedStateUpdateTask> {

private static final Logger logger = LogManager.getLogger(ReservedStateUpdateTaskExecutor.class);

Expand All @@ -34,17 +33,38 @@ public ReservedStateUpdateTaskExecutor(RerouteService rerouteService) {
}

@Override
public Tuple<ClusterState, Void> executeTask(ReservedStateUpdateTask task, ClusterState clusterState) {
return Tuple.tuple(task.execute(clusterState), null);
}
public final ClusterState execute(BatchExecutionContext<ReservedStateUpdateTask> batchExecutionContext) throws Exception {
var initState = batchExecutionContext.initialState();
var taskContexts = batchExecutionContext.taskContexts();
if (taskContexts.isEmpty()) {
return initState;
}

@Override
public void taskSucceeded(ReservedStateUpdateTask task, Void unused) {
task.listener().onResponse(ActionResponse.Empty.INSTANCE);
// Only the last update is relevant; the others can be skipped.
// However, if that last update task fails, we should fall back to the preceding one.
for (var iterator = taskContexts.listIterator(taskContexts.size()); iterator.hasPrevious();) {
var taskContext = iterator.previous();
ClusterState clusterState = initState;
try (var ignored = taskContext.captureResponseHeaders()) {
var task = taskContext.getTask();
clusterState = task.execute(clusterState);
taskContext.success(() -> task.listener().onResponse(ActionResponse.Empty.INSTANCE));
logger.debug("Update task succeeded");
return clusterState;
} catch (Exception e) {
taskContext.onFailure(e);
if (iterator.hasPrevious()) {
logger.warn("Update task failed; will try the previous update task");
}
}
}

logger.warn("All {} update tasks failed; returning initial state", taskContexts.size());
return initState;
}

@Override
public void clusterStatePublished() {
public final void clusterStatePublished(ClusterState newClusterState) {
rerouteService.reroute(
"reroute after saving and reserving part of the cluster state",
Priority.NORMAL,
Expand All @@ -54,4 +74,5 @@ public void clusterStatePublished() {
)
);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.reservedstate.service;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateAckListener;
Expand Down Expand Up @@ -39,6 +40,7 @@
import org.mockito.ArgumentMatchers;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
Expand All @@ -63,7 +65,9 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -279,6 +283,103 @@ public void success(Runnable onPublicationSuccess) {
verify(rerouteService, times(1)).reroute(anyString(), any(), any());
}

public void testLastUpdateIsApplied() throws Exception {
ClusterName clusterName = new ClusterName("test");
ReservedStateUpdateTask realTask = new ReservedStateUpdateTask(
clusterName.value(),
null,
ReservedStateVersionCheck.HIGHER_VERSION_ONLY,
Map.of(),
Set.of(),
errorState -> fail("Unexpected error"),
ActionListener.noop()
);
var updates = mockUpdateSequence(2, clusterName, realTask);
ClusterState state0 = ClusterState.builder(clusterName).version(1000).build();
ClusterState newState = new ReservedStateUpdateTaskExecutor(mock(RerouteService.class)).execute(
new ClusterStateTaskExecutor.BatchExecutionContext<>(state0, updates.taskContexts(), () -> null)
);

assertThat("State should be the final state", newState, sameInstance(updates.states().get(updates.states().size() - 1)));

// Only process the final task; the intermediate ones can be skipped
verify(updates.tasks().get(0), times(0)).execute(any());
verify(updates.tasks().get(1), times(1)).execute(any());
}

public void testLastSuccessfulUpdateIsApplied() throws Exception {
ClusterName clusterName = new ClusterName("test");
ReservedStateUpdateTask realTask = new ReservedStateUpdateTask(
clusterName.value(),
null,
ReservedStateVersionCheck.HIGHER_VERSION_ONLY,
Map.of(),
Set.of(),
errorState -> fail("Unexpected error"),
ActionListener.noop()
) {
@Override
ActionListener<ActionResponse.Empty> listener() {
var superListener = super.listener();
return new ActionListener<>() {
@Override
public void onResponse(ActionResponse.Empty empty) {
superListener.onResponse(empty);
}

@Override
public void onFailure(Exception e) {
superListener.onFailure(e);
}
};
}
};

var updates = mockUpdateSequence(3, clusterName, realTask);

// Inject an error in the last update
reset(updates.tasks().get(2));
doThrow(UnsupportedOperationException.class).when(updates.tasks().get(2)).execute(any());

ClusterState state0 = ClusterState.builder(clusterName).version(1000).build();
ClusterState newState = new ReservedStateUpdateTaskExecutor(mock(RerouteService.class)).execute(
new ClusterStateTaskExecutor.BatchExecutionContext<>(state0, updates.taskContexts(), () -> null)
);

assertThat("State should be the last successful state", newState, sameInstance(updates.states().get(1)));

// Only process the final task; the intermediate ones can be skipped
verify(updates.tasks().get(2), times(1)).execute(any()); // Tried the last one, it failed
verify(updates.tasks().get(1), times(1)).execute(any()); // Tried the second-last one, it succeeded
verify(updates.tasks().get(0), times(0)).execute(any()); // Didn't bother trying the first one
}

/**
* @param tasks Mockito spies configured to return a specific state
* @param states the corresponding states returned by {@link #tasks}
*/
private record MockUpdateSequence(List<ReservedStateUpdateTask> tasks, List<ClusterState> states) {
public List<TestTaskContext<ReservedStateUpdateTask>> taskContexts() {
return tasks.stream().map(TestTaskContext::new).toList();
}
}

/**
* @return a sequence of updates that bump the version starting from 1001.
*/
private MockUpdateSequence mockUpdateSequence(int quantity, ClusterName clusterName, ReservedStateUpdateTask realTask) {
List<ReservedStateUpdateTask> tasks = new ArrayList<>(quantity);
List<ClusterState> states = new ArrayList<>(quantity);
for (int i = 0; i < quantity; i++) {
ClusterState state = ClusterState.builder(clusterName).version(1001 + i).build();
ReservedStateUpdateTask task = spy(realTask);
doReturn(state).when(task).execute(any());
tasks.add(task);
states.add(state);
}
return new MockUpdateSequence(tasks, states);
}

public void testUpdateErrorState() {
ClusterService clusterService = mock(ClusterService.class);
ClusterState state = ClusterState.builder(new ClusterName("test")).build();
Expand Down Expand Up @@ -400,7 +501,7 @@ public TransformState transform(Object source, TransformState prevState) throws
var chunk = new ReservedStateChunk(Map.of("one", "two", "maker", "three"), new ReservedStateVersion(2L, BuildVersion.current()));
var orderedHandlers = List.of(exceptionThrower.name(), newStateMaker.name());

// We submit a task with two handler, one will cause an exception, the other will create a new state.
// We submit a task with two handlers, one will cause an exception, the other will create a new state.
// When we fail to update the metadata because of version, we ensure that the returned state is equal to the
// original state by pointer reference to avoid cluster state update task to run.
ReservedStateUpdateTask task = new ReservedStateUpdateTask(
Expand Down