Skip to content

Commit

Permalink
[ML] fix machine learning job close/kill race condition (#71656) (#71750
Browse files Browse the repository at this point in the history
)

If a machine learning job is killed while it is attempting to open, there is a race condition that may cause it to not close.

This is most evident during the `reset_feature` API call. The reset feature API will kill the jobs, then call close quickly to wait for the persistent tasks to complete. 

But, if this is called while a job is attempting to be assigned to a node, there is a window where the process continues to start even though we attempted to kill and close it.

This commit locks the process context on `kill`, and sets the job to `closing`. This way if the process context is already locked (to start), we won't try to kill it until it is fully started.

Setting the job to `closing` allows the starting process to exit early if the `kill` command has already been completed (before the communicator was created).

closes #71646
  • Loading branch information
benwtrent authored Apr 15, 2021
1 parent a41e0e2 commit f6e1a8f
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,11 @@ public void writeToJob(InputStream inputStream, AnalysisRegistry analysisRegistr
handler);
}

@Override
public void close() {
close(false, null);
}

/**
* Closes job this communicator is encapsulating.
*
* @param restart Whether the job should be restarted by persistent tasks
* @param reason The reason for closing the job
*/
public void close(boolean restart, String reason) {
@Override
public void close() {
Future<?> future = autodetectWorkerExecutor.submit(() -> {
checkProcessIsAlive();
try {
Expand All @@ -166,7 +159,7 @@ public void close(boolean restart, String reason) {
}
autodetectResultProcessor.awaitCompletion();
} finally {
onFinishHandler.accept(restart ? new ElasticsearchException(reason) : null, true);
onFinishHandler.accept(null, true);
}
LOGGER.info("[{}] job closed", job.getId());
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public synchronized void closeAllJobsOnThisNode(String reason) {
logger.info("Closing [{}] jobs, because [{}]", numJobs, reason);

for (ProcessContext process : processByAllocation.values()) {
closeJob(process.getJobTask(), false, reason);
closeJob(process.getJobTask(), reason);
}
}
}
Expand All @@ -179,11 +179,11 @@ public void killProcess(JobTask jobTask, boolean awaitCompletion, String reason)
ProcessContext processContext = processByAllocation.remove(jobTask.getAllocationId());
if (processContext != null) {
processContext.newKillBuilder()
.setAwaitCompletion(awaitCompletion)
.setFinish(true)
.setReason(reason)
.setShouldFinalizeJob(upgradeInProgress == false && resetInProgress == false)
.kill();
.setAwaitCompletion(awaitCompletion)
.setFinish(true)
.setReason(reason)
.setShouldFinalizeJob(upgradeInProgress == false && resetInProgress == false)
.kill();
} else {
// If the process is missing but the task exists this is most likely
// due to 2 reasons. The first is because the job went into the failed
Expand Down Expand Up @@ -536,6 +536,18 @@ protected void doRun() {

try {
if (createProcessAndSetRunning(processContext, job, params, closeHandler)) {
if (processContext.getJobTask().isClosing()) {
logger.debug("Aborted opening job [{}] as it is being closed", job.getId());
closeProcessAndTask(processContext, jobTask, "job is already closing");
return;
}
// It is possible that a `kill` request came in before the communicator was set
// This means that the kill was not handled appropriately and we continued down this execution path
if (processContext.shouldBeKilled()) {
logger.debug("Aborted opening job [{}] as it is being killed", job.getId());
processContext.killIt();
return;
}
processContext.getAutodetectCommunicator().restoreState(params.modelSnapshot());
setJobState(jobTask, JobState.OPENED);
}
Expand Down Expand Up @@ -716,25 +728,9 @@ private Consumer<String> onProcessCrash(JobTask jobTask) {
};
}

/**
* Stop the running job and mark it as finished.
*
* @param jobTask The job to stop
* @param restart Whether the job should be restarted by persistent tasks
* @param reason The reason for closing the job
*/
public void closeJob(JobTask jobTask, boolean restart, String reason) {
private void closeProcessAndTask(ProcessContext processContext, JobTask jobTask, String reason) {
String jobId = jobTask.getJobId();
long allocationId = jobTask.getAllocationId();
logger.debug("Attempting to close job [{}], because [{}]", jobId, reason);
// don't remove the process context immediately, because we need to ensure
// it is reachable to enable killing a job while it is closing
ProcessContext processContext = processByAllocation.get(allocationId);
if (processContext == null) {
logger.debug("Cannot close job [{}] as it has already been closed or is closing", jobId);
return;
}

processContext.tryLock();
try {
if (processContext.setDying() == false) {
Expand All @@ -755,7 +751,7 @@ public void closeJob(JobTask jobTask, boolean restart, String reason) {
logger.debug("Job [{}] is being closed before its process is started", jobId);
jobTask.markAsCompleted();
} else {
communicator.close(restart, reason);
communicator.close();
}

processByAllocation.remove(allocationId);
Expand All @@ -781,6 +777,26 @@ public void closeJob(JobTask jobTask, boolean restart, String reason) {
}
}

/**
* Stop the running job and mark it as finished.
*
* @param jobTask The job to stop
* @param reason The reason for closing the job
*/
public void closeJob(JobTask jobTask, String reason) {
String jobId = jobTask.getJobId();
long allocationId = jobTask.getAllocationId();
logger.debug("Attempting to close job [{}], because [{}]", jobId, reason);
// don't remove the process context immediately, because we need to ensure
// it is reachable to enable killing a job while it is closing
ProcessContext processContext = processByAllocation.get(allocationId);
if (processContext == null) {
logger.debug("Cannot close job [{}] as it has already been closed or is closing", jobId);
return;
}
closeProcessAndTask(processContext, jobTask, reason);
}

int numberOfOpenJobs() {
return (int) processByAllocation.values().stream()
.filter(p -> p.getState() != ProcessContext.ProcessStateName.DYING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ final class ProcessContext {
private final JobTask jobTask;
private volatile AutodetectCommunicator autodetectCommunicator;
private volatile ProcessState state;
private volatile KillBuilder latestKillRequest = null;

ProcessContext(JobTask jobTask) {
this.jobTask = jobTask;
Expand All @@ -46,6 +47,17 @@ private void setAutodetectCommunicator(AutodetectCommunicator autodetectCommunic
this.autodetectCommunicator = autodetectCommunicator;
}

boolean shouldBeKilled() {
return latestKillRequest != null;
}

void killIt() {
if (latestKillRequest == null) {
throw new IllegalArgumentException("Unable to kill job as previous request is not completed");
}
latestKillRequest.kill();
}

ProcessStateName getState() {
return state.getName();
}
Expand Down Expand Up @@ -117,6 +129,7 @@ KillBuilder setShouldFinalizeJob(boolean shouldFinalizeJob) {

void kill() {
if (autodetectCommunicator == null) {
latestKillRequest = this;
return;
}
String jobId = jobTask.getJobId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public boolean isClosing() {

public void closeJob(String reason) {
isClosing = true;
autodetectProcessManager.closeJob(this, false, reason);
autodetectProcessManager.closeJob(this, reason);
}

void setAutodetectProcessManager(AutodetectProcessManager autodetectProcessManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ public void testOpenJob_exceedMaxNumJobs() {
jobTask = mock(JobTask.class);
when(jobTask.getAllocationId()).thenReturn(2L);
when(jobTask.getJobId()).thenReturn("baz");
manager.closeJob(jobTask, false, null);
manager.closeJob(jobTask, null);
assertEquals(2, manager.numberOfOpenJobs());
manager.openJob(jobTask, clusterState, (e1, b) -> {});
assertEquals(3, manager.numberOfOpenJobs());
Expand Down Expand Up @@ -372,7 +372,7 @@ public void testCloseJob() {

// job is created
assertEquals(1, manager.numberOfOpenJobs());
manager.closeJob(jobTask, false, null);
manager.closeJob(jobTask, null);
assertEquals(0, manager.numberOfOpenJobs());
}

Expand All @@ -384,7 +384,7 @@ public void testCanCloseClosingJob() throws Exception {
// the middle of the AutodetectProcessManager.close() method
Thread.yield();
return null;
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
}).when(autodetectCommunicator).close();
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());

Expand All @@ -397,12 +397,12 @@ public void testCanCloseClosingJob() throws Exception {
assertEquals(1, manager.numberOfOpenJobs());

// Close the job in a separate thread
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, false, "in separate thread"));
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, "in separate thread"));
closeThread.start();
Thread.yield();

// Also close the job in the current thread, so that we have two simultaneous close requests
manager.closeJob(jobTask, false, "in main test thread");
manager.closeJob(jobTask, "in main test thread");

// The 10 second timeout here is usually far in excess of what is required. In the vast
// majority of cases the other thread will exit within a few milliseconds. However, it
Expand All @@ -427,7 +427,7 @@ public void testCanKillClosingJob() throws Exception {
closeInterruptedLatch.countDown();
}
return null;
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
}).when(autodetectCommunicator).close();
doAnswer(invocationOnMock -> {
killLatch.countDown();
return null;
Expand All @@ -442,7 +442,7 @@ public void testCanKillClosingJob() throws Exception {
mock(DataLoadParams.class), (dataCounts1, e) -> {});

// Close the job in a separate thread so that it can simulate taking a long time to close
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, false, null));
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, null));
closeThread.start();
assertTrue(closeStartedLatch.await(3, TimeUnit.SECONDS));

Expand Down Expand Up @@ -509,7 +509,7 @@ public void testCloseThrows() {

// let the communicator throw, simulating a problem with the underlying
// autodetect, e.g. a crash
doThrow(Exception.class).when(autodetectCommunicator).close(anyBoolean(), anyString());
doThrow(Exception.class).when(autodetectCommunicator).close();

// create a jobtask
JobTask jobTask = mock(JobTask.class);
Expand All @@ -521,7 +521,7 @@ public void testCloseThrows() {
verify(manager).setJobState(any(), eq(JobState.OPENED));
// job is created
assertEquals(1, manager.numberOfOpenJobs());
expectThrows(ElasticsearchException.class, () -> manager.closeJob(jobTask, false, null));
expectThrows(ElasticsearchException.class, () -> manager.closeJob(jobTask, null));
assertEquals(0, manager.numberOfOpenJobs());

verify(manager).setJobState(any(), eq(JobState.FAILED), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.discovery.MasterNotDiscoveredException;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
import org.elasticsearch.persistent.PersistentTasksService;
Expand Down Expand Up @@ -286,7 +288,9 @@ private ActionListener<Response> waitForStopListener(Request request, ActionList

ActionListener<Response> onStopListener = ActionListener.wrap(
waitResponse -> transformConfigManager.refresh(ActionListener.wrap(r -> listener.onResponse(waitResponse), e -> {
logger.warn("Could not refresh state, state information might be outdated", e);
if ((ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) == false) {
logger.warn("Could not refresh state, state information might be outdated", e);
}
listener.onResponse(waitResponse);
})),
listener::onFailure
Expand Down

0 comments on commit f6e1a8f

Please sign in to comment.