From 12162f595ec6cdf96b1ddfac3ddc5b0906f54a94 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 23 Aug 2019 17:31:36 +0300 Subject: [PATCH] [7.x][ML] Improve progress reportings for DF analytics (#45856) Previously, the stats API reports a progress percentage for DF analytics tasks that are running and are in the `reindexing` or `analyzing` state. This means that when the task is `stopped` there is no progress reported. Thus, one cannot distinguish between a task that never run to one that completed. In addition, there are blind spots in the progress reporting. In particular, we do not account for when data is loaded into the process. We also do not account for when results are written. This commit addresses the above issues. It changes progress to being a list of objects, each one describing the phase and its progress as a percentage. We currently have 4 phases: reindexing, loading_data, analyzing, writing_results. When the task stops, progress is persisted as a document in the state index. The stats API now reports progress from in-memory if the task is running, or returns the persisted document (if there is one). --- .../ml/dataframe/DataFrameAnalyticsStats.java | 24 +-- .../client/ml/dataframe/PhaseProgress.java | 91 ++++++++++ .../client/MachineLearningIT.java | 9 +- .../DataFrameAnalyticsStatsTests.java | 17 +- .../ml/dataframe/PhaseProgressTests.java | 46 +++++ .../apis/get-dfanalytics-stats.asciidoc | 20 ++- .../GetDataFrameAnalyticsStatsAction.java | 94 +++++++++- .../xpack/core/ml/utils/PhaseProgress.java | 83 +++++++++ ...rameAnalyticsStatsActionResponseTests.java | 9 +- .../core/ml/utils/PhaseProgressTests.java | 34 ++++ ...NativeDataFrameAnalyticsIntegTestCase.java | 39 ++++- .../OutlierDetectionWithMissingFieldsIT.java | 4 + .../integration/RunDataFrameAnalyticsIT.java | 24 +++ ...ansportDeleteDataFrameAnalyticsAction.java | 73 +++++++- ...sportGetDataFrameAnalyticsStatsAction.java | 147 ++++++++++------ ...ransportStartDataFrameAnalyticsAction.java | 161 +++++++++++++++++- .../dataframe/DataFrameAnalyticsManager.java | 17 +- .../xpack/ml/dataframe/StoredProgress.java | 60 +++++++ .../dataframe/process/AnalyticsProcess.java | 6 + .../process/AnalyticsProcessConfig.java | 4 + .../process/AnalyticsProcessManager.java | 28 ++- .../process/AnalyticsResultProcessor.java | 23 ++- .../process/NativeAnalyticsProcess.java | 39 ++++- .../NativeAnalyticsProcessFactory.java | 2 +- .../NativeMemoryUsageEstimationProcess.java | 5 + .../ml/dataframe/StoredProgressTests.java | 37 ++++ .../AnalyticsResultProcessorTests.java | 22 ++- 27 files changed, 989 insertions(+), 129 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index 4e04204e65021..bfef47727f631 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.List; import java.util.Objects; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; @@ -42,17 +43,18 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws static final ParseField ID = new ParseField("id"); static final ParseField STATE = new ParseField("state"); static final ParseField FAILURE_REASON = new ParseField("failure_reason"); - static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ParseField PROGRESS = new ParseField("progress"); static final ParseField NODE = new ParseField("node"); static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("data_frame_analytics_stats", true, args -> new DataFrameAnalyticsStats( (String) args[0], (DataFrameAnalyticsState) args[1], (String) args[2], - (Integer) args[3], + (List) args[3], (NodeAttributes) args[4], (String) args[5])); @@ -65,7 +67,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, STATE, ObjectParser.ValueType.STRING); PARSER.declareString(optionalConstructorArg(), FAILURE_REASON); - PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS); PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); } @@ -73,17 +75,17 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws private final String id; private final DataFrameAnalyticsState state; private final String failureReason; - private final Integer progressPercent; + private final List progress; private final NodeAttributes node; private final String assignmentExplanation; public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, - @Nullable Integer progressPercent, @Nullable NodeAttributes node, + @Nullable List progress, @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { this.id = id; this.state = state; this.failureReason = failureReason; - this.progressPercent = progressPercent; + this.progress = progress; this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -100,8 +102,8 @@ public String getFailureReason() { return failureReason; } - public Integer getProgressPercent() { - return progressPercent; + public List getProgress() { + return progress; } public NodeAttributes getNode() { @@ -121,14 +123,14 @@ public boolean equals(Object o) { return Objects.equals(id, other.id) && Objects.equals(state, other.state) && Objects.equals(failureReason, other.failureReason) - && Objects.equals(progressPercent, other.progressPercent) + && Objects.equals(progress, other.progress) && Objects.equals(node, other.node) && Objects.equals(assignmentExplanation, other.assignmentExplanation); } @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progressPercent, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation); } @Override @@ -137,7 +139,7 @@ public String toString() { .add("id", id) .add("state", state) .add("failureReason", failureReason) - .add("progressPercent", progressPercent) + .add("progress", progress) .add("node", node) .add("assignmentExplanation", assignmentExplanation) .toString(); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java new file mode 100644 index 0000000000000..21842efc7dfe6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/PhaseProgress.java @@ -0,0 +1,91 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A class that describes a phase and its progress as a percentage + */ +public class PhaseProgress implements ToXContentObject { + + static final ParseField PHASE = new ParseField("phase"); + static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("phase_progress", + true, a -> new PhaseProgress((String) a[0], (int) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT); + } + + private final String phase; + private final int progressPercent; + + public PhaseProgress(String phase, int progressPercent) { + this.phase = Objects.requireNonNull(phase); + this.progressPercent = progressPercent; + } + + public String getPhase() { + return phase; + } + + public int getProgressPercent() { + return progressPercent; + } + + @Override + public int hashCode() { + return Objects.hash(phase, progressPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PhaseProgress that = (PhaseProgress) o; + return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add(PHASE.getPreferredName(), phase) + .add(PROGRESS_PERCENT.getPreferredName(), progressPercent) + .toString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PhaseProgress.PHASE.getPreferredName(), phase); + builder.field(PhaseProgress.PROGRESS_PERCENT.getPreferredName(), progressPercent); + builder.endObject(); + return builder; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index dd374dc52568c..d8398f2895f98 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -123,6 +123,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; @@ -1405,11 +1406,17 @@ public void testGetDataFrameAnalyticsStats() throws Exception { assertThat(stats.getId(), equalTo(configId)); assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED)); assertNull(stats.getFailureReason()); - assertNull(stats.getProgressPercent()); assertNull(stats.getNode()); assertNull(stats.getAssignmentExplanation()); assertThat(statsResponse.getNodeFailures(), hasSize(0)); assertThat(statsResponse.getTaskFailures(), hasSize(0)); + List progress = stats.getProgress(); + assertThat(progress, is(notNullValue())); + assertThat(progress.size(), equalTo(4)); + assertThat(progress.get(0), equalTo(new PhaseProgress("reindexing", 0))); + assertThat(progress.get(1), equalTo(new PhaseProgress("loading_data", 0))); + assertThat(progress.get(2), equalTo(new PhaseProgress("analyzing", 0))); + assertThat(progress.get(3), equalTo(new PhaseProgress("writing_results", 0))); } public void testStartDataFrameAnalyticsConfig() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java index fad02eac161c7..f8eddd36bc6d9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -24,6 +24,8 @@ import org.elasticsearch.test.ESTestCase; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; @@ -44,11 +46,20 @@ public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { randomAlphaOfLengthBetween(1, 10), randomFrom(DataFrameAnalyticsState.values()), randomBoolean() ? null : randomAlphaOfLength(10), - randomBoolean() ? null : randomIntBetween(0, 100), + randomBoolean() ? null : createRandomProgress(), randomBoolean() ? null : NodeAttributesTests.createRandom(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); } + private static List createRandomProgress() { + int progressPhaseCount = randomIntBetween(3, 7); + List progress = new ArrayList<>(progressPhaseCount); + for (int i = 0; i < progressPhaseCount; i++) { + progress.add(new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100))); + } + return progress; + } + public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException { builder.startObject(); builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId()); @@ -56,8 +67,8 @@ public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder bui if (stats.getFailureReason() != null) { builder.field(DataFrameAnalyticsStats.FAILURE_REASON.getPreferredName(), stats.getFailureReason()); } - if (stats.getProgressPercent() != null) { - builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent()); + if (stats.getProgress() != null) { + builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress()); } if (stats.getNode() != null) { builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java new file mode 100644 index 0000000000000..0281285112aa1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/PhaseProgressTests.java @@ -0,0 +1,46 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class PhaseProgressTests extends AbstractXContentTestCase { + + public static PhaseProgress createRandom() { + return new PhaseProgress(randomAlphaOfLength(20), randomIntBetween(0, 100)); + } + + @Override + protected PhaseProgress createTestInstance() { + return createRandom(); + } + + @Override + protected PhaseProgress doParseInstance(XContentParser parser) throws IOException { + return PhaseProgress.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc b/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc index 018d53a2c5e89..b1a8d4c194b64 100644 --- a/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc @@ -99,7 +99,25 @@ The API returns the following results: "data_frame_analytics": [ { "id": "loganalytics", - "state": "stopped" + "state": "stopped", + "progress": [ + { + "phase": "reindexing", + "progress_percent": 0 + }, + { + "phase": "loading_data", + "progress_percent": 0 + }, + { + "phase": "analyzing", + "progress_percent": 0 + }, + { + "phase": "writing_results", + "progress_percent": 0 + } + ] } ] } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index fb67cb0f965b0..6712c1f8ecf23 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; @@ -28,8 +29,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -154,19 +157,23 @@ public static class Stats implements ToXContentObject, Writeable { private final DataFrameAnalyticsState state; @Nullable private final String failureReason; - @Nullable - private final Integer progressPercentage; + + /** + * The progress is described as a list of each phase and its completeness percentage. + */ + private final List progress; + @Nullable private final DiscoveryNode node; @Nullable private final String assignmentExplanation; - public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, @Nullable Integer progressPercentage, + public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List progress, @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); this.failureReason = failureReason; - this.progressPercentage = progressPercentage; + this.progress = Objects.requireNonNull(progress); this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -175,11 +182,47 @@ public Stats(StreamInput in) throws IOException { id = in.readString(); state = DataFrameAnalyticsState.fromStream(in); failureReason = in.readOptionalString(); - progressPercentage = in.readOptionalInt(); + if (in.getVersion().before(Version.V_7_4_0)) { + progress = readProgressFromLegacy(state, in); + } else { + progress = in.readList(PhaseProgress::new); + } node = in.readOptionalWriteable(DiscoveryNode::new); assignmentExplanation = in.readOptionalString(); } + private static List readProgressFromLegacy(DataFrameAnalyticsState state, StreamInput in) throws IOException { + Integer legacyProgressPercent = in.readOptionalInt(); + if (legacyProgressPercent == null) { + return Collections.emptyList(); + } + + int reindexingProgress = 0; + int loadingDataProgress = 0; + int analyzingProgress = 0; + switch (state) { + case ANALYZING: + reindexingProgress = 100; + loadingDataProgress = 100; + analyzingProgress = legacyProgressPercent; + break; + case REINDEXING: + reindexingProgress = legacyProgressPercent; + break; + case STARTED: + case STOPPED: + case STOPPING: + default: + return null; + } + + return Arrays.asList( + new PhaseProgress("reindexing", reindexingProgress), + new PhaseProgress("loading_data", loadingDataProgress), + new PhaseProgress("analyzing", analyzingProgress), + new PhaseProgress("writing_results", 0)); + } + public String getId() { return id; } @@ -188,6 +231,10 @@ public DataFrameAnalyticsState getState() { return state; } + public List getProgress() { + return progress; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them @@ -204,8 +251,8 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc if (failureReason != null) { builder.field("failure_reason", failureReason); } - if (progressPercentage != null) { - builder.field("progress_percent", progressPercentage); + if (progress != null) { + builder.field("progress", progress); } if (node != null) { builder.startObject("node"); @@ -232,14 +279,43 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(id); state.writeTo(out); out.writeOptionalString(failureReason); - out.writeOptionalInt(progressPercentage); + if (out.getVersion().before(Version.V_7_4_0)) { + writeProgressToLegacy(out); + } else { + out.writeList(progress); + } out.writeOptionalWriteable(node); out.writeOptionalString(assignmentExplanation); } + private void writeProgressToLegacy(StreamOutput out) throws IOException { + String targetPhase = null; + switch (state) { + case ANALYZING: + targetPhase = "analyzing"; + break; + case REINDEXING: + targetPhase = "reindexing"; + break; + case STARTED: + case STOPPED: + case STOPPING: + default: + break; + } + + Integer legacyProgressPercent = null; + for (PhaseProgress phaseProgress : progress) { + if (phaseProgress.getPhase().equals(targetPhase)) { + legacyProgressPercent = phaseProgress.getProgressPercent(); + } + } + out.writeOptionalInt(legacyProgressPercent); + } + @Override public int hashCode() { - return Objects.hash(id, state, failureReason, progressPercentage, node, assignmentExplanation); + return Objects.hash(id, state, failureReason, progress, node, assignmentExplanation); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java new file mode 100644 index 0000000000000..0f9617bceb10e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgress.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * A class that describes a phase and its progress as a percentage + */ +public class PhaseProgress implements ToXContentObject, Writeable { + + public static final ParseField PHASE = new ParseField("phase"); + public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("phase_progress", + true, a -> new PhaseProgress((String) a[0], (int) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), PHASE); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), PROGRESS_PERCENT); + } + + private final String phase; + private final int progressPercent; + + public PhaseProgress(String phase, int progressPercent) { + this.phase = Objects.requireNonNull(phase); + this.progressPercent = progressPercent; + } + + public PhaseProgress(StreamInput in) throws IOException { + phase = in.readString(); + progressPercent = in.readVInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(phase); + out.writeVInt(progressPercent); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PHASE.getPreferredName(), phase); + builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); + builder.endObject(); + return builder; + } + + public String getPhase() { + return phase; + } + + public int getProgressPercent() { + return progressPercent; + } + + @Override + public int hashCode() { + return Objects.hash(phase, progressPercent); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PhaseProgress that = (PhaseProgress) o; + return Objects.equals(phase, that.phase) && progressPercent == that.progressPercent; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 5a88f2ea52eab..8ada940252139 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -11,9 +11,11 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import java.util.ArrayList; import java.util.List; +import java.util.stream.IntStream; public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { @@ -22,10 +24,13 @@ protected Response createTestInstance() { int listSize = randomInt(10); List analytics = new ArrayList<>(listSize); for (int j = 0; j < listSize; j++) { - Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100); String failureReason = randomBoolean() ? null : randomAlphaOfLength(10); + int progressSize = randomIntBetween(2, 5); + List progress = new ArrayList<>(progressSize); + IntStream.of(progressSize).forEach(progressIndex -> progress.add( + new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)))); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), failureReason, progressPercentage, null, randomAlphaOfLength(20)); + randomFrom(DataFrameAnalyticsState.values()), failureReason, progress, null, randomAlphaOfLength(20)); analytics.add(stats); } return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java new file mode 100644 index 0000000000000..71834329342fd --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/PhaseProgressTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class PhaseProgressTests extends AbstractSerializingTestCase { + + @Override + protected PhaseProgress createTestInstance() { + return createRandom(); + } + + public static PhaseProgress createRandom() { + return new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)); + } + + @Override + protected PhaseProgress doParseInstance(XContentParser parser) throws IOException { + return PhaseProgress.PARSER.apply(parser, null); + } + + @Override + protected Writeable.Reader instanceReader() { + return PhaseProgress::new; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 24045c1549151..333811bcdb711 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -5,11 +5,11 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; @@ -22,14 +22,16 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; -import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; /** * Base class of ML integration tests that use a native data_frame_analytics process @@ -46,7 +48,8 @@ protected void cleanUpResources() { private void cleanUpAnalytics() { for (DataFrameAnalyticsConfig config : analytics) { try { - deleteAnalytics(config.getId()); + assertThat(deleteAnalytics(config.getId()).isAcknowledged(), is(true)); + assertThat(searchStoredProgress(config.getId()).getHits().getTotalHits().value, equalTo(0L)); } catch (Exception e) { // ignore } @@ -100,10 +103,6 @@ protected List getAnalyticsStat return response.getResponse().results(); } - protected static String createJsonRecord(Map keyValueMap) throws IOException { - return Strings.toString(JsonXContent.contentBuilder().map(keyValueMap)) + "\n"; - } - protected static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String[] sourceIndex, String destIndex, @Nullable String resultsField) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); @@ -121,6 +120,28 @@ protected void assertState(String id, DataFrameAnalyticsState state) { assertThat(stats.get(0).getState(), equalTo(state)); } + protected void assertProgress(String id, int reindexing, int loadingData, int analyzing, int writingResults) { + List stats = getAnalyticsStats(id); + List progress = stats.get(0).getProgress(); + assertThat(stats.size(), equalTo(1)); + assertThat(stats.get(0).getId(), equalTo(id)); + assertThat(progress.size(), equalTo(4)); + assertThat(progress.get(0).getPhase(), equalTo("reindexing")); + assertThat(progress.get(1).getPhase(), equalTo("loading_data")); + assertThat(progress.get(2).getPhase(), equalTo("analyzing")); + assertThat(progress.get(3).getPhase(), equalTo("writing_results")); + assertThat(progress.get(0).getProgressPercent(), equalTo(reindexing)); + assertThat(progress.get(1).getProgressPercent(), equalTo(loadingData)); + assertThat(progress.get(2).getProgressPercent(), equalTo(analyzing)); + assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults)); + } + + protected SearchResponse searchStoredProgress(String id) { + return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) + .setQuery(QueryBuilders.idsQuery().addIds(TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(id))) + .get(); + } + protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex, @Nullable String resultsField, String dependentVariable) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java index 79f3af3164a94..89782f18c0c72 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java @@ -73,6 +73,7 @@ public void testMissingFields() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -99,5 +100,8 @@ public void testMissingFields() throws Exception { assertThat(destDoc.containsKey("ml"), is(false)); } } + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index eb99135b418e5..3dfa83470f507 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -78,6 +78,7 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -113,6 +114,9 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { } } assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier))); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { @@ -143,6 +147,7 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -156,6 +161,9 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Exception { @@ -201,6 +209,7 @@ public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Ex putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -224,6 +233,9 @@ public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Ex double outlierScore = (double) resultsObject.get("outlier_score"); assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0))); } + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/43960") @@ -312,6 +324,7 @@ public void testOutlierDetectionWithMultipleSourceIndices() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -325,6 +338,9 @@ public void testOutlierDetectionWithMultipleSourceIndices() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { @@ -358,6 +374,7 @@ public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -371,6 +388,9 @@ public void testOutlierDetectionWithPreExistingDestIndex() throws Exception { .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions())); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { @@ -406,6 +426,7 @@ public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { putAnalytics(config); assertState(id, DataFrameAnalyticsState.STOPPED); + assertProgress(id, 0, 0, 0, 0); startAnalytics(id); waitUntilAnalyticsIsStopped(id); @@ -438,6 +459,9 @@ public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception { } } assertThat(resultsWithPrediction, greaterThan(0)); + + assertProgress(id, 100, 100, 100, 100); + assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L)); } public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java index 1165ae175256d..eee9d5b69c749 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java @@ -5,15 +5,20 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.delete.DeleteAction; import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.Client; +import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -21,7 +26,14 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.MlTasks; @@ -30,7 +42,9 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; import java.io.IOException; @@ -45,18 +59,22 @@ public class TransportDeleteDataFrameAnalyticsAction extends TransportMasterNodeAction { + private static final Logger LOGGER = LogManager.getLogger(TransportDeleteDataFrameAnalyticsAction.class); + private final Client client; private final MlMemoryTracker memoryTracker; + private final DataFrameAnalyticsConfigProvider configProvider; @Inject public TransportDeleteDataFrameAnalyticsAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, - MlMemoryTracker memoryTracker) { + MlMemoryTracker memoryTracker, DataFrameAnalyticsConfigProvider configProvider) { super(DeleteDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, DeleteDataFrameAnalyticsAction.Request::new, indexNameExpressionResolver); this.client = client; this.memoryTracker = memoryTracker; + this.configProvider = configProvider; } @Override @@ -72,6 +90,12 @@ protected AcknowledgedResponse read(StreamInput in) throws IOException { @Override protected void masterOperation(DeleteDataFrameAnalyticsAction.Request request, ClusterState state, ActionListener listener) { + throw new UnsupportedOperationException("The task parameter is required"); + } + + @Override + protected void masterOperation(Task task, DeleteDataFrameAnalyticsAction.Request request, ClusterState state, + ActionListener listener) { String id = request.getId(); PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); DataFrameAnalyticsState taskState = MlTasks.getDataFrameAnalyticsState(id, tasks); @@ -81,25 +105,70 @@ protected void masterOperation(DeleteDataFrameAnalyticsAction.Request request, C return; } + TaskId taskId = new TaskId(clusterService.localNode().getId(), task.getId()); + ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, taskId); + // We clean up the memory tracker on delete because there is no stop; the task stops by itself memoryTracker.removeDataFrameAnalyticsJob(id); + // Step 2. Delete the config + ActionListener deleteStateHandler = ActionListener.wrap( + bulkByScrollResponse -> { + if (bulkByScrollResponse.isTimedOut()) { + LOGGER.warn("[{}] DeleteByQuery for state timed out", id); + } + if (bulkByScrollResponse.getBulkFailures().isEmpty() == false) { + LOGGER.warn("[{}] {} failures and {} conflicts encountered while runnint DeleteByQuery for state", id, + bulkByScrollResponse.getBulkFailures().size(), bulkByScrollResponse.getVersionConflicts()); + for (BulkItemResponse.Failure failure : bulkByScrollResponse.getBulkFailures()) { + LOGGER.warn("[{}] DBQ failure: {}", id, failure); + } + } + deleteConfig(parentTaskClient, id, listener); + }, + listener::onFailure + ); + + // Step 1. Delete state + ActionListener configListener = ActionListener.wrap( + config -> deleteState(parentTaskClient, id, deleteStateHandler), + listener::onFailure + ); + + // Step 1. Get the config to check if it exists + configProvider.get(id, configListener); + } + + private void deleteConfig(ParentTaskAssigningClient parentTaskClient, String id, ActionListener listener) { DeleteRequest deleteRequest = new DeleteRequest(AnomalyDetectorsIndex.configIndexName()); deleteRequest.id(DataFrameAnalyticsConfig.documentId(id)); deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - executeAsyncWithOrigin(client, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( + executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( deleteResponse -> { if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(id)); return; } assert deleteResponse.getResult() == DocWriteResponse.Result.DELETED; + LOGGER.info("[{}] Deleted", id); listener.onResponse(new AcknowledgedResponse(true)); }, listener::onFailure )); } + private void deleteState(ParentTaskAssigningClient parentTaskClient, String analyticsId, + ActionListener listener) { + DeleteByQueryRequest request = new DeleteByQueryRequest(AnomalyDetectorsIndex.jobStateIndexPattern()); + request.setQuery(QueryBuilders.idsQuery().addIds( + TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.progressDocId(analyticsId))); + request.setIndicesOptions(MlIndicesUtils.addIgnoreUnavailable(IndicesOptions.lenientExpandOpen())); + request.setSlices(AbstractBulkByScrollRequest.AUTO_SLICES); + request.setAbortOnVersionConflict(false); + request.setRefresh(true); + executeAsyncWithOrigin(parentTaskClient, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, listener); + } + @Override protected ClusterBlockException checkBlock(DeleteDataFrameAnalyticsAction.Request request, ClusterState state) { return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 884f741013b59..875a0a8f44749 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -7,24 +7,32 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.ResourceNotFoundException; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.TaskOperationFailure; -import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest; +import org.elasticsearch.action.search.MultiSearchAction; +import org.elasticsearch.action.search.MultiSearchRequest; +import org.elasticsearch.action.search.MultiSearchResponse; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.index.reindex.BulkByScrollTask; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.action.util.QueryPage; @@ -35,9 +43,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.StoredProgress; +import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -55,16 +68,14 @@ public class TransportGetDataFrameAnalyticsStatsAction private static final Logger LOGGER = LogManager.getLogger(TransportGetDataFrameAnalyticsStatsAction.class); private final Client client; - private final AnalyticsProcessManager analyticsProcessManager; @Inject public TransportGetDataFrameAnalyticsStatsAction(TransportService transportService, ClusterService clusterService, Client client, - ActionFilters actionFilters, AnalyticsProcessManager analyticsProcessManager) { + ActionFilters actionFilters) { super(GetDataFrameAnalyticsStatsAction.NAME, clusterService, transportService, actionFilters, GetDataFrameAnalyticsStatsAction.Request::new, GetDataFrameAnalyticsStatsAction.Response::new, in -> new QueryPage<>(in, GetDataFrameAnalyticsStatsAction.Response.Stats::new), ThreadPool.Names.MANAGEMENT); this.client = client; - this.analyticsProcessManager = analyticsProcessManager; } @Override @@ -86,7 +97,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D ActionListener> listener) { LOGGER.debug("Get stats for running task [{}]", task.getParams().getId()); - ActionListener progressListener = ActionListener.wrap( + ActionListener> progressListener = ActionListener.wrap( progress -> { Stats stats = buildStats(task.getParams().getId(), progress); listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, @@ -94,38 +105,14 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D }, listener::onFailure ); - ClusterState clusterState = clusterService.state(); - PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); - DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(task.getParams().getId(), tasks); - - // For a running task we report the progress associated with its current state - if (analyticsState == DataFrameAnalyticsState.REINDEXING) { - getReindexTaskProgress(task, progressListener); - } else { - progressListener.onResponse(analyticsProcessManager.getProgressPercent(task.getAllocationId())); - } - } - - private void getReindexTaskProgress(DataFrameAnalyticsTask task, ActionListener listener) { - TaskId reindexTaskId = new TaskId(clusterService.localNode().getId(), task.getReindexingTaskId()); - GetTaskRequest getTaskRequest = new GetTaskRequest(); - getTaskRequest.setTaskId(reindexTaskId); - client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap( - taskResponse -> { - TaskResult taskResult = taskResponse.getTask(); - BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus(); - int progress = taskStatus.getTotal() == 0 ? 100 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal()); - listener.onResponse(progress); + ActionListener reindexingProgressListener = ActionListener.wrap( + aVoid -> { + progressListener.onResponse(task.getProgressTracker().report()); }, - error -> { - if (error instanceof ResourceNotFoundException) { - // The task has either not started yet or has finished, thus it is better to respond null and not show progress at all - listener.onResponse(null); - } else { - listener.onFailure(error); - } - } - )); + listener::onFailure + ); + + task.updateReindexTaskProgress(reindexingProgressListener); } @Override @@ -166,12 +153,27 @@ protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request req void gatherStatsForStoppedTasks(List expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse, ActionListener listener) { List stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results()); - List stoppedTasksStats = stoppedTasksIds.stream().map(this::buildStatsForStoppedTask).collect(Collectors.toList()); - List allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results()); - allTasksStats.addAll(stoppedTasksStats); - Collections.sort(allTasksStats, Comparator.comparing(Stats::getId)); - listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>( - allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + if (stoppedTasksIds.isEmpty()) { + listener.onResponse(runningTasksResponse); + return; + } + + searchStoredProgresses(stoppedTasksIds, ActionListener.wrap( + storedProgresses -> { + List stoppedStats = new ArrayList<>(stoppedTasksIds.size()); + for (int i = 0; i < stoppedTasksIds.size(); i++) { + String configId = stoppedTasksIds.get(i); + StoredProgress storedProgress = storedProgresses.get(i); + stoppedStats.add(buildStats(configId, storedProgress.get())); + } + List allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results()); + allTasksStats.addAll(stoppedStats); + Collections.sort(allTasksStats, Comparator.comparing(Stats::getId)); + listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>( + allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + }, + listener::onFailure + )); } static List determineStoppedTasksIds(List expandedIds, List runningTasksStats) { @@ -179,11 +181,52 @@ static List determineStoppedTasksIds(List expandedIds, List startedTasksIds.contains(id) == false).collect(Collectors.toList()); } - private GetDataFrameAnalyticsStatsAction.Response.Stats buildStatsForStoppedTask(String concreteAnalyticsId) { - return buildStats(concreteAnalyticsId, null); + private void searchStoredProgresses(List configIds, ActionListener> listener) { + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (String configId : configIds) { + SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()); + searchRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); + searchRequest.source().size(1); + searchRequest.source().query(QueryBuilders.idsQuery().addIds(DataFrameAnalyticsTask.progressDocId(configId))); + multiSearchRequest.add(searchRequest); + } + + executeAsyncWithOrigin(client, ML_ORIGIN, MultiSearchAction.INSTANCE, multiSearchRequest, ActionListener.wrap( + multiSearchResponse -> { + List progresses = new ArrayList<>(configIds.size()); + for (MultiSearchResponse.Item itemResponse : multiSearchResponse.getResponses()) { + if (itemResponse.isFailure()) { + listener.onFailure(ExceptionsHelper.serverError(itemResponse.getFailureMessage(), itemResponse.getFailure())); + return; + } else { + SearchHit[] hits = itemResponse.getResponse().getHits().getHits(); + if (hits.length == 0) { + progresses.add(new StoredProgress(new DataFrameAnalyticsTask.ProgressTracker().report())); + } else { + progresses.add(parseStoredProgress(hits[0])); + } + } + } + listener.onResponse(progresses); + }, + e -> listener.onFailure(ExceptionsHelper.serverError("Error searching for stored progresses", e)) + )); + } + + private StoredProgress parseStoredProgress(SearchHit hit) { + BytesReference source = hit.getSourceRef(); + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) { + StoredProgress storedProgress = StoredProgress.PARSER.apply(parser, null); + return storedProgress; + } catch (IOException e) { + LOGGER.error(new ParameterizedMessage("failed to parse progress from doc with it [{}]", hit.getId()), e); + return new StoredProgress(Collections.emptyList()); + } } - private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, @Nullable Integer progressPercent) { + private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, List progress) { ClusterState clusterState = clusterService.state(); PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); PersistentTasksCustomMetaData.PersistentTask analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks); @@ -200,6 +243,6 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre assignmentExplanation = analyticsTask.getAssignment().getExplanation(); } return new GetDataFrameAnalyticsStatsAction.Response.Stats( - concreteAnalyticsId, analyticsState, failureReason, progressPercent, node, assignmentExplanation); + concreteAnalyticsId, analyticsState, failureReason, progress, node, assignmentExplanation); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 70f92478df4a5..48cfa58afb008 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -15,10 +15,14 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.TransportMasterNodeAction; import org.elasticsearch.client.Client; @@ -34,7 +38,10 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.reindex.BulkByScrollTask; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; @@ -45,6 +52,7 @@ import org.elasticsearch.persistent.PersistentTasksService; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ClientHelper; @@ -52,6 +60,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -59,10 +68,13 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.core.watcher.watch.Payload; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.MappingsMerger; import org.elasticsearch.xpack.ml.dataframe.SourceDestValidator; +import org.elasticsearch.xpack.ml.dataframe.StoredProgress; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.job.JobNodeSelector; @@ -70,12 +82,16 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Predicate; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ml.MlTasks.AWAITING_UPGRADE; import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE; @@ -369,7 +385,9 @@ public static class DataFrameAnalyticsTask extends AllocatedPersistentTask imple private final StartDataFrameAnalyticsAction.TaskParams taskParams; @Nullable private volatile Long reindexingTaskId; + private volatile boolean isReindexingFinished; private volatile boolean isStopping; + private final ProgressTracker progressTracker = new ProgressTracker(); public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager, @@ -389,22 +407,50 @@ public void setReindexingTaskId(Long reindexingTaskId) { this.reindexingTaskId = reindexingTaskId; } - @Nullable - public Long getReindexingTaskId() { - return reindexingTaskId; + public void setReindexingFinished() { + isReindexingFinished = true; } public boolean isStopping() { return isStopping; } + public ProgressTracker getProgressTracker() { + return progressTracker; + } + @Override protected void onCancelled() { stop(getReasonCancelled(), TimeValue.ZERO); } + @Override + public void markAsCompleted() { + persistProgress(() -> super.markAsCompleted()); + } + + @Override + public void markAsFailed(Exception e) { + persistProgress(() -> super.markAsFailed(e)); + } + public void stop(String reason, TimeValue timeout) { isStopping = true; + + ActionListener reindexProgressListener = ActionListener.wrap( + aVoid -> doStop(reason, timeout), + e -> { + LOGGER.error(new ParameterizedMessage("[{}] Error updating reindexing progress", taskParams.getId()), e); + // We should log the error but it shouldn't stop us from stopping the task + doStop(reason, timeout); + } + ); + + // We need to update reindexing progress before we cancel the task + updateReindexTaskProgress(reindexProgressListener); + } + + private void doStop(String reason, TimeValue timeout) { if (reindexingTaskId != null) { cancelReindexingTask(reason, timeout); } @@ -440,10 +486,115 @@ public void updateState(DataFrameAnalyticsState state, @Nullable String reason) DataFrameAnalyticsTaskState newTaskState = new DataFrameAnalyticsTaskState(state, getAllocationId(), reason); updatePersistentTaskState(newTaskState, ActionListener.wrap( updatedTask -> LOGGER.info("[{}] Successfully update task state to [{}]", getParams().getId(), state), - e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}]", - getParams().getId(), state), e) + e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]", + getParams().getId(), state, reason), e) + )); + } + + public void updateReindexTaskProgress(ActionListener listener) { + TaskId reindexTaskId = getReindexTaskId(); + if (reindexTaskId == null) { + // The task is not present which means either it has not started yet or it finished. + // We keep track of whether the task has finished so we can use that to tell whether the progress 100. + if (isReindexingFinished) { + progressTracker.reindexingPercent.set(100); + } + listener.onResponse(null); + return; + } + + GetTaskRequest getTaskRequest = new GetTaskRequest(); + getTaskRequest.setTaskId(reindexTaskId); + client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap( + taskResponse -> { + TaskResult taskResult = taskResponse.getTask(); + BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus(); + int progress = taskStatus.getTotal() == 0 ? 0 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal()); + progressTracker.reindexingPercent.set(progress); + listener.onResponse(null); + }, + error -> { + if (error instanceof ResourceNotFoundException) { + // The task is not present which means either it has not started yet or it finished. + // We keep track of whether the task has finished so we can use that to tell whether the progress 100. + if (isReindexingFinished) { + progressTracker.reindexingPercent.set(100); + } + listener.onResponse(null); + } else { + listener.onFailure(error); + } + } )); } + + @Nullable + private TaskId getReindexTaskId() { + try { + return new TaskId(clusterService.localNode().getId(), reindexingTaskId); + } catch (NullPointerException e) { + // This may happen if there is no reindexing task id set which means we either never started the task yet or we're finished + return null; + } + } + + private void persistProgress(Runnable runnable) { + GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId()); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap( + statsResponse -> { + GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0); + IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias()); + indexRequest.id(progressDocId(taskParams.getId())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) { + new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS); + indexRequest.source(jsonBuilder); + } + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap( + indexResponse -> { + LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId()); + runnable.run(); + }, + indexError -> { + LOGGER.error(new ParameterizedMessage( + "[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError); + runnable.run(); + } + )); + }, + e -> { + LOGGER.error(new ParameterizedMessage( + "[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e); + runnable.run(); + } + )); + } + + public static String progressDocId(String id) { + return "data_frame_analytics-" + id + "-progress"; + } + + public static class ProgressTracker { + + public static final String REINDEXING = "reindexing"; + public static final String LOADING_DATA = "loading_data"; + public static final String ANALYZING = "analyzing"; + public static final String WRITING_RESULTS = "writing_results"; + + public final AtomicInteger reindexingPercent = new AtomicInteger(0); + public final AtomicInteger loadingDataPercent = new AtomicInteger(0); + public final AtomicInteger analyzingPercent = new AtomicInteger(0); + public final AtomicInteger writingResultsPercent = new AtomicInteger(0); + + public List report() { + return Arrays.asList( + new PhaseProgress(REINDEXING, reindexingPercent.get()), + new PhaseProgress(LOADING_DATA, loadingDataPercent.get()), + new PhaseProgress(ANALYZING, analyzingPercent.get()), + new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get()) + ); + } + } } static List verifyIndicesPrimaryShardsAreActive(ClusterState clusterState, String... indexNames) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 7206376334a36..4b73a91886443 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -126,6 +126,10 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF // Reindexing is complete; start analytics ActionListener refreshListener = ActionListener.wrap( refreshResponse -> { + if (task.isStopping()) { + LOGGER.debug("[{}] Stopping before starting analytics process", config.getId()); + return; + } task.setReindexingTaskId(null); startAnalytics(task, config, false); }, @@ -134,12 +138,18 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF // Refresh to ensure copied index is fully searchable ActionListener reindexCompletedListener = ActionListener.wrap( - bulkResponse -> + bulkResponse -> { + if (task.isStopping()) { + LOGGER.debug("[{}] Stopping before refreshing destination index", config.getId()); + return; + } + task.setReindexingFinished(); ClientHelper.executeAsyncWithOrigin(client, ClientHelper.ML_ORIGIN, RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex()), - refreshListener), + refreshListener); + }, error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage()) ); @@ -187,6 +197,9 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF } private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, boolean isTaskRestarting) { + // Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing + task.setReindexingFinished(); + // Update state to ANALYZING and start process ActionListener dataExtractorFactoryListener = ActionListener.wrap( dataExtractorFactory -> { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java new file mode 100644 index 0000000000000..9c08a0b3012e1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/StoredProgress.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class StoredProgress implements ToXContentObject { + + private static final ParseField PROGRESS = new ParseField("progress"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + PROGRESS.getPreferredName(), true, a -> new StoredProgress((List) a[0])); + + static { + PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), PhaseProgress.PARSER, PROGRESS); + } + + private final List progress; + + public StoredProgress(List progress) { + this.progress = Objects.requireNonNull(progress); + } + + public List get() { + return progress; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(PROGRESS.getPreferredName(), progress); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || o.getClass().equals(getClass()) == false) return false; + StoredProgress that = (StoredProgress) o; + return Objects.equals(progress, that.progress); + } + + @Override + public int hashCode() { + return Objects.hash(progress); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java index 6a2ea283b4440..24b03000b21b4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java @@ -31,4 +31,10 @@ public interface AnalyticsProcess extends NativeProcess { * a SIGPIPE */ void consumeAndCloseOutputStream(); + + /** + * + * @return the process config + */ + AnalyticsProcessConfig getConfig(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 70a2e213fb6ca..5093404812afe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -43,6 +43,10 @@ public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, in this.analysis = Objects.requireNonNull(analysis); } + public long rows() { + return rows; + } + public int cols() { return cols; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index e9ef10e848eb4..e94dbf4747b2a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -31,7 +30,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; public class AnalyticsProcessManager { @@ -90,14 +88,15 @@ private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c Consumer finishHandler) { try { + ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process); + writeDataRows(dataExtractor, process, task.getProgressTracker()); process.writeEndOfDataMessage(); process.flushStream(); LOGGER.info("[{}] Waiting for result processor to complete", config.getId()); resultProcessor.awaitForCompletion(); - processContextByAllocation.get(task.getAllocationId()).setFailureReason(resultProcessor.getFailure()); + processContext.setFailureReason(resultProcessor.getFailure()); refreshDest(config); LOGGER.info("[{}] Result processor has completed", config.getId()); @@ -122,12 +121,16 @@ private void processData(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c } } - private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, + DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException { // The extra fields are for the doc hash and the control field (should be an empty string) String[] record = new String[dataExtractor.getFieldNames().size() + 2]; // The value of the control field should be an empty string for data frame rows record[record.length - 1] = ""; + long totalRows = process.getConfig().rows(); + long rowsProcessed = 0; + while (dataExtractor.hasNext()) { Optional> rows = dataExtractor.next(); if (rows.isPresent()) { @@ -139,6 +142,8 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces process.writeRecord(record); } } + rowsProcessed += rows.get().size(); + progressTracker.loadingDataPercent.set(rowsProcessed >= totalRows ? 100 : (int) (rowsProcessed * 100.0 / totalRows)); } } } @@ -179,12 +184,6 @@ private Consumer onProcessCrash(DataFrameAnalyticsTask task) { }; } - @Nullable - public Integer getProgressPercent(long allocationId) { - ProcessContext processContext = processContextByAllocation.get(allocationId); - return processContext == null ? null : processContext.progressPercent.get(); - } - private void refreshDest(DataFrameAnalyticsConfig config) { ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); @@ -222,7 +221,6 @@ class ProcessContext { private volatile AnalyticsProcess process; private volatile DataFrameDataExtractor dataExtractor; private volatile AnalyticsResultProcessor resultProcessor; - private final AtomicInteger progressPercent = new AtomicInteger(0); private volatile boolean processKilled; private volatile String failureReason; @@ -238,10 +236,6 @@ public boolean isProcessKilled() { return processKilled; } - void setProgressPercent(int progressPercent) { - this.progressPercent.set(progressPercent); - } - private synchronized void setFailureReason(String failureReason) { // Only set the new reason if there isn't one already as we want to keep the first reason if (failureReason != null) { @@ -282,7 +276,7 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr process = createProcess(task, createProcessConfig(config, dataExtractor)); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); - resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, this::setProgressPercent); + resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker()); return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 8a4f134de9a2b..30c063324b15a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -9,13 +9,13 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.util.Iterator; import java.util.Objects; import java.util.concurrent.CountDownLatch; -import java.util.function.Consumer; import java.util.function.Supplier; public class AnalyticsResultProcessor { @@ -25,16 +25,16 @@ public class AnalyticsResultProcessor { private final String dataFrameAnalyticsId; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final Supplier isProcessKilled; - private final Consumer progressConsumer; + private final ProgressTracker progressTracker; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier isProcessKilled, - Consumer progressConsumer) { + ProgressTracker progressTracker) { this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.isProcessKilled = Objects.requireNonNull(isProcessKilled); - this.progressConsumer = Objects.requireNonNull(progressConsumer); + this.progressTracker = Objects.requireNonNull(progressTracker); } @Nullable @@ -52,12 +52,25 @@ public void awaitForCompletion() { } public void process(AnalyticsProcess process) { + long totalRows = process.getConfig().rows(); + LOGGER.info("Total rows = {}", totalRows); + long processedRows = 0; + // TODO When java 9 features can be used, we will not need the local variable here try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { AnalyticsResult result = iterator.next(); processResult(result, resultsJoiner); + if (result.getRowResults() != null) { + processedRows++; + progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); + } + } + if (isProcessKilled.get() == false) { + // This means we completed successfully so we need to set the progress to 100. + // This is because due to skipped rows, it is possible the processed rows will not reach the total rows. + progressTracker.writingResultsPercent.set(100); } } catch (Exception e) { if (isProcessKilled.get()) { @@ -79,7 +92,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } Integer progressPercent = result.getProgressPercent(); if (progressPercent != null) { - progressConsumer.accept(progressPercent); + progressTracker.analyzingPercent.set(progressPercent); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java index abff4c863c3af..644751a47ccce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java @@ -6,21 +6,54 @@ package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.file.Path; +import java.util.Iterator; import java.util.List; +import java.util.Objects; import java.util.function.Consumer; public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess { private static final String NAME = "analytics"; - protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, - InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, - List filesToDelete, Consumer onProcessCrash) { + private final ProcessResultsParser resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER); + private final AnalyticsProcessConfig config; + + protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, + OutputStream processRestoreStream, int numberOfFields, List filesToDelete, + Consumer onProcessCrash, AnalyticsProcessConfig config) { super(NAME, AnalyticsResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); + this.config = Objects.requireNonNull(config); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() { + // Nothing to persist + } + + @Override + public void writeEndOfDataMessage() throws IOException { + new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData(); + } + + @Override + public Iterator readAnalyticsResults() { + return resultsParser.parseResults(processOutStream()); + } + + @Override + public AnalyticsProcessConfig getConfig() { + return config; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java index c41510019ba17..6aad810959f4d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java @@ -64,7 +64,7 @@ public NativeAnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProc NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(), processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, numberOfFields, - filesToDelete, onProcessCrash); + filesToDelete, onProcessCrash, analyticsProcessConfig); try { analyticsProcess.start(executorService); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java index 55c9ec7dbbd71..02bd188fc8328 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java @@ -24,4 +24,9 @@ protected NativeMemoryUsageEstimationProcess(String jobId, InputStream logStream super(NAME, MemoryUsageEstimationResult.PARSER, jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); } + + @Override + public AnalyticsProcessConfig getConfig() { + throw new UnsupportedOperationException(); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java new file mode 100644 index 0000000000000..572ca816f81e6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/StoredProgressTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class StoredProgressTests extends AbstractXContentTestCase { + + @Override + protected StoredProgress doParseInstance(XContentParser parser) throws IOException { + return StoredProgress.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected StoredProgress createTestInstance() { + int phaseCount = randomIntBetween(3, 7); + List progress = new ArrayList<>(phaseCount); + for (int i = 0; i < phaseCount; i++) { + progress.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))); + } + return new StoredProgress(progress); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 097437ce8a40f..6b4e54e19ff91 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,7 +5,10 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.junit.Before; @@ -28,8 +31,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase { private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; - private int progressPercent; - + private ProgressTracker progressTracker = new ProgressTracker(); @Before @SuppressWarnings("unchecked") @@ -39,6 +41,7 @@ public void setUpMocks() { } public void testProcess_GivenNoResults() { + givenDataFrameRows(0); givenProcessResults(Collections.emptyList()); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -50,6 +53,7 @@ public void testProcess_GivenNoResults() { } public void testProcess_GivenEmptyResults() { + givenDataFrameRows(2); givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -58,10 +62,11 @@ public void testProcess_GivenEmptyResults() { verify(dataFrameRowsJoiner).close(); Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); - assertThat(progressPercent, equalTo(100)); + assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } public void testProcess_GivenRowResults() { + givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100))); @@ -74,15 +79,20 @@ public void testProcess_GivenRowResults() { inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); - assertThat(progressPercent, equalTo(100)); + assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } + private void givenDataFrameRows(int rows) { + AnalyticsProcessConfig config = new AnalyticsProcessConfig( + rows, 1, ByteSizeValue.ZERO, 1, "ml", Collections.emptySet(), mock(DataFrameAnalysis.class)); + when(process.getConfig()).thenReturn(config); + } + private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, - progressPercent -> this.progressPercent = progressPercent); + return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, progressTracker); } }