Skip to content

Commit

Permalink
[ML] Use feature reset API in ML REST test cleanup (#71552)
Browse files Browse the repository at this point in the history
Now that we have a feature reset API, we should use
this for cleaning up in between tests instead of running
lots of bespoke cleanup code.

During testing of this change we found we need to
delete custom cluster state as part of the reset process,
so this PR also implements that.

Additionally we no longer assign persistent tasks
during feature reset.
  • Loading branch information
droberts195 authored Apr 13, 2021
1 parent fb1921c commit c436458
Show file tree
Hide file tree
Showing 21 changed files with 220 additions and 289 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,37 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ingest.DeletePipelineRequest;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.DeleteDatafeedRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.GetDatafeedRequest;
import org.elasticsearch.client.ml.GetDatafeedResponse;
import org.elasticsearch.client.ml.GetJobRequest;
import org.elasticsearch.client.ml.GetJobResponse;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.feature.ResetFeaturesRequest;
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.StopDatafeedRequest;
import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.job.config.Job;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Cleans up and ML resources created during tests
* Cleans up ML resources created during tests
*/
public class MlTestStateCleaner {

private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
private final Logger logger;
private final MachineLearningClient mlClient;
private final RestHighLevelClient client;

public MlTestStateCleaner(Logger logger, RestHighLevelClient client) {
this.logger = logger;
this.mlClient = client.machineLearning();
this.client = client;
}

public void clearMlMetadata() throws IOException {
deleteAllTrainedModels();
deleteAllDatafeeds();
deleteAllJobs();
deleteAllDataFrameAnalytics();
deleteAllTrainedModelIngestPipelines();
// This resets all features, not just ML, but they should have been getting reset between tests anyway so it shouldn't matter
client.features().resetFeatures(new ResetFeaturesRequest(), RequestOptions.DEFAULT);
}

@SuppressWarnings("unchecked")
private void deleteAllTrainedModels() throws IOException {
Set<String> pipelinesWithModels = mlClient.getTrainedModelsStats(
private void deleteAllTrainedModelIngestPipelines() throws IOException {
Set<String> pipelinesWithModels = client.machineLearning().getTrainedModelsStats(
new GetTrainedModelsStatsRequest("_all").setPageParams(new PageParams(0, 10_000)), RequestOptions.DEFAULT
).getTrainedModelStats()
.stream()
Expand All @@ -86,95 +64,5 @@ private void deleteAllTrainedModels() throws IOException {
logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex);
}
}

mlClient.getTrainedModels(
GetTrainedModelsRequest.getAllTrainedModelConfigsRequest().setPageParams(new PageParams(0, 10_000)),
RequestOptions.DEFAULT)
.getTrainedModels()
.stream()
.filter(trainedModelConfig -> NOT_DELETED_TRAINED_MODELS.contains(trainedModelConfig.getModelId()) == false)
.forEach(config -> {
try {
mlClient.deleteTrainedModel(new DeleteTrainedModelRequest(config.getModelId()), RequestOptions.DEFAULT);
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
});
}

private void deleteAllDatafeeds() throws IOException {
stopAllDatafeeds();

GetDatafeedResponse getDatafeedResponse = mlClient.getDatafeed(GetDatafeedRequest.getAllDatafeedsRequest(), RequestOptions.DEFAULT);
for (DatafeedConfig datafeed : getDatafeedResponse.datafeeds()) {
mlClient.deleteDatafeed(new DeleteDatafeedRequest(datafeed.getId()), RequestOptions.DEFAULT);
}
}

private void stopAllDatafeeds() {
StopDatafeedRequest stopAllDatafeedsRequest = StopDatafeedRequest.stopAllDatafeedsRequest();
try {
mlClient.stopDatafeed(stopAllDatafeedsRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to stop all datafeeds. Forcing stop", e1);
try {
stopAllDatafeedsRequest.setForce(true);
mlClient.stopDatafeed(stopAllDatafeedsRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("Force-closing all data feeds failed", e2);
}
throw new RuntimeException("Had to resort to force-stopping datafeeds, something went wrong?", e1);
}
}

private void deleteAllJobs() throws IOException {
closeAllJobs();

GetJobResponse getJobResponse = mlClient.getJob(GetJobRequest.getAllJobsRequest(), RequestOptions.DEFAULT);
for (Job job : getJobResponse.jobs()) {
mlClient.deleteJob(new DeleteJobRequest(job.getId()), RequestOptions.DEFAULT);
}
}

private void closeAllJobs() {
CloseJobRequest closeAllJobsRequest = CloseJobRequest.closeAllJobsRequest();
try {
mlClient.closeJob(closeAllJobsRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to close all jobs. Forcing closed", e1);
closeAllJobsRequest.setForce(true);
try {
mlClient.closeJob(closeAllJobsRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("Force-closing all jobs failed", e2);
}
throw new RuntimeException("Had to resort to force-closing jobs, something went wrong?", e1);
}
}

private void deleteAllDataFrameAnalytics() throws IOException {
stopAllDataFrameAnalytics();

GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse =
mlClient.getDataFrameAnalytics(GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), RequestOptions.DEFAULT);
for (DataFrameAnalyticsConfig config : getDataFrameAnalyticsResponse.getAnalytics()) {
mlClient.deleteDataFrameAnalytics(new DeleteDataFrameAnalyticsRequest(config.getId()), RequestOptions.DEFAULT);
}
}

private void stopAllDataFrameAnalytics() {
StopDataFrameAnalyticsRequest stopAllRequest = new StopDataFrameAnalyticsRequest("*");
try {
mlClient.stopDataFrameAnalytics(stopAllRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to stop all data frame analytics. Will proceed to force-stopping", e1);
stopAllRequest.setForce(true);
try {
mlClient.stopDataFrameAnalytics(stopAllRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("force-stopping all data frame analytics failed", e2);
}
throw new RuntimeException("Had to resort to force-stopping data frame analytics, something went wrong?", e1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,46 @@

public class SetResetModeActionRequest extends AcknowledgedRequest<SetResetModeActionRequest> implements ToXContentObject {
public static SetResetModeActionRequest enabled() {
return new SetResetModeActionRequest(true);
return new SetResetModeActionRequest(true, false);
}

public static SetResetModeActionRequest disabled() {
return new SetResetModeActionRequest(false);
public static SetResetModeActionRequest disabled(boolean deleteMetadata) {
return new SetResetModeActionRequest(false, deleteMetadata);
}

private final boolean enabled;
private final boolean deleteMetadata;

private static final ParseField ENABLED = new ParseField("enabled");
private static final ParseField DELETE_METADATA = new ParseField("delete_metadata");
public static final ConstructingObjectParser<SetResetModeActionRequest, Void> PARSER =
new ConstructingObjectParser<>("set_reset_mode_action_request", a -> new SetResetModeActionRequest((Boolean)a[0]));
new ConstructingObjectParser<>("set_reset_mode_action_request",
a -> new SetResetModeActionRequest((Boolean)a[0], (Boolean)a[1]));

static {
PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DELETE_METADATA);
}

SetResetModeActionRequest(boolean enabled) {
SetResetModeActionRequest(boolean enabled, Boolean deleteMetadata) {
this.enabled = enabled;
this.deleteMetadata = deleteMetadata != null && deleteMetadata;
}

public SetResetModeActionRequest(StreamInput in) throws IOException {
super(in);
this.enabled = in.readBoolean();
this.deleteMetadata = in.readBoolean();
}

public boolean isEnabled() {
return enabled;
}

public boolean shouldDeleteMetadata() {
return deleteMetadata;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -61,11 +71,12 @@ public ActionRequestValidationException validate() {
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(enabled);
out.writeBoolean(deleteMetadata);
}

@Override
public int hashCode() {
return Objects.hash(enabled);
return Objects.hash(enabled, deleteMetadata);
}

@Override
Expand All @@ -77,13 +88,17 @@ public boolean equals(Object obj) {
return false;
}
SetResetModeActionRequest other = (SetResetModeActionRequest) obj;
return Objects.equals(enabled, other.enabled);
return Objects.equals(enabled, other.enabled)
&& Objects.equals(deleteMetadata, other.deleteMetadata);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(ENABLED.getPreferredName(), enabled);
if (enabled == false) {
builder.field(DELETE_METADATA.getPreferredName(), deleteMetadata);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public final class MlTasks {
public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE =
new PersistentTasksCustomMetadata.Assignment(null,
"persistent task cannot be assigned while upgrade mode is enabled.");
public static final PersistentTasksCustomMetadata.Assignment RESET_IN_PROGRESS =
new PersistentTasksCustomMetadata.Assignment(null,
"persistent task will not be assigned as a feature reset is in progress.");

private MlTasks() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public class SetResetModeActionRequestTests extends AbstractSerializingTestCase<

@Override
protected SetResetModeActionRequest createTestInstance() {
return new SetResetModeActionRequest(randomBoolean());
boolean enabled = randomBoolean();
return new SetResetModeActionRequest(enabled, enabled == false && randomBoolean());
}

@Override
Expand Down
Loading

0 comments on commit c436458

Please sign in to comment.