Skip to content

Commit

Permalink
[ML] clear job size estimate cache when feature is reset (#74494)
Browse files Browse the repository at this point in the history
Since the feature reset API clears out the `.ml-*` indices, it follows that it also deletes the machine learning jobs. 

But, since the regular path of calling the delete job API is not followed, jobs that no longer exist could still have memory estimates cached on the master node. These would never get cleared out until after a master node changed. 

This commit causes feature reset to: 
 - await for all refresh requests to finish (of which there should usually be NONE as all assignments have been cancelled)
 - clear out the cached hashmap of memory estimates sitting on the master node
 - Then once cleared, new refreshes are allowed again
  • Loading branch information
benwtrent authored Jun 24, 2021
1 parent a5af44d commit c37184c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1312,9 +1312,12 @@ public String getFeatureDescription() {
public void cleanUpFeature(
ClusterService clusterService,
Client client,
ActionListener<ResetFeatureStateResponse.ResetFeatureStateStatus> finalListener) {
ActionListener<ResetFeatureStateResponse.ResetFeatureStateStatus> finalListener
) {
logger.info("Starting machine learning feature reset");

final Map<String, Boolean> results = new ConcurrentHashMap<>();

ActionListener<ResetFeatureStateResponse.ResetFeatureStateStatus> unsetResetModeListener = ActionListener.wrap(
success -> client.execute(SetResetModeAction.INSTANCE, SetResetModeActionRequest.disabled(true), ActionListener.wrap(
resetSuccess -> finalListener.onResponse(success),
Expand All @@ -1328,22 +1331,33 @@ public void cleanUpFeature(
);
})
),
failure -> client.execute(SetResetModeAction.INSTANCE, SetResetModeActionRequest.disabled(false), ActionListener.wrap(
resetSuccess -> finalListener.onFailure(failure),
resetFailure -> {
logger.error("failed to disable reset mode after state clean up failure", resetFailure);
finalListener.onFailure(failure);
})
)
failure -> {
logger.error("failed to reset machine learning", failure);
client.execute(SetResetModeAction.INSTANCE, SetResetModeActionRequest.disabled(false), ActionListener.wrap(
resetSuccess -> finalListener.onFailure(failure),
resetFailure -> {
logger.error("failed to disable reset mode after state clean up failure", resetFailure);
finalListener.onFailure(failure);
})
);
}
);

Map<String, Boolean> results = new ConcurrentHashMap<>();

ActionListener<ListTasksResponse> afterWaitingForTasks = ActionListener.wrap(
listTasksResponse -> {
listTasksResponse.rethrowFailures("Waiting for indexing requests for .ml-* indices");
if (results.values().stream().allMatch(b -> b)) {
// Call into the original listener to clean up the indices
if (memoryTracker.get() != null) {
memoryTracker.get().awaitAndClear(ActionListener.wrap(
cacheCleared -> SystemIndexPlugin.super.cleanUpFeature(clusterService, client, unsetResetModeListener),
clearFailed -> {
logger.error("failed to clear memory tracker cache via machine learning reset feature API", clearFailed);
SystemIndexPlugin.super.cleanUpFeature(clusterService, client, unsetResetModeListener);
}
));
return;
}
// Call into the original listener to clean up the indices and then clear ml memory cache
SystemIndexPlugin.super.cleanUpFeature(clusterService, client, unsetResetModeListener);
} else {
final List<String> failedComponents = results.entrySet().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,36 @@ public AutoscalingDeciderResult scale(Settings configuration, AutoscalingDecider
return scaleUpFromZero(waitingAnomalyJobs, waitingAnalyticsJobs, reasonBuilder);
}

// We don't need to check anything as there are no tasks
// This is a quick path to downscale.
// simply return `0` for scale down if delay is satisfied
if (anomalyDetectionTasks.isEmpty() && dataframeAnalyticsTasks.isEmpty()) {
long msLeftToScale = msLeftToDownScale(configuration);
if (msLeftToScale > 0) {
return new AutoscalingDeciderResult(
context.currentCapacity(),
reasonBuilder
.setSimpleReason(
String.format(
Locale.ROOT,
"Passing currently perceived capacity as down scale delay has not been satisfied; configured delay [%s]"
+ "last detected scale down event [%s]. Will request scale down in approximately [%s]",
DOWN_SCALE_DELAY.get(configuration).getStringRep(),
XContentElasticsearchExtension.DEFAULT_DATE_PRINTER.print(scaleDownDetected),
TimeValue.timeValueMillis(msLeftToScale).getStringRep()
)
)
.build());
}
return new AutoscalingDeciderResult(
AutoscalingCapacity.ZERO,
reasonBuilder
.setRequiredCapacity(AutoscalingCapacity.ZERO)
.setSimpleReason("Requesting scale down as tier and/or node size could be smaller")
.build()
);
}

if (mlMemoryTracker.isRecentlyRefreshed(memoryTrackingStale) == false) {
logger.debug(() -> new ParameterizedMessage(
"view of job memory is stale given duration [{}]. Not attempting to make scaling decision",
Expand Down Expand Up @@ -521,15 +551,11 @@ public AutoscalingDeciderResult scale(Settings configuration, AutoscalingDecider
}
}

final long now = timeSupplier.get();
if (newScaleDownCheck()) {
scaleDownDetected = now;
}
TimeValue downScaleDelay = DOWN_SCALE_DELAY.get(configuration);
long msLeftToScale = downScaleDelay.millis() - (now - scaleDownDetected);
long msLeftToScale = msLeftToDownScale(configuration);
if (msLeftToScale <= 0) {
return scaleDownDecision.get();
}
TimeValue downScaleDelay = DOWN_SCALE_DELAY.get(configuration);
logger.debug(() -> new ParameterizedMessage(
"not scaling down as the current scale down delay [{}] is not satisfied." +
" The last time scale down was detected [{}]. Calculated scaled down capacity [{}] ",
Expand All @@ -542,7 +568,7 @@ public AutoscalingDeciderResult scale(Settings configuration, AutoscalingDecider
.setSimpleReason(
String.format(
Locale.ROOT,
"Passing currently perceived capacity as down scale delay has not be satisfied; configured delay [%s]"
"Passing currently perceived capacity as down scale delay has not been satisfied; configured delay [%s]"
+ "last detected scale down event [%s]. Will request scale down in approximately [%s]",
downScaleDelay.getStringRep(),
XContentElasticsearchExtension.DEFAULT_DATE_PRINTER.print(scaleDownDetected),
Expand Down Expand Up @@ -835,6 +861,15 @@ Optional<AutoscalingDeciderResult> checkForScaleDown(List<NodeLoad> nodeLoads,
return Optional.empty();
}

private long msLeftToDownScale(Settings configuration) {
final long now = timeSupplier.get();
if (newScaleDownCheck()) {
scaleDownDetected = now;
}
TimeValue downScaleDelay = DOWN_SCALE_DELAY.get(configuration);
return downScaleDelay.millis() - (now - scaleDownDetected);
}

@Override
public String name() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.LocalNodeMasterListener;
Expand Down Expand Up @@ -42,6 +43,7 @@
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Phaser;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -72,6 +74,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
private final JobResultsProvider jobResultsProvider;
private final DataFrameAnalyticsConfigProvider configProvider;
private final Phaser stopPhaser;
private volatile AtomicInteger phase = new AtomicInteger(0);
private volatile boolean isMaster;
private volatile Instant lastUpdateTime;
private volatile Duration reassignmentRecheckInterval;
Expand Down Expand Up @@ -115,6 +118,39 @@ public void onMaster() {
public void offMaster() {
isMaster = false;
logger.trace("ML memory tracker off master");
clear();
}

public void awaitAndClear(ActionListener<Void> listener) {
// We never terminate the phaser
logger.trace("awaiting and clearing memory tracker");
assert stopPhaser.isTerminated() == false;
// If there are no registered parties or no unarrived parties then there is a flaw
// in the register/arrive/unregister logic in another method that uses the phaser
assert stopPhaser.getRegisteredParties() > 0;
assert stopPhaser.getUnarrivedParties() > 0;
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(
() -> {
try {
// We await all current refreshes to complete, this increments the "current phase" and prevents
// further interaction while we clear contents
int newPhase = stopPhaser.arriveAndAwaitAdvance();
assert newPhase > 0;
clear();
phase.incrementAndGet();
logger.trace("completed awaiting and clearing memory tracker");
listener.onResponse(null);
} catch (Exception e) {
logger.warn("failed to wait for all refresh requests to complete", e);
listener.onFailure(e);
}
}
);

}

private void clear() {
logger.trace("clearing ML Memory tracker contents");
for (Map<String, Long> memoryRequirementByJob : memoryRequirementByTaskName.values()) {
memoryRequirementByJob.clear();
}
Expand Down Expand Up @@ -325,6 +361,7 @@ void refresh(PersistentTasksCustomMetadata persistentTasks, ActionListener<Void>
for (ActionListener<Void> listener : fullRefreshCompletionListeners) {
listener.onFailure(e);
}
logger.warn("ML memory tracker last update failed and listeners called", e);
// It's critical that we empty out the current listener list on
// error otherwise subsequent retries to refresh will be ignored
fullRefreshCompletionListeners.clear();
Expand Down Expand Up @@ -401,9 +438,13 @@ public void refreshAnomalyDetectorJobMemory(String jobId, ActionListener<Long> l
}

// The phaser prevents searches being started after the memory tracker's stop() method has returned
if (stopPhaser.register() != 0) {
// Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
// Note: `phase` is incremented if cache is reset via the feature reset API
if (stopPhaser.register() != phase.get()) {
// Phases above not equal to `phase` mean we've been stopped, so don't do any operations that involve external interaction
stopPhaser.arriveAndDeregister();
logger.info(
() -> new ParameterizedMessage("[{}] not refreshing anomaly detector memory as node is shutting down", jobId)
);
listener.onFailure(new EsRejectedExecutionException("Couldn't run ML memory update - node is shutting down"));
return;
}
Expand Down

0 comments on commit c37184c

Please sign in to comment.