diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastAction.java index b6888d6d606fe..57d082e64bfd1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastAction.java @@ -19,7 +19,8 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Client; -import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; @@ -36,13 +37,16 @@ import org.elasticsearch.index.reindex.DeleteByQueryAction; import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.index.reindex.ScrollableHitSource; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; +import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.results.Forecast; @@ -55,6 +59,7 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.EnumSet; import java.util.HashSet; import java.util.List; @@ -71,21 +76,28 @@ public class TransportDeleteForecastAction extends HandledTransportAction DELETABLE_STATUSES = EnumSet.of(ForecastRequestStatus.FINISHED, ForecastRequestStatus.FAILED); @Inject - public TransportDeleteForecastAction(TransportService transportService, ActionFilters actionFilters, Client client) { + public TransportDeleteForecastAction(TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService) { super(DeleteForecastAction.NAME, transportService, actionFilters, DeleteForecastAction.Request::new); this.client = client; + this.clusterService = clusterService; } @Override protected void doExecute(Task task, DeleteForecastAction.Request request, ActionListener listener) { final String jobId = request.getJobId(); - final String forecastsExpression = request.getForecastId(); + + String forecastsExpression = request.getForecastId(); + final String[] forecastIds = Strings.tokenizeToStringArray(forecastsExpression, ","); ActionListener forecastStatsHandler = ActionListener.wrap( searchResponse -> deleteForecasts(searchResponse, request, listener), e -> listener.onFailure(new ElasticsearchException("An error occurred while searching forecasts to delete", e))); @@ -95,10 +107,8 @@ protected void doExecute(Task task, DeleteForecastAction.Request request, Action BoolQueryBuilder builder = QueryBuilders.boolQuery(); BoolQueryBuilder innerBool = QueryBuilders.boolQuery().must( QueryBuilders.termQuery(Result.RESULT_TYPE.getPreferredName(), ForecastRequestStats.RESULT_TYPE_VALUE)); - - if (Metadata.ALL.equals(request.getForecastId()) == false) { - Set forcastIds = new HashSet<>(Arrays.asList(Strings.tokenizeToStringArray(forecastsExpression, ","))); - innerBool.must(QueryBuilders.termsQuery(Forecast.FORECAST_ID.getPreferredName(), forcastIds)); + if (Strings.isAllOrWildcard(forecastIds) == false) { + innerBool.must(QueryBuilders.termsQuery(Forecast.FORECAST_ID.getPreferredName(), new HashSet<>(Arrays.asList(forecastIds)))); } source.query(builder.filter(innerBool)); @@ -109,6 +119,17 @@ protected void doExecute(Task task, DeleteForecastAction.Request request, Action executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, forecastStatsHandler); } + static void validateForecastState(Collection forecastsToDelete, JobState jobState, String jobId) { + List badStatusForecasts = forecastsToDelete.stream() + .filter((f) -> DELETABLE_STATUSES.contains(f.getStatus()) == false) + .map(ForecastRequestStats::getForecastId) + .collect(Collectors.toList()); + if (badStatusForecasts.size() > 0 && JobState.OPENED.equals(jobState)) { + throw ExceptionsHelper.conflictStatusException( + Messages.getMessage(Messages.REST_CANNOT_DELETE_FORECAST_IN_CURRENT_STATE, badStatusForecasts, jobId)); + } + } + private void deleteForecasts(SearchResponse searchResponse, DeleteForecastAction.Request request, ActionListener listener) { @@ -122,7 +143,7 @@ private void deleteForecasts(SearchResponse searchResponse, } if (forecastsToDelete.isEmpty()) { - if (Metadata.ALL.equals(request.getForecastId()) && + if (Strings.isAllOrWildcard(new String[]{request.getForecastId()}) && request.isAllowNoForecasts()) { listener.onResponse(new AcknowledgedResponse(true)); } else { @@ -131,13 +152,13 @@ private void deleteForecasts(SearchResponse searchResponse, } return; } - List badStatusForecasts = forecastsToDelete.stream() - .filter((f) -> !DELETABLE_STATUSES.contains(f.getStatus())) - .map(ForecastRequestStats::getForecastId).collect(Collectors.toList()); - if (badStatusForecasts.size() > 0) { - listener.onFailure( - ExceptionsHelper.conflictStatusException( - Messages.getMessage(Messages.REST_CANNOT_DELETE_FORECAST_IN_CURRENT_STATE, badStatusForecasts, jobId))); + final ClusterState state = clusterService.state(); + PersistentTasksCustomMetadata persistentTasks = state.metadata().custom(PersistentTasksCustomMetadata.TYPE); + JobState jobState = MlTasks.getJobState(jobId, persistentTasks); + try { + validateForecastState(forecastsToDelete, jobState, jobId); + } catch (ElasticsearchException ex) { + listener.onFailure(ex); return; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastActionTests.java new file mode 100644 index 0000000000000..4658b9482e01f --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportDeleteForecastActionTests.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.job.config.JobState; +import org.elasticsearch.xpack.core.ml.job.results.ForecastRequestStats; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TransportDeleteForecastActionTests extends ESTestCase { + + private static final int TEST_RUNS = 10; + + public void testValidateForecastStateWithAllFailedFinished() { + for (int i = 0; i < TEST_RUNS; ++i) { + List forecastRequestStats = Stream.generate( + () -> createForecastStats(randomFrom( + ForecastRequestStats.ForecastRequestStatus.FAILED, + ForecastRequestStats.ForecastRequestStatus.FINISHED + ))) + .limit(randomInt(10)) + .collect(Collectors.toList()); + + // This should not throw. + TransportDeleteForecastAction.validateForecastState( + forecastRequestStats, + randomFrom(JobState.values()), + randomAlphaOfLength(10)); + } + } + + public void testValidateForecastStateWithSomeFailedFinished() { + for (int i = 0; i < TEST_RUNS; ++i) { + List forecastRequestStats = Stream.generate( + () -> createForecastStats(randomFrom( + ForecastRequestStats.ForecastRequestStatus.values() + ))) + .limit(randomInt(10)) + .collect(Collectors.toList()); + + forecastRequestStats.add(createForecastStats(ForecastRequestStats.ForecastRequestStatus.STARTED)); + + { + JobState jobState = randomFrom(JobState.CLOSED, JobState.CLOSING, JobState.FAILED); + try { + TransportDeleteForecastAction.validateForecastState(forecastRequestStats, jobState, randomAlphaOfLength(10)); + } catch (Exception ex) { + fail("Should not have thrown: " + ex.getMessage()); + } + } + { + JobState jobState = JobState.OPENED; + expectThrows( + ElasticsearchStatusException.class, + () -> TransportDeleteForecastAction.validateForecastState(forecastRequestStats, jobState, randomAlphaOfLength(10)) + ); + } + } + } + + + private static ForecastRequestStats createForecastStats(ForecastRequestStats.ForecastRequestStatus status) { + ForecastRequestStats forecastRequestStats = new ForecastRequestStats(randomAlphaOfLength(10), randomAlphaOfLength(10)); + forecastRequestStats.setStatus(status); + return forecastRequestStats; + } + +}