Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Use feature reset API in ML REST test cleanup #71552

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,37 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ingest.DeletePipelineRequest;
import org.elasticsearch.client.core.PageParams;
import org.elasticsearch.client.ml.CloseJobRequest;
import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.DeleteDatafeedRequest;
import org.elasticsearch.client.ml.DeleteJobRequest;
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse;
import org.elasticsearch.client.ml.GetDatafeedRequest;
import org.elasticsearch.client.ml.GetDatafeedResponse;
import org.elasticsearch.client.ml.GetJobRequest;
import org.elasticsearch.client.ml.GetJobResponse;
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
import org.elasticsearch.client.feature.ResetFeaturesRequest;
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest;
import org.elasticsearch.client.ml.StopDatafeedRequest;
import org.elasticsearch.client.ml.datafeed.DatafeedConfig;
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.client.ml.job.config.Job;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Cleans up and ML resources created during tests
* Cleans up ML resources created during tests
*/
public class MlTestStateCleaner {

private static final Set<String> NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1");
private final Logger logger;
private final MachineLearningClient mlClient;
private final RestHighLevelClient client;

public MlTestStateCleaner(Logger logger, RestHighLevelClient client) {
this.logger = logger;
this.mlClient = client.machineLearning();
this.client = client;
}

public void clearMlMetadata() throws IOException {
deleteAllTrainedModels();
deleteAllDatafeeds();
deleteAllJobs();
deleteAllDataFrameAnalytics();
deleteAllTrainedModelIngestPipelines();
// This resets all features, not just ML, but they should have been getting reset between tests anyway so it shouldn't matter
client.features().resetFeatures(new ResetFeaturesRequest(), RequestOptions.DEFAULT);
}

@SuppressWarnings("unchecked")
private void deleteAllTrainedModels() throws IOException {
Set<String> pipelinesWithModels = mlClient.getTrainedModelsStats(
private void deleteAllTrainedModelIngestPipelines() throws IOException {
Set<String> pipelinesWithModels = client.machineLearning().getTrainedModelsStats(
new GetTrainedModelsStatsRequest("_all").setPageParams(new PageParams(0, 10_000)), RequestOptions.DEFAULT
).getTrainedModelStats()
.stream()
Expand All @@ -86,95 +64,5 @@ private void deleteAllTrainedModels() throws IOException {
logger.warn(() -> new ParameterizedMessage("failed to delete pipeline [{}]", pipelineId), ex);
}
}

mlClient.getTrainedModels(
GetTrainedModelsRequest.getAllTrainedModelConfigsRequest().setPageParams(new PageParams(0, 10_000)),
RequestOptions.DEFAULT)
.getTrainedModels()
.stream()
.filter(trainedModelConfig -> NOT_DELETED_TRAINED_MODELS.contains(trainedModelConfig.getModelId()) == false)
.forEach(config -> {
try {
mlClient.deleteTrainedModel(new DeleteTrainedModelRequest(config.getModelId()), RequestOptions.DEFAULT);
} catch (IOException ex) {
throw new UncheckedIOException(ex);
}
});
}

private void deleteAllDatafeeds() throws IOException {
stopAllDatafeeds();

GetDatafeedResponse getDatafeedResponse = mlClient.getDatafeed(GetDatafeedRequest.getAllDatafeedsRequest(), RequestOptions.DEFAULT);
for (DatafeedConfig datafeed : getDatafeedResponse.datafeeds()) {
mlClient.deleteDatafeed(new DeleteDatafeedRequest(datafeed.getId()), RequestOptions.DEFAULT);
}
}

private void stopAllDatafeeds() {
StopDatafeedRequest stopAllDatafeedsRequest = StopDatafeedRequest.stopAllDatafeedsRequest();
try {
mlClient.stopDatafeed(stopAllDatafeedsRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to stop all datafeeds. Forcing stop", e1);
try {
stopAllDatafeedsRequest.setForce(true);
mlClient.stopDatafeed(stopAllDatafeedsRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("Force-closing all data feeds failed", e2);
}
throw new RuntimeException("Had to resort to force-stopping datafeeds, something went wrong?", e1);
}
}

private void deleteAllJobs() throws IOException {
closeAllJobs();

GetJobResponse getJobResponse = mlClient.getJob(GetJobRequest.getAllJobsRequest(), RequestOptions.DEFAULT);
for (Job job : getJobResponse.jobs()) {
mlClient.deleteJob(new DeleteJobRequest(job.getId()), RequestOptions.DEFAULT);
}
}

private void closeAllJobs() {
CloseJobRequest closeAllJobsRequest = CloseJobRequest.closeAllJobsRequest();
try {
mlClient.closeJob(closeAllJobsRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to close all jobs. Forcing closed", e1);
closeAllJobsRequest.setForce(true);
try {
mlClient.closeJob(closeAllJobsRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("Force-closing all jobs failed", e2);
}
throw new RuntimeException("Had to resort to force-closing jobs, something went wrong?", e1);
}
}

private void deleteAllDataFrameAnalytics() throws IOException {
stopAllDataFrameAnalytics();

GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse =
mlClient.getDataFrameAnalytics(GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), RequestOptions.DEFAULT);
for (DataFrameAnalyticsConfig config : getDataFrameAnalyticsResponse.getAnalytics()) {
mlClient.deleteDataFrameAnalytics(new DeleteDataFrameAnalyticsRequest(config.getId()), RequestOptions.DEFAULT);
}
}

private void stopAllDataFrameAnalytics() {
StopDataFrameAnalyticsRequest stopAllRequest = new StopDataFrameAnalyticsRequest("*");
try {
mlClient.stopDataFrameAnalytics(stopAllRequest, RequestOptions.DEFAULT);
} catch (Exception e1) {
logger.warn("failed to stop all data frame analytics. Will proceed to force-stopping", e1);
stopAllRequest.setForce(true);
try {
mlClient.stopDataFrameAnalytics(stopAllRequest, RequestOptions.DEFAULT);
} catch (Exception e2) {
logger.warn("force-stopping all data frame analytics failed", e2);
}
throw new RuntimeException("Had to resort to force-stopping data frame analytics, something went wrong?", e1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,46 @@

public class SetResetModeActionRequest extends AcknowledgedRequest<SetResetModeActionRequest> implements ToXContentObject {
public static SetResetModeActionRequest enabled() {
return new SetResetModeActionRequest(true);
return new SetResetModeActionRequest(true, false);
}

public static SetResetModeActionRequest disabled() {
return new SetResetModeActionRequest(false);
public static SetResetModeActionRequest disabled(boolean deleteMetadata) {
return new SetResetModeActionRequest(false, deleteMetadata);
}

private final boolean enabled;
private final boolean deleteMetadata;

private static final ParseField ENABLED = new ParseField("enabled");
private static final ParseField DELETE_METADATA = new ParseField("delete_metadata");
public static final ConstructingObjectParser<SetResetModeActionRequest, Void> PARSER =
new ConstructingObjectParser<>("set_reset_mode_action_request", a -> new SetResetModeActionRequest((Boolean)a[0]));
new ConstructingObjectParser<>("set_reset_mode_action_request",
a -> new SetResetModeActionRequest((Boolean)a[0], (Boolean)a[1]));

static {
PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ENABLED);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DELETE_METADATA);
}

SetResetModeActionRequest(boolean enabled) {
SetResetModeActionRequest(boolean enabled, Boolean deleteMetadata) {
this.enabled = enabled;
this.deleteMetadata = deleteMetadata != null && deleteMetadata;
}

public SetResetModeActionRequest(StreamInput in) throws IOException {
super(in);
this.enabled = in.readBoolean();
this.deleteMetadata = in.readBoolean();
}

public boolean isEnabled() {
return enabled;
}

public boolean shouldDeleteMetadata() {
return deleteMetadata;
}

@Override
public ActionRequestValidationException validate() {
return null;
Expand All @@ -61,11 +71,12 @@ public ActionRequestValidationException validate() {
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(enabled);
out.writeBoolean(deleteMetadata);
}

@Override
public int hashCode() {
return Objects.hash(enabled);
return Objects.hash(enabled, deleteMetadata);
}

@Override
Expand All @@ -77,13 +88,17 @@ public boolean equals(Object obj) {
return false;
}
SetResetModeActionRequest other = (SetResetModeActionRequest) obj;
return Objects.equals(enabled, other.enabled);
return Objects.equals(enabled, other.enabled)
&& Objects.equals(deleteMetadata, other.deleteMetadata);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(ENABLED.getPreferredName(), enabled);
if (enabled == false) {
builder.field(DELETE_METADATA.getPreferredName(), deleteMetadata);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ public final class MlTasks {
public static final PersistentTasksCustomMetadata.Assignment AWAITING_UPGRADE =
new PersistentTasksCustomMetadata.Assignment(null,
"persistent task cannot be assigned while upgrade mode is enabled.");
public static final PersistentTasksCustomMetadata.Assignment RESET_IN_PROGRESS =
new PersistentTasksCustomMetadata.Assignment(null,
"persistent task will not be assigned as a feature reset is in progress.");

private MlTasks() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ public class SetResetModeActionRequestTests extends AbstractSerializingTestCase<

@Override
protected SetResetModeActionRequest createTestInstance() {
return new SetResetModeActionRequest(randomBoolean());
boolean enabled = randomBoolean();
return new SetResetModeActionRequest(enabled, enabled == false && randomBoolean());
}

@Override
Expand Down
Loading