Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.x] [ML] fix machine learning job close/kill race condition (#71656) #71750

Merged
merged 1 commit into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,11 @@ public void writeToJob(InputStream inputStream, AnalysisRegistry analysisRegistr
handler);
}

@Override
public void close() {
close(false, null);
}

/**
* Closes job this communicator is encapsulating.
*
* @param restart Whether the job should be restarted by persistent tasks
* @param reason The reason for closing the job
*/
public void close(boolean restart, String reason) {
@Override
public void close() {
Future<?> future = autodetectWorkerExecutor.submit(() -> {
checkProcessIsAlive();
try {
Expand All @@ -166,7 +159,7 @@ public void close(boolean restart, String reason) {
}
autodetectResultProcessor.awaitCompletion();
} finally {
onFinishHandler.accept(restart ? new ElasticsearchException(reason) : null, true);
onFinishHandler.accept(null, true);
}
LOGGER.info("[{}] job closed", job.getId());
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public synchronized void closeAllJobsOnThisNode(String reason) {
logger.info("Closing [{}] jobs, because [{}]", numJobs, reason);

for (ProcessContext process : processByAllocation.values()) {
closeJob(process.getJobTask(), false, reason);
closeJob(process.getJobTask(), reason);
}
}
}
Expand All @@ -179,11 +179,11 @@ public void killProcess(JobTask jobTask, boolean awaitCompletion, String reason)
ProcessContext processContext = processByAllocation.remove(jobTask.getAllocationId());
if (processContext != null) {
processContext.newKillBuilder()
.setAwaitCompletion(awaitCompletion)
.setFinish(true)
.setReason(reason)
.setShouldFinalizeJob(upgradeInProgress == false && resetInProgress == false)
.kill();
.setAwaitCompletion(awaitCompletion)
.setFinish(true)
.setReason(reason)
.setShouldFinalizeJob(upgradeInProgress == false && resetInProgress == false)
.kill();
} else {
// If the process is missing but the task exists this is most likely
// due to 2 reasons. The first is because the job went into the failed
Expand Down Expand Up @@ -536,6 +536,18 @@ protected void doRun() {

try {
if (createProcessAndSetRunning(processContext, job, params, closeHandler)) {
if (processContext.getJobTask().isClosing()) {
logger.debug("Aborted opening job [{}] as it is being closed", job.getId());
closeProcessAndTask(processContext, jobTask, "job is already closing");
return;
}
// It is possible that a `kill` request came in before the communicator was set
// This means that the kill was not handled appropriately and we continued down this execution path
if (processContext.shouldBeKilled()) {
logger.debug("Aborted opening job [{}] as it is being killed", job.getId());
processContext.killIt();
return;
}
processContext.getAutodetectCommunicator().restoreState(params.modelSnapshot());
setJobState(jobTask, JobState.OPENED);
}
Expand Down Expand Up @@ -716,25 +728,9 @@ private Consumer<String> onProcessCrash(JobTask jobTask) {
};
}

/**
* Stop the running job and mark it as finished.
*
* @param jobTask The job to stop
* @param restart Whether the job should be restarted by persistent tasks
* @param reason The reason for closing the job
*/
public void closeJob(JobTask jobTask, boolean restart, String reason) {
private void closeProcessAndTask(ProcessContext processContext, JobTask jobTask, String reason) {
String jobId = jobTask.getJobId();
long allocationId = jobTask.getAllocationId();
logger.debug("Attempting to close job [{}], because [{}]", jobId, reason);
// don't remove the process context immediately, because we need to ensure
// it is reachable to enable killing a job while it is closing
ProcessContext processContext = processByAllocation.get(allocationId);
if (processContext == null) {
logger.debug("Cannot close job [{}] as it has already been closed or is closing", jobId);
return;
}

processContext.tryLock();
try {
if (processContext.setDying() == false) {
Expand All @@ -755,7 +751,7 @@ public void closeJob(JobTask jobTask, boolean restart, String reason) {
logger.debug("Job [{}] is being closed before its process is started", jobId);
jobTask.markAsCompleted();
} else {
communicator.close(restart, reason);
communicator.close();
}

processByAllocation.remove(allocationId);
Expand All @@ -781,6 +777,26 @@ public void closeJob(JobTask jobTask, boolean restart, String reason) {
}
}

/**
* Stop the running job and mark it as finished.
*
* @param jobTask The job to stop
* @param reason The reason for closing the job
*/
public void closeJob(JobTask jobTask, String reason) {
String jobId = jobTask.getJobId();
long allocationId = jobTask.getAllocationId();
logger.debug("Attempting to close job [{}], because [{}]", jobId, reason);
// don't remove the process context immediately, because we need to ensure
// it is reachable to enable killing a job while it is closing
ProcessContext processContext = processByAllocation.get(allocationId);
if (processContext == null) {
logger.debug("Cannot close job [{}] as it has already been closed or is closing", jobId);
return;
}
closeProcessAndTask(processContext, jobTask, reason);
}

int numberOfOpenJobs() {
return (int) processByAllocation.values().stream()
.filter(p -> p.getState() != ProcessContext.ProcessStateName.DYING)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ final class ProcessContext {
private final JobTask jobTask;
private volatile AutodetectCommunicator autodetectCommunicator;
private volatile ProcessState state;
private volatile KillBuilder latestKillRequest = null;

ProcessContext(JobTask jobTask) {
this.jobTask = jobTask;
Expand All @@ -46,6 +47,17 @@ private void setAutodetectCommunicator(AutodetectCommunicator autodetectCommunic
this.autodetectCommunicator = autodetectCommunicator;
}

boolean shouldBeKilled() {
return latestKillRequest != null;
}

void killIt() {
if (latestKillRequest == null) {
throw new IllegalArgumentException("Unable to kill job as previous request is not completed");
}
latestKillRequest.kill();
}

ProcessStateName getState() {
return state.getName();
}
Expand Down Expand Up @@ -117,6 +129,7 @@ KillBuilder setShouldFinalizeJob(boolean shouldFinalizeJob) {

void kill() {
if (autodetectCommunicator == null) {
latestKillRequest = this;
return;
}
String jobId = jobTask.getJobId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public boolean isClosing() {

public void closeJob(String reason) {
isClosing = true;
autodetectProcessManager.closeJob(this, false, reason);
autodetectProcessManager.closeJob(this, reason);
}

void setAutodetectProcessManager(AutodetectProcessManager autodetectProcessManager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ public void testOpenJob_exceedMaxNumJobs() {
jobTask = mock(JobTask.class);
when(jobTask.getAllocationId()).thenReturn(2L);
when(jobTask.getJobId()).thenReturn("baz");
manager.closeJob(jobTask, false, null);
manager.closeJob(jobTask, null);
assertEquals(2, manager.numberOfOpenJobs());
manager.openJob(jobTask, clusterState, (e1, b) -> {});
assertEquals(3, manager.numberOfOpenJobs());
Expand Down Expand Up @@ -372,7 +372,7 @@ public void testCloseJob() {

// job is created
assertEquals(1, manager.numberOfOpenJobs());
manager.closeJob(jobTask, false, null);
manager.closeJob(jobTask, null);
assertEquals(0, manager.numberOfOpenJobs());
}

Expand All @@ -384,7 +384,7 @@ public void testCanCloseClosingJob() throws Exception {
// the middle of the AutodetectProcessManager.close() method
Thread.yield();
return null;
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
}).when(autodetectCommunicator).close();
AutodetectProcessManager manager = createSpyManager();
assertEquals(0, manager.numberOfOpenJobs());

Expand All @@ -397,12 +397,12 @@ public void testCanCloseClosingJob() throws Exception {
assertEquals(1, manager.numberOfOpenJobs());

// Close the job in a separate thread
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, false, "in separate thread"));
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, "in separate thread"));
closeThread.start();
Thread.yield();

// Also close the job in the current thread, so that we have two simultaneous close requests
manager.closeJob(jobTask, false, "in main test thread");
manager.closeJob(jobTask, "in main test thread");

// The 10 second timeout here is usually far in excess of what is required. In the vast
// majority of cases the other thread will exit within a few milliseconds. However, it
Expand All @@ -427,7 +427,7 @@ public void testCanKillClosingJob() throws Exception {
closeInterruptedLatch.countDown();
}
return null;
}).when(autodetectCommunicator).close(anyBoolean(), anyString());
}).when(autodetectCommunicator).close();
doAnswer(invocationOnMock -> {
killLatch.countDown();
return null;
Expand All @@ -442,7 +442,7 @@ public void testCanKillClosingJob() throws Exception {
mock(DataLoadParams.class), (dataCounts1, e) -> {});

// Close the job in a separate thread so that it can simulate taking a long time to close
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, false, null));
Thread closeThread = new Thread(() -> manager.closeJob(jobTask, null));
closeThread.start();
assertTrue(closeStartedLatch.await(3, TimeUnit.SECONDS));

Expand Down Expand Up @@ -509,7 +509,7 @@ public void testCloseThrows() {

// let the communicator throw, simulating a problem with the underlying
// autodetect, e.g. a crash
doThrow(Exception.class).when(autodetectCommunicator).close(anyBoolean(), anyString());
doThrow(Exception.class).when(autodetectCommunicator).close();

// create a jobtask
JobTask jobTask = mock(JobTask.class);
Expand All @@ -521,7 +521,7 @@ public void testCloseThrows() {
verify(manager).setJobState(any(), eq(JobState.OPENED));
// job is created
assertEquals(1, manager.numberOfOpenJobs());
expectThrows(ElasticsearchException.class, () -> manager.closeJob(jobTask, false, null));
expectThrows(ElasticsearchException.class, () -> manager.closeJob(jobTask, null));
assertEquals(0, manager.numberOfOpenJobs());

verify(manager).setJobState(any(), eq(JobState.FAILED), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
Expand All @@ -27,6 +28,7 @@
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.discovery.MasterNotDiscoveredException;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask;
import org.elasticsearch.persistent.PersistentTasksService;
Expand Down Expand Up @@ -286,7 +288,9 @@ private ActionListener<Response> waitForStopListener(Request request, ActionList

ActionListener<Response> onStopListener = ActionListener.wrap(
waitResponse -> transformConfigManager.refresh(ActionListener.wrap(r -> listener.onResponse(waitResponse), e -> {
logger.warn("Could not refresh state, state information might be outdated", e);
if ((ExceptionsHelper.unwrapCause(e) instanceof IndexNotFoundException) == false) {
logger.warn("Could not refresh state, state information might be outdated", e);
}
listener.onResponse(waitResponse);
})),
listener::onFailure
Expand Down