Skip to content

Commit

Permalink
[ML] Check licence when datafeeds use cross cluster search (#31247)
Browse files Browse the repository at this point in the history
This change prevents a datafeed using cross cluster search from starting if the remote cluster
does not have x-pack installed and a sufficient license. The check is made only when starting a 
datafeed.
  • Loading branch information
davidkyle authored Jun 13, 2018
1 parent 7199d5f commit 88f44a9
Show file tree
Hide file tree
Showing 6 changed files with 493 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.TimeZone;
Expand Down Expand Up @@ -193,11 +192,11 @@ public void testDefaults() {

public void testDefaultQueryDelay() {
DatafeedConfig.Builder feedBuilder1 = new DatafeedConfig.Builder("datafeed1", "job1");
feedBuilder1.setIndices(Arrays.asList("foo"));
feedBuilder1.setIndices(Collections.singletonList("foo"));
DatafeedConfig.Builder feedBuilder2 = new DatafeedConfig.Builder("datafeed2", "job1");
feedBuilder2.setIndices(Arrays.asList("foo"));
feedBuilder2.setIndices(Collections.singletonList("foo"));
DatafeedConfig.Builder feedBuilder3 = new DatafeedConfig.Builder("datafeed3", "job2");
feedBuilder3.setIndices(Arrays.asList("foo"));
feedBuilder3.setIndices(Collections.singletonList("foo"));
DatafeedConfig feed1 = feedBuilder1.build();
DatafeedConfig feed2 = feedBuilder2.build();
DatafeedConfig feed3 = feedBuilder3.build();
Expand All @@ -208,19 +207,19 @@ public void testDefaultQueryDelay() {
assertThat(feed1.getQueryDelay(), not(equalTo(feed3.getQueryDelay())));
}

public void testCheckValid_GivenNullIndices() throws IOException {
public void testCheckValid_GivenNullIndices() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
expectThrows(IllegalArgumentException.class, () -> conf.setIndices(null));
}

public void testCheckValid_GivenEmptyIndices() throws IOException {
public void testCheckValid_GivenEmptyIndices() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
conf.setIndices(Collections.emptyList());
ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, conf::build);
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[]"), e.getMessage());
}

public void testCheckValid_GivenIndicesContainsOnlyNulls() throws IOException {
public void testCheckValid_GivenIndicesContainsOnlyNulls() {
List<String> indices = new ArrayList<>();
indices.add(null);
indices.add(null);
Expand All @@ -230,7 +229,7 @@ public void testCheckValid_GivenIndicesContainsOnlyNulls() throws IOException {
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[null, null]"), e.getMessage());
}

public void testCheckValid_GivenIndicesContainsOnlyEmptyStrings() throws IOException {
public void testCheckValid_GivenIndicesContainsOnlyEmptyStrings() {
List<String> indices = new ArrayList<>();
indices.add("");
indices.add("");
Expand All @@ -240,27 +239,27 @@ public void testCheckValid_GivenIndicesContainsOnlyEmptyStrings() throws IOExcep
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[, ]"), e.getMessage());
}

public void testCheckValid_GivenNegativeQueryDelay() throws IOException {
public void testCheckValid_GivenNegativeQueryDelay() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class,
() -> conf.setQueryDelay(TimeValue.timeValueMillis(-10)));
assertEquals("query_delay cannot be less than 0. Value = -10", e.getMessage());
}

public void testCheckValid_GivenZeroFrequency() throws IOException {
public void testCheckValid_GivenZeroFrequency() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class, () -> conf.setFrequency(TimeValue.ZERO));
assertEquals("frequency cannot be less or equal than 0. Value = 0s", e.getMessage());
}

public void testCheckValid_GivenNegativeFrequency() throws IOException {
public void testCheckValid_GivenNegativeFrequency() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class,
() -> conf.setFrequency(TimeValue.timeValueMinutes(-1)));
assertEquals("frequency cannot be less or equal than 0. Value = -1", e.getMessage());
}

public void testCheckValid_GivenNegativeScrollSize() throws IOException {
public void testCheckValid_GivenNegativeScrollSize() {
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, () -> conf.setScrollSize(-1000));
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "scroll_size", -1000L), e.getMessage());
Expand Down Expand Up @@ -414,7 +413,7 @@ public void testDefaultFrequency_GivenNegative() {

public void testDefaultFrequency_GivenNoAggregations() {
DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("feed", "job");
datafeedBuilder.setIndices(Arrays.asList("my_index"));
datafeedBuilder.setIndices(Collections.singletonList("my_index"));
DatafeedConfig datafeed = datafeedBuilder.build();

assertEquals(TimeValue.timeValueMinutes(1), datafeed.defaultFrequency(TimeValue.timeValueSeconds(1)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@
import org.elasticsearch.persistent.PersistentTasksExecutor;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.datafeed.MlRemoteLicenseChecker;
import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
import org.elasticsearch.xpack.ml.datafeed.DatafeedNodeSelector;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory;

import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

Expand Down Expand Up @@ -111,40 +113,65 @@ protected void masterOperation(StartDatafeedAction.Request request, ClusterState
ActionListener<StartDatafeedAction.Response> listener) {
StartDatafeedAction.DatafeedParams params = request.getParams();
if (licenseState.isMachineLearningAllowed()) {
ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>> finalListener =

ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>> waitForTaskListener =
new ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>>() {
@Override
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask) {
waitForDatafeedStarted(persistentTask.getId(), params, listener);
}
@Override
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>
persistentTask) {
waitForDatafeedStarted(persistentTask.getId(), params, listener);
}

@Override
public void onFailure(Exception e) {
if (e instanceof ResourceAlreadyExistsException) {
logger.debug("datafeed already started", e);
e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() +
"] because it has already been started", RestStatus.CONFLICT);
}
listener.onFailure(e);
}
};
@Override
public void onFailure(Exception e) {
if (e instanceof ResourceAlreadyExistsException) {
logger.debug("datafeed already started", e);
e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() +
"] because it has already been started", RestStatus.CONFLICT);
}
listener.onFailure(e);
}
};

// Verify data extractor factory can be created, then start persistent task
MlMetadata mlMetadata = MlMetadata.getMlMetadata(state);
PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
validate(params.getDatafeedId(), mlMetadata, tasks);
DatafeedConfig datafeed = mlMetadata.getDatafeed(params.getDatafeedId());
Job job = mlMetadata.getJobs().get(datafeed.getJobId());
DataExtractorFactory.create(client, datafeed, job, ActionListener.wrap(
dataExtractorFactory ->
persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()),
StartDatafeedAction.TASK_NAME, params, finalListener)
, listener::onFailure));

if (MlRemoteLicenseChecker.containsRemoteIndex(datafeed.getIndices())) {
MlRemoteLicenseChecker remoteLicenseChecker = new MlRemoteLicenseChecker(client);
remoteLicenseChecker.checkRemoteClusterLicenses(MlRemoteLicenseChecker.remoteClusterNames(datafeed.getIndices()),
ActionListener.wrap(
response -> {
if (response.isViolated()) {
listener.onFailure(createUnlicensedError(datafeed.getId(), response));
} else {
createDataExtractor(job, datafeed, params, waitForTaskListener);
}
},
e -> listener.onFailure(createUnknownLicenseError(datafeed.getId(),
MlRemoteLicenseChecker.remoteIndices(datafeed.getIndices()), e))
));
} else {
createDataExtractor(job, datafeed, params, waitForTaskListener);
}
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
}
}

private void createDataExtractor(Job job, DatafeedConfig datafeed, StartDatafeedAction.DatafeedParams params,
ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>>
listener) {
DataExtractorFactory.create(client, datafeed, job, ActionListener.wrap(
dataExtractorFactory ->
persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()),
StartDatafeedAction.TASK_NAME, params, listener)
, listener::onFailure));
}

@Override
protected ClusterBlockException checkBlock(StartDatafeedAction.Request request, ClusterState state) {
// We only delegate here to PersistentTasksService, but if there is a metadata writeblock,
Expand All @@ -158,28 +185,29 @@ private void waitForDatafeedStarted(String taskId, StartDatafeedAction.DatafeedP
DatafeedPredicate predicate = new DatafeedPredicate();
persistentTasksService.waitForPersistentTaskCondition(taskId, predicate, params.getTimeout(),
new PersistentTasksService.WaitForPersistentTaskListener<StartDatafeedAction.DatafeedParams>() {
@Override
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask) {
if (predicate.exception != null) {
// We want to return to the caller without leaving an unassigned persistent task, to match
// what would have happened if the error had been detected in the "fast fail" validation
cancelDatafeedStart(persistentTask, predicate.exception, listener);
} else {
listener.onResponse(new StartDatafeedAction.Response(true));
}
}
@Override
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>
persistentTask) {
if (predicate.exception != null) {
// We want to return to the caller without leaving an unassigned persistent task, to match
// what would have happened if the error had been detected in the "fast fail" validation
cancelDatafeedStart(persistentTask, predicate.exception, listener);
} else {
listener.onResponse(new StartDatafeedAction.Response(true));
}
}

@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}

@Override
public void onTimeout(TimeValue timeout) {
listener.onFailure(new ElasticsearchException("Starting datafeed ["
+ params.getDatafeedId() + "] timed out after [" + timeout + "]"));
}
});
@Override
public void onTimeout(TimeValue timeout) {
listener.onFailure(new ElasticsearchException("Starting datafeed ["
+ params.getDatafeedId() + "] timed out after [" + timeout + "]"));
}
});
}

private void cancelDatafeedStart(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask,
Expand All @@ -203,6 +231,25 @@ public void onFailure(Exception e) {
);
}

private ElasticsearchStatusException createUnlicensedError(String datafeedId,
MlRemoteLicenseChecker.LicenseViolation licenseViolation) {
String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use "
+ "indices on a remote cluster [" + licenseViolation.get().getClusterName()
+ "] that is not licensed for Machine Learning. "
+ MlRemoteLicenseChecker.buildErrorMessage(licenseViolation.get());

return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
}

private ElasticsearchStatusException createUnknownLicenseError(String datafeedId, List<String> remoteIndices,
Exception cause) {
String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use"
+ " indices on a remote cluster " + remoteIndices
+ " but the license type could not be verified";

return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, new Exception(cause.getMessage()));
}

public static class StartDatafeedPersistentTasksExecutor extends PersistentTasksExecutor<StartDatafeedAction.DatafeedParams> {
private final DatafeedManager datafeedManager;
private final IndexNameExpressionResolver resolver;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private AssignmentFailure verifyIndicesActive(DatafeedConfig datafeed) {
List<String> indices = datafeed.getIndices();
for (String index : indices) {

if (isRemoteIndex(index)) {
if (MlRemoteLicenseChecker.isRemoteIndex(index)) {
// We cannot verify remote indices
continue;
}
Expand Down Expand Up @@ -122,10 +122,6 @@ private AssignmentFailure verifyIndicesActive(DatafeedConfig datafeed) {
return null;
}

private boolean isRemoteIndex(String index) {
return index.indexOf(':') != -1;
}

private static class AssignmentFailure {
private final String reason;
private final boolean isCriticalForTaskCreation;
Expand Down
Loading

0 comments on commit 88f44a9

Please sign in to comment.