Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Allow analytics process define its own progress phases #55763

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ public String getStateDocId(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX;
}

@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}

public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
*/
String getStateDocId(String jobId);

/**
* Returns the progress phases the analysis goes through in order
*/
List<String> getProgressPhases();

/**
* Summarizes information about the fields that is necessary for analysis to generate
* the parameters needed for the process configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ public String getStateDocId(String jobId) {
throw new UnsupportedOperationException("Outlier detection does not support state");
}

@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}

public enum Method {
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ public String getStateDocId(String jobId) {
return jobId + STATE_DOC_ID_SUFFIX;
}

@Override
public List<String> getProgressPhases() {
return Collections.singletonList("analyzing");
}

public static String extractJobIdFromStateDoc(String stateDocId) {
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
Expand All @@ -55,7 +55,6 @@
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.utils.persistence.MlParserUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -105,25 +104,20 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
ActionListener<QueryPage<Stats>> listener) {
logger.debug("Get stats for running task [{}]", task.getParams().getId());

ActionListener<StatsHolder> statsHolderListener = ActionListener.wrap(
statsHolder -> {
ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> {
Stats stats = buildStats(
task.getParams().getId(),
statsHolder.getProgressTracker().report(),
statsHolder.getDataCountsTracker().report(task.getParams().getId()),
statsHolder.getMemoryUsage(),
statsHolder.getAnalysisStats()
task.getStatsHolder().getProgressTracker().report(),
task.getStatsHolder().getDataCountsTracker().report(task.getParams().getId()),
task.getStatsHolder().getMemoryUsage(),
task.getStatsHolder().getAnalysisStats()
);
listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1,
GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
}, listener::onFailure
);

ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> statsHolderListener.onResponse(task.getStatsHolder()),
listener::onFailure
);

task.updateReindexTaskProgress(reindexingProgressListener);
}

Expand All @@ -138,7 +132,7 @@ protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request req
.collect(Collectors.toList());
request.setExpandedIds(expandedIds);
ActionListener<GetDataFrameAnalyticsStatsAction.Response> runningTasksStatsListener = ActionListener.wrap(
runningTasksStatsResponse -> gatherStatsForStoppedTasks(request.getExpandedIds(), runningTasksStatsResponse,
runningTasksStatsResponse -> gatherStatsForStoppedTasks(getResponse.getResources().results(), runningTasksStatsResponse,
ActionListener.wrap(
finalResponse -> {

Expand All @@ -163,20 +157,20 @@ protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request req
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener);
}

void gatherStatsForStoppedTasks(List<String> expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse,
void gatherStatsForStoppedTasks(List<DataFrameAnalyticsConfig> configs, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse,
ActionListener<GetDataFrameAnalyticsStatsAction.Response> listener) {
List<String> stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results());
if (stoppedTasksIds.isEmpty()) {
List<DataFrameAnalyticsConfig> stoppedConfigs = determineStoppedConfigs(configs, runningTasksResponse.getResponse().results());
if (stoppedConfigs.isEmpty()) {
listener.onResponse(runningTasksResponse);
return;
}

AtomicInteger counter = new AtomicInteger(stoppedTasksIds.size());
AtomicArray<Stats> jobStats = new AtomicArray<>(stoppedTasksIds.size());
for (int i = 0; i < stoppedTasksIds.size(); i++) {
AtomicInteger counter = new AtomicInteger(stoppedConfigs.size());
AtomicArray<Stats> jobStats = new AtomicArray<>(stoppedConfigs.size());
for (int i = 0; i < stoppedConfigs.size(); i++) {
final int slot = i;
String jobId = stoppedTasksIds.get(i);
searchStats(jobId, ActionListener.wrap(
DataFrameAnalyticsConfig config = stoppedConfigs.get(i);
searchStats(config, ActionListener.wrap(
stats -> {
jobStats.set(slot, stats);
if (counter.decrementAndGet() == 0) {
Expand All @@ -192,23 +186,24 @@ void gatherStatsForStoppedTasks(List<String> expandedIds, GetDataFrameAnalyticsS
}
}

static List<String> determineStoppedTasksIds(List<String> expandedIds, List<Stats> runningTasksStats) {
static List<DataFrameAnalyticsConfig> determineStoppedConfigs(List<DataFrameAnalyticsConfig> configs, List<Stats> runningTasksStats) {
Set<String> startedTasksIds = runningTasksStats.stream().map(Stats::getId).collect(Collectors.toSet());
return expandedIds.stream().filter(id -> startedTasksIds.contains(id) == false).collect(Collectors.toList());
return configs.stream().filter(config -> startedTasksIds.contains(config.getId()) == false).collect(Collectors.toList());
}

private void searchStats(String configId, ActionListener<Stats> listener) {
logger.debug("[{}] Gathering stats for stopped task", configId);
private void searchStats(DataFrameAnalyticsConfig config, ActionListener<Stats> listener) {
logger.debug("[{}] Gathering stats for stopped task", config.getId());

RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder();
RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder(
ProgressTracker.fromZeroes(config.getAnalysis().getProgressPhases()).report());

MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
multiSearchRequest.add(buildStoredProgressSearch(configId));
multiSearchRequest.add(buildStatsDocSearch(configId, DataCounts.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, MemoryUsage.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, OutlierDetectionStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, ClassificationStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(configId, RegressionStats.TYPE_VALUE));
multiSearchRequest.add(buildStoredProgressSearch(config.getId()));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), DataCounts.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), MemoryUsage.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), OutlierDetectionStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), ClassificationStats.TYPE_VALUE));
multiSearchRequest.add(buildStatsDocSearch(config.getId(), RegressionStats.TYPE_VALUE));

executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap(
multiSearchResponse -> {
Expand All @@ -220,7 +215,7 @@ private void searchStats(String configId, ActionListener<Stats> listener) {
logger.error(
new ParameterizedMessage(
"[{}] Item failure encountered during multi search for request [indices={}, source={}]: {}",
configId, itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()),
config.getId(), itemRequest.indices(), itemRequest.source(), itemResponse.getFailureMessage()),
itemResponse.getFailure());
listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure()));
return;
Expand All @@ -229,13 +224,13 @@ private void searchStats(String configId, ActionListener<Stats> listener) {
if (hits.length == 0) {
// Not found
} else if (hits.length == 1) {
parseHit(hits[0], configId, retrievedStatsHolder);
parseHit(hits[0], config.getId(), retrievedStatsHolder);
} else {
throw ExceptionsHelper.serverError("Found [" + hits.length + "] hits when just one was requested");
}
}
}
listener.onResponse(buildStats(configId,
listener.onResponse(buildStats(config.getId(),
retrievedStatsHolder.progress.get(),
retrievedStatsHolder.dataCounts,
retrievedStatsHolder.memoryUsage,
Expand Down Expand Up @@ -322,9 +317,13 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre

private static class RetrievedStatsHolder {

private volatile StoredProgress progress = new StoredProgress(new ProgressTracker().report());
private volatile StoredProgress progress;
private volatile DataCounts dataCounts;
private volatile MemoryUsage memoryUsage;
private volatile AnalysisStats analysisStats;

private RetrievedStatsHolder(List<PhaseProgress> defaultProgress) {
progress = new StoredProgress(defaultProgress);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current
// With config in hand, determine action to take
ActionListener<DataFrameAnalyticsConfig> configListener = ActionListener.wrap(
config -> {
// At this point we have the config at hand and we can reset the progress tracker
// to use the analyses phases. We preserve reindexing progress as if reindexing was
// finished it will not be reset.
task.getStatsHolder().resetProgressTrackerPreservingReindexingProgress(config.getAnalysis().getProgressPhases());

switch(currentState) {
// If we are STARTED, it means the job was started because the start API was called.
// We should determine the job's starting state based on its previous progress.
Expand Down Expand Up @@ -217,7 +222,6 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF
return;
}
task.setReindexingTaskId(null);
task.setReindexingFinished();
auditor.info(
config.getId(),
Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING, config.getDest().getIndex(),
Expand Down Expand Up @@ -296,6 +300,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi
task.markAsCompleted();
return;
}

final ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());
// Update state to ANALYZING and start process
ActionListener<DataFrameDataExtractorFactory> dataExtractorFactoryListener = ActionListener.wrap(
Expand Down Expand Up @@ -327,8 +332,8 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi

ActionListener<RefreshResponse> refreshListener = ActionListener.wrap(
refreshResponse -> {
// Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing
task.setReindexingFinished();
// Now we can ensure reindexing progress is complete
task.getStatsHolder().getProgressTracker().updateReindexingProgress(100);

// TODO This could fail with errors. In that case we get stuck with the copied index.
// We could delete the index in case of failure or we could try building the factory before reindexing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private final StartDataFrameAnalyticsAction.TaskParams taskParams;
@Nullable
private volatile Long reindexingTaskId;
private volatile boolean isReindexingFinished;
private volatile boolean isStopping;
private volatile boolean isMarkAsCompletedCalled;
private final StatsHolder statsHolder = new StatsHolder();
private final StatsHolder statsHolder;

public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
Expand All @@ -81,6 +80,7 @@ public DataFrameAnalyticsTask(long id, String type, String action, TaskId parent
this.analyticsManager = Objects.requireNonNull(analyticsManager);
this.auditor = Objects.requireNonNull(auditor);
this.taskParams = Objects.requireNonNull(taskParams);
this.statsHolder = new StatsHolder(taskParams.getProgressOnStart());
}

public StartDataFrameAnalyticsAction.TaskParams getParams() {
Expand All @@ -92,10 +92,6 @@ public void setReindexingTaskId(Long reindexingTaskId) {
this.reindexingTaskId = reindexingTaskId;
}

public void setReindexingFinished() {
isReindexingFinished = true;
}

public boolean isStopping() {
return isStopping;
}
Expand Down Expand Up @@ -222,7 +218,7 @@ public void updateReindexTaskProgress(ActionListener<Void> listener) {
// We set reindexing progress at least to 1 for a running process to be able to
// distinguish a job that is running for the first time against a job that is restarting.
reindexTaskProgress -> {
statsHolder.getProgressTracker().reindexingPercent.set(Math.max(1, reindexTaskProgress));
statsHolder.getProgressTracker().updateReindexingProgress(Math.max(1, reindexTaskProgress));
listener.onResponse(null);
},
listener::onFailure
Expand All @@ -232,9 +228,7 @@ public void updateReindexTaskProgress(ActionListener<Void> listener) {
private void getReindexTaskProgress(ActionListener<Integer> listener) {
TaskId reindexTaskId = getReindexTaskId();
if (reindexTaskId == null) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
listener.onResponse(isReindexingFinished ? 100 : 0);
listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent());
return;
}

Expand All @@ -250,8 +244,7 @@ private void getReindexTaskProgress(ActionListener<Integer> listener) {
error -> {
if (ExceptionsHelper.unwrapCause(error) instanceof ResourceNotFoundException) {
// The task is not present which means either it has not started yet or it finished.
// We keep track of whether the task has finished so we can use that to tell whether the progress 100.
listener.onResponse(isReindexingFinished ? 100 : 0);
listener.onResponse(statsHolder.getProgressTracker().getReindexingProgressPercent());
} else {
listener.onFailure(error);
}
Expand Down Expand Up @@ -366,17 +359,10 @@ public static StartingState determineStartingState(String jobId, List<PhaseProgr
LOGGER.debug("[{}] Last incomplete progress [{}, {}]", jobId, lastIncompletePhase.getPhase(),
lastIncompletePhase.getProgressPercent());

switch (lastIncompletePhase.getPhase()) {
case ProgressTracker.REINDEXING:
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
case ProgressTracker.LOADING_DATA:
case ProgressTracker.ANALYZING:
case ProgressTracker.WRITING_RESULTS:
return StartingState.RESUMING_ANALYZING;
default:
LOGGER.warn("[{}] Unexpected progress phase [{}]", jobId, lastIncompletePhase.getPhase());
return StartingState.FIRST_TIME;
if (ProgressTracker.REINDEXING.equals(lastIncompletePhase.getPhase())) {
return lastIncompletePhase.getProgressPercent() == 0 ? StartingState.FIRST_TIME : StartingState.RESUMING_REINDEXING;
}
return StartingState.RESUMING_ANALYZING;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
}
}
rowsProcessed += rows.get().size();
progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
progressTracker.updateLoadingDataProgress(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows));
}
}
}
Expand Down
Loading