Skip to content

Commit

Permalink
Add feature enable setting for controller index (opensearch-project#2652
Browse files Browse the repository at this point in the history
)

Signed-off-by: b4sjoo <[email protected]>
  • Loading branch information
b4sjoo authored Sep 17, 2024
1 parent 0d26931 commit a72181a
Show file tree
Hide file tree
Showing 22 changed files with 269 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -68,6 +70,7 @@ public class CreateControllerTransportAction extends HandledTransportAction<Acti
ClusterService clusterService;
MLModelCacheHelper mlModelCacheHelper;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public CreateControllerTransportAction(
Expand All @@ -78,7 +81,8 @@ public CreateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new);
this.mlIndicesHandler = mlIndicesHandler;
Expand All @@ -87,6 +91,7 @@ public CreateControllerTransportAction(
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -98,6 +103,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCrea
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<MLCreateControllerResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE;
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -41,6 +42,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -58,6 +60,7 @@ public class DeleteControllerTransportAction extends HandledTransportAction<Acti
MLModelManager mlModelManager;
MLModelCacheHelper mlModelCacheHelper;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public DeleteControllerTransportAction(
Expand All @@ -68,7 +71,8 @@ public DeleteControllerTransportAction(
ClusterService clusterService,
MLModelManager mlModelManager,
MLModelCacheHelper mlModelCacheHelper,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLControllerDeleteAction.NAME, transportService, actionFilters, MLControllerDeleteRequest::new);
this.client = client;
Expand All @@ -77,6 +81,7 @@ public DeleteControllerTransportAction(
this.mlModelManager = mlModelManager;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -86,6 +91,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
User user = RestActionUtils.getUserContext(client);
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
Boolean isHidden = mlModel.getIsHidden();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext;

Expand All @@ -33,6 +34,7 @@
import org.opensearch.ml.common.transport.controller.MLControllerGetResponse;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
Expand All @@ -50,6 +52,7 @@ public class GetControllerTransportAction extends HandledTransportAction<ActionR
ClusterService clusterService;
MLModelManager mlModelManager;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public GetControllerTransportAction(
Expand All @@ -59,14 +62,16 @@ public GetControllerTransportAction(
NamedXContentRegistry xContentRegistry,
ClusterService clusterService,
MLModelManager mlModelManager,
ModelAccessControlHelper modelAccessControlHelper
ModelAccessControlHelper modelAccessControlHelper,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLControllerGetAction.NAME, transportService, actionFilters, MLControllerGetRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
this.mlModelManager = mlModelManager;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -79,6 +84,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCont
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<MLControllerGetResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.FunctionName.REMOTE;
import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING;
import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -46,6 +47,7 @@
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
Expand All @@ -62,6 +64,7 @@ public class UpdateControllerTransportAction extends HandledTransportAction<Acti
MLModelCacheHelper mlModelCacheHelper;
ClusterService clusterService;
ModelAccessControlHelper modelAccessControlHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public UpdateControllerTransportAction(
Expand All @@ -71,14 +74,16 @@ public UpdateControllerTransportAction(
ClusterService clusterService,
ModelAccessControlHelper modelAccessControlHelper,
MLModelCacheHelper mlModelCacheHelper,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new);
this.client = client;
this.mlModelManager = mlModelManager;
this.clusterService = clusterService;
this.mlModelCacheHelper = mlModelCacheHelper;
this.modelAccessControlHelper = modelAccessControlHelper;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -90,6 +95,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
ActionListener<UpdateResponse> wrappedListener = ActionListener.runBefore(actionListener, context::restore);
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.stats.ActionName.REGISTER;
import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLNodeUtils.checkOpenCircuitBreaker;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;
Expand Down Expand Up @@ -1254,6 +1255,9 @@ public synchronized void updateModelCache(String modelId, ActionListener<String>
*/
public synchronized void deployControllerWithDeployedModel(String modelId, ActionListener<String> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
if (!modelCacheHelper.isModelDeployed(modelId)) {
throw new OpenSearchStatusException(
"The model of this model controller has not deployed yet, please deploy the model first.",
Expand Down Expand Up @@ -1423,6 +1427,9 @@ private synchronized void deployControllerWithDeployingModel(
* @param mlModel ml model
*/
public void deployControllerWithDeployingModel(MLModel mlModel, Integer eligibleNodeCount) {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
if (mlModel.getModelState() != MLModelState.DEPLOYING) {
throw new OpenSearchStatusException(
"This method should only be called when model is in DEPLOYING state, but the model is in state: " + mlModel.getModelState(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,10 +755,10 @@ public List<RestHandler> getRestHandlers(
RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction();
RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction();
RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction();
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction();
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction();
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction();
RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(mlFeatureEnabledSetting);
RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(mlFeatureEnabledSetting);
RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(mlFeatureEnabledSetting);
RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(mlFeatureEnabledSetting);
RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(mlFeatureEnabledSetting);
RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(mlFeatureEnabledSetting);
RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction();
Expand Down Expand Up @@ -969,7 +969,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;

Expand All @@ -20,6 +21,7 @@
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLCreateControllerAction;
import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -29,11 +31,14 @@
public class RestMLCreateControllerAction extends BaseRestHandler {

public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

/**
* Constructor
*/
public RestMLCreateControllerAction() {}
public RestMLCreateControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand Down Expand Up @@ -61,6 +66,10 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
* @return MLCreateControllerRequest
*/
private MLCreateControllerRequest getRequest(RestRequest request) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}

if (!request.hasContent()) {
throw new OpenSearchParseException("Create model controller request has empty body");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.rest;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.CONTROLLER_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_ID;

import java.io.IOException;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction;
import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;
Expand All @@ -25,9 +27,13 @@
* This class consists of the REST handler to delete ML Model.
*/
public class RestMLDeleteControllerAction extends BaseRestHandler {

private static final String ML_DELETE_CONTROLLER_ACTION = "ml_delete_controller_action";
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

public void RestMLDeleteControllerAction() {}
public RestMLDeleteControllerAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
public String getName() {
Expand All @@ -42,6 +48,9 @@ public List<Route> routes() {

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
if (!mlFeatureEnabledSetting.isControllerEnabled()) {
throw new IllegalStateException(CONTROLLER_DISABLED_ERR_MSG);
}
String modelId = request.param(PARAMETER_MODEL_ID);

MLControllerDeleteRequest mlControllerDeleteRequest = new MLControllerDeleteRequest(modelId);
Expand Down
Loading

0 comments on commit a72181a

Please sign in to comment.