Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[8.15] Add LTR License Check on PUT for Enterprise Licensing (#111248) #111460

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
Expand All @@ -22,6 +24,7 @@
import java.util.Arrays;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.xpack.core.ml.MachineLearningField.ML_API_FEATURE;

public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {

Expand Down Expand Up @@ -114,4 +117,12 @@ default ElasticsearchStatusException incompatibleUpdateException(String updateNa
updateName
);
}

default License.OperationMode getMinLicenseSupported() {
return ML_API_FEATURE.getMinimumOperationMode();
}

default License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
return getMinLicenseSupported();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -226,6 +228,14 @@ public TransportVersion getMinimalSupportedTransportVersion() {
return MIN_SUPPORTED_TRANSPORT_VERSION;
}

@Override
public License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
if (method == RestRequest.Method.PUT) {
return License.OperationMode.ENTERPRISE;
}
return super.getMinLicenseSupportedForAction(method);
}

@Override
public LearningToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
if (this.featureExtractorBuilders.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.NamedXContentRegistry;
Expand All @@ -36,6 +38,7 @@
import java.util.stream.Stream;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.hamcrest.Matchers.is;

public class LearningToRankConfigTests extends InferenceConfigItemTestCase<LearningToRankConfig> {
private boolean lenient;
Expand Down Expand Up @@ -140,6 +143,16 @@ public void testDuplicateFeatureNames() {
expectThrows(IllegalArgumentException.class, () -> builder.build());
}

public void testLicenseSupport_ForPutAction_RequiresEnterprise() {
var config = randomLearningToRankConfig();
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.PUT), is(License.OperationMode.ENTERPRISE));
}

public void testLicenseSupport_ForGetAction_RequiresPlatinum() {
var config = randomLearningToRankConfig();
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.GET), is(License.OperationMode.PLATINUM));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
Expand Down Expand Up @@ -37,6 +38,7 @@
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
Expand Down Expand Up @@ -143,61 +145,7 @@ protected void masterOperation(
// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
boolean hasModelDefinition = config.getModelDefinition() != null;
if (hasModelDefinition) {
try {
config.getModelDefinition().getTrainedModel().validate();
} catch (ElasticsearchException ex) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
);
return;
}

TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
if (trainedModelType == null) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Unknown trained model definition class [{}]",
config.getModelDefinition().getTrainedModel().getName()
)
);
return;
}

if (config.getModelType() == null) {
// Set the model type from the definition
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
} else if (trainedModelType != config.getModelType()) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"{} [{}] does not match the model definition type [{}]",
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
config.getModelType(),
trainedModelType
)
);
return;
}

if (config.getInferenceConfig().isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
config.getModelId(),
config.getInferenceConfig().getName(),
config.getModelDefinition().getTrainedModel().targetType()
)
);
return;
}

TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Cannot create model [{}] while cluster upgrade is in progress.",
config.getModelId()
)
);
if (validateModelDefinition(config, state, licenseState, finalResponseListener) == false) {
return;
}
}
Expand Down Expand Up @@ -507,6 +455,85 @@ private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> li
);
}

public static boolean validateModelDefinition(
TrainedModelConfig config,
ClusterState state,
XPackLicenseState licenseState,
ActionListener<Response> finalResponseListener
) {
try {
config.getModelDefinition().getTrainedModel().validate();
} catch (ElasticsearchException ex) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
);
return false;
}

TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
if (trainedModelType == null) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Unknown trained model definition class [{}]",
config.getModelDefinition().getTrainedModel().getName()
)
);
return false;
}

var configModelType = config.getModelType();
if (configModelType == null) {
// Set the model type from the definition
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
} else if (trainedModelType != configModelType) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"{} [{}] does not match the model definition type [{}]",
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
configModelType,
trainedModelType
)
);
return false;
}

var inferenceConfig = config.getInferenceConfig();
if (inferenceConfig.isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
config.getModelId(),
config.getInferenceConfig().getName(),
config.getModelDefinition().getTrainedModel().targetType()
)
);
return false;
}

var minLicenseSupported = inferenceConfig.getMinLicenseSupportedForAction(RestRequest.Method.PUT);
if (licenseState.isAllowedByLicense(minLicenseSupported) == false) {
finalResponseListener.onFailure(
new ElasticsearchSecurityException(
"Model of type [{}] requires [{}] license level",
RestStatus.FORBIDDEN,
config.getInferenceConfig().getName(),
minLicenseSupported
)
);
return false;
}

TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Cannot create model [{}] while cluster upgrade is in progress.", config.getModelId())
);
return false;
}

return true;
}

@Override
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.license.internal.XPackLicenseStatus;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
Expand All @@ -35,11 +40,13 @@
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
Expand All @@ -50,6 +57,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -60,10 +68,12 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.getTaskInfoListOfOne;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockClientWithTasksResponse;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockListTasksClient;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
Expand All @@ -73,6 +83,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TransportPutTrainedModelActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
Expand Down Expand Up @@ -273,6 +284,56 @@ public void testVerifyMlNodesAndModelArchitectures_GivenArchitecturesMatch_ThenT
ensureNoWarnings();
}

public void testValidateModelDefinition_FailsWhenLicenseIsNotSupported() throws IOException {
ModelPackageConfig packageConfig = ModelPackageConfigTests.randomModulePackageConfig();

TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder().setModelId(
"." + packageConfig.getPackagedModelId()
).setInput(TrainedModelInputTests.createRandomInput());

TransportPutTrainedModelAction.setTrainedModelConfigFieldsFromPackagedModel(
trainedModelConfigBuilder,
packageConfig,
xContentRegistry()
);

var mockTrainedModelDefinition = mock(TrainedModelDefinition.class);
when(mockTrainedModelDefinition.getTrainedModel()).thenReturn(mock(LangIdentNeuralNetwork.class));
var trainedModelConfig = trainedModelConfigBuilder.setLicenseLevel("basic").build();

var mockModelInferenceConfig = spy(new LearningToRankConfig(1, List.of(), Map.of()));
when(mockModelInferenceConfig.isTargetTypeSupported(any())).thenReturn(true);

var mockTrainedModelConfig = spy(trainedModelConfig);
when(mockTrainedModelConfig.getModelType()).thenReturn(TrainedModelType.LANG_IDENT);
when(mockTrainedModelConfig.getModelDefinition()).thenReturn(mockTrainedModelDefinition);
when(mockTrainedModelConfig.getInferenceConfig()).thenReturn(mockModelInferenceConfig);

ActionListener<PutTrainedModelAction.Response> responseListener = ActionListener.wrap(
response -> fail("Expected exception, but got response: " + response),
exception -> {
assertThat(exception, instanceOf(ElasticsearchSecurityException.class));
assertThat(exception.getMessage(), is("Model of type [learning_to_rank] requires [ENTERPRISE] license level"));
}
);

var mockClusterState = mock(ClusterState.class);

AtomicInteger currentTime = new AtomicInteger(100);
var mockXPackLicenseStatus = new XPackLicenseStatus(License.OperationMode.BASIC, true, "");
var mockLicenseState = new XPackLicenseState(currentTime::get, mockXPackLicenseStatus);

assertThat(
TransportPutTrainedModelAction.validateModelDefinition(
mockTrainedModelConfig,
mockClusterState,
mockLicenseState,
responseListener
),
is(false)
);
}

private static void prepareGetTrainedModelResponse(Client client, List<TrainedModelConfig> trainedModels) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
Expand Down