diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java new file mode 100644 index 0000000000..9fa10a39c6 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java @@ -0,0 +1,16 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.opensearch.action.ActionType; +import org.opensearch.action.update.UpdateResponse; + +public class MLUpdateConnectorAction extends ActionType { + public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction(); + public static final String NAME = "cluster:admin/opensearch/ml/connectors/update"; + + private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);} +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java new file mode 100644 index 0000000000..ced3646d13 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -0,0 +1,83 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import lombok.Builder; +import lombok.Getter; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.opensearch.action.ValidateActions.addValidationError; + +@Getter +public class MLUpdateConnectorRequest extends ActionRequest { + String connectorId; + Map updateContent; + + @Builder + public MLUpdateConnectorRequest(String connectorId, Map updateContent) { + this.connectorId = connectorId; + this.updateContent = updateContent; + } + + public MLUpdateConnectorRequest(StreamInput in) throws IOException { + super(in); + this.connectorId = in.readString(); + this.updateContent = in.readMap(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.connectorId); + out.writeMap(this.getUpdateContent()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + if (this.connectorId == null) { + exception = addValidationError("ML connector id can't be null", exception); + } + + return exception; + } + + public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { + Map dataAsMap = null; + dataAsMap = parser.map(); + + return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build(); + } + + public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLUpdateConnectorRequest) { + return (MLUpdateConnectorRequest) actionRequest; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionRequest.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLUpdateConnectorRequest(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionRequest into MLUpdateConnectorRequest", e); + } + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java new file mode 100644 index 0000000000..e017009983 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.connector; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.RestRequest; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.mockito.Mockito.when; + +public class MLUpdateConnectorRequestTests { + private String connectorId; + private Map updateContent; + private MLUpdateConnectorRequest mlUpdateConnectorRequest; + + @Mock + XContentParser parser; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + this.connectorId = "test-connector_id"; + this.updateContent = Map.of("description", "new description"); + mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUpdateConnectorRequest.writeTo(bytesStreamOutput); + MLUpdateConnectorRequest parsedUpdateRequest = new MLUpdateConnectorRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(connectorId, parsedUpdateRequest.getConnectorId()); + assertEquals(updateContent, parsedUpdateRequest.getUpdateContent()); + } + + @Test + public void validate_Success() { + assertNull(mlUpdateConnectorRequest.validate()); + } + + @Test + public void validate_Exception_NullConnectorId() { + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); + Exception exception = updateConnectorRequest.validate(); + + assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); + } + + @Test + public void parse_success() throws IOException { + RestRequest.Method method = RestRequest.Method.POST; + final Map updatefields = Map.of("version", "new version", "description", "new description"); + when(parser.map()).thenReturn(updatefields); + + MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); + assertEquals(updateConnectorRequest.getConnectorId(), connectorId); + assertEquals(updateConnectorRequest.getUpdateContent(), updatefields); + } + + @Test + public void fromActionRequest_Success() { + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + assertSame(MLUpdateConnectorRequest.fromActionRequest(mlUpdateConnectorRequest), mlUpdateConnectorRequest); + } + + @Test + public void fromActionRequest_Success_fromActionRequest() { + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + .connectorId(connectorId) + .updateContent(updateContent) + .build(); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + mlUpdateConnectorRequest.writeTo(out); + } + }; + MLUpdateConnectorRequest request = MLUpdateConnectorRequest.fromActionRequest(actionRequest); + assertNotSame(request, mlUpdateConnectorRequest); + assertEquals(mlUpdateConnectorRequest.getConnectorId(), request.getConnectorId()); + assertEquals(mlUpdateConnectorRequest.getUpdateContent(), request.getUpdateContent()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException(); + } + }; + MLUpdateConnectorRequest.fromActionRequest(actionRequest); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java new file mode 100644 index 0000000000..d8a1d88a01 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLValidationException; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class UpdateConnectorTransportAction extends HandledTransportAction { + Client client; + + ConnectorAccessControlHelper connectorAccessControlHelper; + MLModelManager mlModelManager; + + @Inject + public UpdateConnectorTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ConnectorAccessControlHelper connectorAccessControlHelper, + MLModelManager mlModelManager + ) { + super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); + this.client = client; + this.connectorAccessControlHelper = connectorAccessControlHelper; + this.mlModelManager = mlModelManager; + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); + String connectorId = mlUpdateConnectorAction.getConnectorId(); + UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); + updateRequest.doc(mlUpdateConnectorAction.getUpdateContent()); + updateRequest.docAsUpsert(true); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> { + if (Boolean.TRUE.equals(hasPermission)) { + updateUndeployedConnector(connectorId, updateRequest, listener, context); + } else { + listener + .onFailure( + new IllegalArgumentException("You don't have permission to update the connector, connector id: " + connectorId) + ); + } + }, exception -> { + log.error("Permission denied: Unable to update the connector with ID {}. Details: {}", connectorId, exception); + listener.onFailure(exception); + })); + } catch (Exception e) { + log.error("Failed to update ML connector for connector id {}. Details {}:", connectorId, e); + listener.onFailure(e); + } + } + + private void updateUndeployedConnector( + String connectorId, + UpdateRequest updateRequest, + ActionListener listener, + ThreadContext.StoredContext context + ) { + SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + boolQueryBuilder.must(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); + boolQueryBuilder.must(QueryBuilders.idsQuery().addIds(mlModelManager.getAllModelIds())); + sourceBuilder.query(boolQueryBuilder); + searchRequest.source(sourceBuilder); + + client.search(searchRequest, ActionListener.wrap(searchResponse -> { + SearchHit[] searchHits = searchResponse.getHits().getHits(); + if (searchHits.length == 0) { + client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + } else { + log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); + listener + .onFailure( + new MLValidationException( + searchHits.length + " models are still using this connector, please undeploy the models first!" + ) + ); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); + return; + } + log.error("Failed to update ML connector: " + connectorId, e); + listener.onFailure(e); + + })); + } + + private ActionListener getUpdateResponseListener( + String connectorId, + ActionListener actionListener, + ThreadContext.StoredContext context + ) { + return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { + log.info("Failed to update the connector with ID: {}", connectorId); + actionListener.onResponse(updateResponse); + return; + } + log.info("Successfully updated the connector with ID: {}", connectorId); + actionListener.onResponse(updateResponse); + }, exception -> { + log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception); + actionListener.onFailure(exception); + }), context::restore); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 35d36c20f0..70e87eb860 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -40,6 +40,7 @@ import org.opensearch.ml.action.connector.GetConnectorTransportAction; import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; +import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; import org.opensearch.ml.action.deploy.TransportDeployModelAction; import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; @@ -91,6 +92,7 @@ import org.opensearch.ml.common.transport.connector.MLConnectorGetAction; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -163,6 +165,7 @@ import org.opensearch.ml.rest.RestMLTrainAndPredictAction; import org.opensearch.ml.rest.RestMLTrainingAction; import org.opensearch.ml.rest.RestMLUndeployModelAction; +import org.opensearch.ml.rest.RestMLUpdateConnectorAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -289,12 +292,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), - new ActionHandler<>(CreateConversationAction.INSTANCE, CreateConversationTransportAction.class), new ActionHandler<>(GetConversationsAction.INSTANCE, GetConversationsTransportAction.class), new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class), new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), - new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class) + new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), + new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class) ); } @@ -538,12 +541,12 @@ public List getRestHandlers( RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); - RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction(); RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); + RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); return ImmutableList .of( restMLStatsAction, @@ -575,7 +578,8 @@ public List getRestHandlers( restListConversationsAction, restCreateInteractionAction, restListInteractionsAction, - restDeleteConversationAction + restDeleteConversationAction, + restMLUpdateConnectorAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java new file mode 100644 index 0000000000..a74ed27ecc --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_CONNECTOR_ID; +import static org.opensearch.ml.utils.RestActionUtils.getParameterId; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; + +public class RestMLUpdateConnectorAction extends BaseRestHandler { + private static final String ML_UPDATE_CONNECTOR_ACTION = "ml_update_connector_action"; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + public RestMLUpdateConnectorAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + } + + @Override + public String getName() { + return ML_UPDATE_CONNECTOR_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/connectors/_update/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID) + ) + ); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + MLUpdateConnectorRequest mlUpdateConnectorRequest = getRequest(request); + return restChannel -> client + .execute(MLUpdateConnectorAction.INSTANCE, mlUpdateConnectorRequest, new RestToXContentListener<>(restChannel)); + } + + @VisibleForTesting + private MLUpdateConnectorRequest getRequest(RestRequest request) throws IOException { + if (!mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { + throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); + } + + if (!request.hasContent()) { + throw new IOException("Failed to update connector: Request body is empty"); + } + + String connectorId = getParameterId(request, PARAMETER_CONNECTOR_ID); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + return MLUpdateConnectorRequest.parse(parser, connectorId); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java new file mode 100644 index 0000000000..fc6020474a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java @@ -0,0 +1,346 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.connector; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableList; + +public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { + + private UpdateConnectorTransportAction transportUpdateConnectorAction; + + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Mock + private Task task; + + @Mock + private Client client; + + @Mock + private ThreadPool threadPool; + + @Mock + private ClusterService clusterService; + + @Mock + private TransportService transportService; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLUpdateConnectorRequest updateRequest; + + @Mock + private UpdateResponse updateResponse; + + @Mock + ActionListener actionListener; + + @Mock + MLModelManager mlModelManager; + + ThreadContext threadContext; + + private Settings settings; + + private ShardId shardId; + + private SearchResponse searchResponse; + + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); + ClusterSettings clusterSettings = clusterSetting( + settings, + ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, + ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED + ); + + Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); + threadContext = new ThreadContext(settings); + when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + String connector_id = "test_connector_id"; + Map updateContent = Map.of("version", "2", "description", "updated description"); + when(updateRequest.getConnectorId()).thenReturn(connector_id); + when(updateRequest.getUpdateContent()).thenReturn(updateContent); + + SearchHits hits = new SearchHits(new SearchHit[] {}, new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1); + searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + transportUpdateConnectorAction = new UpdateConnectorTransportAction( + transportService, + actionFilters, + client, + connectorAccessControlHelper, + mlModelManager + ); + + when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); + shardId = new ShardId(new Index("indexName", "uuid"), 1); + updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); + } + + public void test_execute_connectorAccessControl_success() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_connectorAccessControl_NoPermission() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(false); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You don't have permission to update the connector, connector id: test_connector_id", + argumentCaptor.getValue().getMessage() + ); + } + + public void test_execute_connectorAccessControl_AccessError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("Connector Access Control Error")); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Connector Access Control Error", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_connectorAccessControl_Exception() { + doThrow(new RuntimeException("exception in access control")) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("exception in access control", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_UpdateWrongStatus() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + UpdateResponse updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + public void test_execute_UpdateException() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("update document failure")); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("update document failure", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_SearchResponseNotEmpty() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(noneEmptySearchResponse()); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_SearchResponseError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Error in Search Request")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Error in Search Request", argumentCaptor.getValue().getMessage()); + } + + public void test_execute_SearchIndexNotFoundError() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(true); + return null; + }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("Index not found!")); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + private SearchResponse noneEmptySearchResponse() throws IOException { + String modelContent = "{\"name\":\"Remote_Model\",\"algorithm\":\"Remote\",\"version\":1,\"connector_id\":\"test_id\"}"; + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + SearchHits hits = new SearchHits(new SearchHit[] { model }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), Float.NaN); + SearchResponseSections searchSections = new SearchResponseSections(hits, InternalAggregations.EMPTY, null, false, false, null, 1); + SearchResponse searchResponse = new SearchResponse( + searchSections, + null, + 1, + 1, + 0, + 11, + ShardSearchFailure.EMPTY_ARRAY, + SearchResponse.Clusters.EMPTY + ); + + return searchResponse; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java index eee7f39e87..eff2f2d69f 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionTests.java @@ -129,6 +129,7 @@ private RestRequest getRestRequest() { .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); + return request; } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java new file mode 100644 index 0000000000..814402fb66 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; +import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +import com.google.gson.Gson; + +public class RestMLUpdateConnectorActionTests extends OpenSearchTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private RestMLUpdateConnectorAction restMLUpdateConnectorAction; + private NodeClient client; + private ThreadPool threadPool; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + RestChannel channel; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(true); + restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + return null; + }).when(client).execute(eq(MLUpdateConnectorAction.INSTANCE), any(), any()); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + public void testConstructor() { + RestMLUpdateConnectorAction updateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + assertNotNull(updateConnectorAction); + } + + public void testGetName() { + String actionName = restMLUpdateConnectorAction.getName(); + assertFalse(Strings.isNullOrEmpty(actionName)); + assertEquals("ml_update_connector_action", actionName); + } + + public void testRoutes() { + List routes = restMLUpdateConnectorAction.routes(); + assertNotNull(routes); + assertFalse(routes.isEmpty()); + RestHandler.Route route = routes.get(0); + assertEquals(RestRequest.Method.POST, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/_update/{connector_id}", route.getPath()); + } + + public void testUpdateConnectorRequest() throws Exception { + RestRequest request = getRestRequest(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateConnectorRequest.class); + verify(client, times(1)).execute(eq(MLUpdateConnectorAction.INSTANCE), argumentCaptor.capture(), any()); + MLUpdateConnectorRequest updateConnectorRequest = argumentCaptor.getValue(); + assertEquals("test_connectorId", updateConnectorRequest.getConnectorId()); + assertEquals("This is test description", updateConnectorRequest.getUpdateContent().get("description")); + assertEquals("2", updateConnectorRequest.getUpdateContent().get("version")); + } + + public void testUpdateConnectorRequestWithEmptyContent() throws Exception { + exceptionRule.expect(IOException.class); + exceptionRule.expectMessage("Failed to update connector: Request body is empty"); + RestRequest request = getRestRequestWithEmptyContent(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + public void testUpdateConnectorRequestWithNullConnectorId() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Request should contain connector_id"); + RestRequest request = getRestRequestWithNullConnectorId(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + public void testPrepareRequestFeatureDisabled() throws Exception { + exceptionRule.expect(IllegalStateException.class); + exceptionRule.expectMessage(REMOTE_INFERENCE_DISABLED_ERR_MSG); + + when(mlFeatureEnabledSetting.isRemoteInferenceEnabled()).thenReturn(false); + RestRequest request = getRestRequest(); + restMLUpdateConnectorAction.handleRequest(request, channel, client); + } + + private RestRequest getRestRequest() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of("version", "2", "description", "This is test description"); + String requestContent = new Gson().toJson(updateContent).toString(); + Map params = new HashMap<>(); + params.put("connector_id", "test_connectorId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithEmptyContent() { + RestRequest.Method method = RestRequest.Method.POST; + Map params = new HashMap<>(); + params.put("connector_id", "test_connectorId"); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(""), XContentType.JSON) + .build(); + return request; + } + + private RestRequest getRestRequestWithNullConnectorId() { + RestRequest.Method method = RestRequest.Method.POST; + final Map updateContent = Map.of("version", "2", "description", "This is test description"); + String requestContent = new Gson().toJson(updateContent).toString(); + Map params = new HashMap<>(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withMethod(method) + .withPath("/_plugins/_ml/connectors/_update/{connector_id}") + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .build(); + return request; + } + +}