From 61012a5dd601d86da77e8fae17c92a8c1ceb4ccc Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 11 Oct 2023 15:17:26 -0700 Subject: [PATCH] fix update connector API (#1484) (#1494) * fix update connector API Signed-off-by: Yaliang Wu * fix ut Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 2f3f39e8f8b010e85797533fc4fa3a7df9e1f6b3) Co-authored-by: Yaliang Wu --- .../ml/common/connector/Connector.java | 13 +- .../ml/common/connector/HttpConnector.java | 33 ++++ .../connector/MLCreateConnectorInput.java | 27 ++-- .../connector/MLUpdateConnectorRequest.java | 18 ++- .../MLUpdateConnectorRequestTests.java | 31 ++-- .../DeleteConnectorTransportAction.java | 11 +- .../UpdateConnectorTransportAction.java | 43 +++++- .../helper/ConnectorAccessControlHelper.java | 57 ++++--- .../ml/rest/RestMLUpdateConnectorAction.java | 7 +- .../opensearch/ml/utils/MLExceptionUtils.java | 2 +- .../DeleteConnectorTransportActionTests.java | 14 +- .../TransportUpdateConnectorActionTests.java | 144 ++++++++++++------ .../RestMLUpdateConnectorActionTests.java | 8 +- 13 files changed, 275 insertions(+), 133 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 419e460c95..15b7456f3c 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -30,17 +30,7 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.model.ModelTensor; - -import java.io.IOException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.function.Function; -import java.util.regex.Matcher; -import java.util.regex.Pattern; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -80,6 +70,7 @@ public interface Connector extends ToXContentObject, Writeable { void writeTo(StreamOutput out) throws IOException; + void update(MLCreateConnectorInput updateContent, Function function); void parseResponse(T orElse, List modelTensors, boolean b) throws IOException; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index aee56907e1..0926694b44 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -34,6 +34,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @Log4j2 @NoArgsConstructor @@ -248,6 +249,38 @@ public void writeTo(StreamOutput out) throws IOException { } } + @Override + public void update(MLCreateConnectorInput updateContent, Function function) { + if (updateContent.getName() != null) { + this.name = updateContent.getName(); + } + if (updateContent.getDescription() != null) { + this.description = updateContent.getDescription(); + } + if (updateContent.getVersion() != null) { + this.version = updateContent.getVersion(); + } + if (updateContent.getProtocol() != null) { + this.protocol = updateContent.getProtocol(); + } + if (updateContent.getParameters() != null && updateContent.getParameters().size() > 0) { + this.parameters = updateContent.getParameters(); + } + if (updateContent.getCredential() != null && updateContent.getCredential().size() > 0) { + this.credential = updateContent.getCredential(); + encrypt(function); + } + if (updateContent.getActions() != null) { + this.actions = updateContent.getActions(); + } + if (updateContent.getBackendRoles() != null) { + this.backendRoles = updateContent.getBackendRoles(); + } + if (updateContent.getAccess() != null) { + this.access = updateContent.getAccess(); + } + } + @Override public T createPredictPayload(Map parameters) { Optional predictAction = findPredictAction(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 9d5f7c88de..9d9879daec 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -56,6 +56,7 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private Boolean addAllBackendRoles; private AccessMode access; private boolean dryRun = false; + private boolean updateConnector = false; @Builder(toBuilder = true) public MLCreateConnectorInput(String name, @@ -68,9 +69,10 @@ public MLCreateConnectorInput(String name, List backendRoles, Boolean addAllBackendRoles, AccessMode access, - boolean dryRun + boolean dryRun, + boolean updateConnector ) { - if (!dryRun) { + if (!dryRun && !updateConnector) { if (name == null) { throw new IllegalArgumentException("Connector name is null"); } @@ -92,9 +94,14 @@ public MLCreateConnectorInput(String name, this.addAllBackendRoles = addAllBackendRoles; this.access = access; this.dryRun = dryRun; + this.updateConnector = updateConnector; } public static MLCreateConnectorInput parse(XContentParser parser) throws IOException { + return parse(parser, false); + } + + public static MLCreateConnectorInput parse(XContentParser parser, boolean updateConnector) throws IOException { String name = null; String description = null; String version = null; @@ -159,7 +166,7 @@ public static MLCreateConnectorInput parse(XContentParser parser) throws IOExcep break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun); + return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector); } @Override @@ -201,10 +208,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput output) throws IOException { - output.writeString(name); + output.writeOptionalString(name); output.writeOptionalString(description); - output.writeString(version); - output.writeString(protocol); + output.writeOptionalString(version); + output.writeOptionalString(protocol); if (parameters != null) { output.writeBoolean(true); output.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); @@ -240,13 +247,14 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); } output.writeBoolean(dryRun); + output.writeBoolean(updateConnector); } public MLCreateConnectorInput(StreamInput input) throws IOException { - name = input.readString(); + name = input.readOptionalString(); description = input.readOptionalString(); - version = input.readString(); - protocol = input.readString(); + version = input.readOptionalString(); + protocol = input.readOptionalString(); if (input.readBoolean()) { parameters = input.readMap(s -> s.readString(), s -> s.readString()); } @@ -268,5 +276,6 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { this.access = input.readEnum(AccessMode.class); } dryRun = input.readBoolean(); + updateConnector = input.readBoolean(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index ced3646d13..089180cdc5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -19,17 +19,16 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Map; import static org.opensearch.action.ValidateActions.addValidationError; @Getter public class MLUpdateConnectorRequest extends ActionRequest { String connectorId; - Map updateContent; + MLCreateConnectorInput updateContent; @Builder - public MLUpdateConnectorRequest(String connectorId, Map updateContent) { + public MLUpdateConnectorRequest(String connectorId, MLCreateConnectorInput updateContent) { this.connectorId = connectorId; this.updateContent = updateContent; } @@ -37,14 +36,14 @@ public MLUpdateConnectorRequest(String connectorId, Map updateCo public MLUpdateConnectorRequest(StreamInput in) throws IOException { super(in); this.connectorId = in.readString(); - this.updateContent = in.readMap(); + this.updateContent = new MLCreateConnectorInput(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(this.connectorId); - out.writeMap(this.getUpdateContent()); + this.updateContent.writeTo(out); } @Override @@ -55,14 +54,17 @@ public ActionRequestValidationException validate() { exception = addValidationError("ML connector id can't be null", exception); } + if (updateContent == null) { + exception = addValidationError("Update connector content can't be null", exception); + } + return exception; } public static MLUpdateConnectorRequest parse(XContentParser parser, String connectorId) throws IOException { - Map dataAsMap = null; - dataAsMap = parser.map(); + MLCreateConnectorInput updateContent = MLCreateConnectorInput.parse(parser, true); - return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(dataAsMap).build(); + return MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionRequest) { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index e017009983..44e970f95c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -7,38 +7,37 @@ import org.junit.Before; import org.junit.Test; -import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestRequest; +import org.opensearch.search.SearchModule; import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Map; +import java.util.Collections; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; -import static org.mockito.Mockito.when; +import static org.junit.Assert.assertTrue; public class MLUpdateConnectorRequestTests { private String connectorId; - private Map updateContent; + private MLCreateConnectorInput updateContent; private MLUpdateConnectorRequest mlUpdateConnectorRequest; - @Mock - XContentParser parser; - @Before public void setUp() { MockitoAnnotations.openMocks(this); this.connectorId = "test-connector_id"; - this.updateContent = Map.of("description", "new description"); + this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build(); mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() .connectorId(connectorId) .updateContent(updateContent) @@ -64,18 +63,20 @@ public void validate_Exception_NullConnectorId() { MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); Exception exception = updateConnectorRequest.validate(); - assertEquals("Validation Failed: 1: ML connector id can't be null;", exception.getMessage()); + assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage()); } @Test public void parse_success() throws IOException { - RestRequest.Method method = RestRequest.Method.POST; - final Map updatefields = Map.of("version", "new version", "description", "new description"); - when(parser.map()).thenReturn(updatefields); - + String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; + XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), null, jsonStr); + parser.nextToken(); MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); - assertEquals(updateConnectorRequest.getUpdateContent(), updatefields); + assertTrue(updateConnectorRequest.getUpdateContent().isUpdateConnector()); + assertEquals("new version", updateConnectorRequest.getUpdateContent().getVersion()); + assertEquals("new description", updateConnectorRequest.getUpdateContent().getDescription()); } @Test diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index 276f29259e..856fccb848 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -8,6 +8,10 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; @@ -77,11 +81,16 @@ protected void doExecute(Task task, ActionRequest request, ActionListener modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } actionListener .onFailure( new MLValidationException( searchHits.length - + " models are still using this connector, please delete or update the models first!" + + " models are still using this connector, please delete or update the models first: " + + Arrays.toString(modelIds.toArray(new String[0])) ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index d8a1d88a01..86e4afe56f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -7,6 +7,11 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; @@ -16,9 +21,14 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -26,6 +36,7 @@ import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.search.SearchHit; @@ -38,12 +49,14 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@FieldDefaults(level = AccessLevel.PRIVATE) public class UpdateConnectorTransportAction extends HandledTransportAction { Client client; ConnectorAccessControlHelper connectorAccessControlHelper; MLModelManager mlModelManager; + MLEngine mlEngine; + volatile List trustedConnectorEndpointsRegex; @Inject public UpdateConnectorTransportAction( @@ -51,25 +64,35 @@ public UpdateConnectorTransportAction( ActionFilters actionFilters, Client client, ConnectorAccessControlHelper connectorAccessControlHelper, - MLModelManager mlModelManager + MLModelManager mlModelManager, + Settings settings, + ClusterService clusterService, + MLEngine mlEngine ) { super(MLUpdateConnectorAction.NAME, transportService, actionFilters, MLUpdateConnectorRequest::new); this.client = client; this.connectorAccessControlHelper = connectorAccessControlHelper; this.mlModelManager = mlModelManager; + this.mlEngine = mlEngine; + trustedConnectorEndpointsRegex = ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, it -> trustedConnectorEndpointsRegex = it); } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLUpdateConnectorRequest mlUpdateConnectorAction = MLUpdateConnectorRequest.fromActionRequest(request); String connectorId = mlUpdateConnectorAction.getConnectorId(); - UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); - updateRequest.doc(mlUpdateConnectorAction.getUpdateContent()); - updateRequest.docAsUpsert(true); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - connectorAccessControlHelper.validateConnectorAccess(client, connectorId, ActionListener.wrap(hasPermission -> { + connectorAccessControlHelper.getConnector(client, connectorId, ActionListener.wrap(connector -> { + boolean hasPermission = connectorAccessControlHelper.validateConnectorAccess(client, connector); if (Boolean.TRUE.equals(hasPermission)) { + connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt); + connector.validateConnectorURL(trustedConnectorEndpointsRegex); + UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId); + updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); updateUndeployedConnector(connectorId, updateRequest, listener, context); } else { listener @@ -107,10 +130,16 @@ private void updateUndeployedConnector( client.update(updateRequest, getUpdateResponseListener(connectorId, listener, context)); } else { log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); + List modelIds = new ArrayList<>(); + for (SearchHit hit : searchHits) { + modelIds.add(hit.getId()); + } listener .onFailure( new MLValidationException( - searchHits.length + " models are still using this connector, please undeploy the models first!" + searchHits.length + + " models are still using this connector, please undeploy the models first: " + + Arrays.toString(modelIds.toArray(new String[0])) ) ); } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index 09fd80818f..b2c912b9d5 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -64,35 +64,48 @@ public void validateConnectorAccess(Client client, String connectorId, ActionLis listener.onResponse(true); return; } - GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); - client.get(getRequest, ActionListener.wrap(r -> { - if (r != null && r.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Connector connector = Connector.createConnector(parser); - boolean hasPermission = hasPermission(user, connector); - wrappedListener.onResponse(hasPermission); - } catch (Exception e) { - log.error("Failed to parse connector:" + connectorId); - wrappedListener.onFailure(e); - } - } else { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId)); - } - }, e -> { - log.error("Fail to get connector", e); - wrappedListener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); - })); + getConnector(client, connectorId, ActionListener.wrap(connector -> { + boolean hasPermission = hasPermission(user, connector); + wrappedListener.onResponse(hasPermission); + }, e -> { wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to validate Access for connector:" + connectorId, e); listener.onFailure(e); } + } + + public boolean validateConnectorAccess(Client client, Connector connector) { + User user = RestActionUtils.getUserContext(client); + if (isAdmin(user) || accessControlNotEnabled(user)) { + return true; + } + return hasPermission(user, connector); + } + public void getConnector(Client client, String connectorId, ActionListener listener) { + GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); + client.get(getRequest, ActionListener.wrap(r -> { + if (r != null && r.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, r.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Connector connector = Connector.createConnector(parser); + listener.onResponse(connector); + } catch (Exception e) { + log.error("Failed to parse connector:" + connectorId); + listener.onFailure(e); + } + } else { + listener.onFailure(new MLResourceNotFoundException("Fail to find connector:" + connectorId)); + } + }, e -> { + log.error("Fail to get connector", e); + listener.onFailure(new IllegalStateException("Fail to get connector:" + connectorId)); + })); } public boolean skipConnectorAccessControl(User user) { diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java index 334aa2766a..2957740354 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateConnectorAction.java @@ -44,12 +44,7 @@ public String getName() { @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.POST, - String.format(Locale.ROOT, "%s/connectors/_update/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID) - ) - ); + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/connectors/{%s}", ML_BASE_URI, PARAMETER_CONNECTOR_ID))); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index d3866ad299..6f051615e3 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -21,7 +21,7 @@ public class MLExceptionUtils { public static final String REMOTE_INFERENCE_DISABLED_ERR_MSG = "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; public static final String UPDATE_CONNECTOR_DISABLED_ERR_MSG = - "Update connector API is currently disabled. To enable it, update the setting \"plugins.ml_commons.update_connector_enabled\" to true."; + "Update connector API is currently disabled. To enable it, update the setting \"plugins.ml_commons.update_connector.enabled\" to true."; public static String getRootCauseMessage(final Throwable throwable) { String message = ExceptionUtils.getRootCauseMessage(throwable); diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java index e3a0cdc058..977ae66603 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/DeleteConnectorTransportActionTests.java @@ -44,6 +44,7 @@ import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest; import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -180,7 +181,7 @@ public void testDeleteConnector_BlockedByModel() throws IOException { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLValidationException.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "1 models are still using this connector, please delete or update the models first!", + "1 models are still using this connector, please delete or update the models first: [model_ID]", argumentCaptor.getValue().getMessage() ); } @@ -291,8 +292,17 @@ private SearchResponse getEmptySearchResponse() { return searchResponse; } - private SearchResponse getNonEmptySearchResponse() { + private SearchResponse getNonEmptySearchResponse() throws IOException { SearchHit[] hits = new SearchHit[1]; + String modelContent = "{\n" + + " \"created_time\": 1684981986069,\n" + + " \"last_updated_time\": 1684981986069,\n" + + " \"_id\": \"model_ID\",\n" + + " \"name\": \"test_model\",\n" + + " \"description\": \"This is an example description\"\n" + + " }"; + SearchHit model = SearchHit.fromXContent(TestHelper.parser(modelContent)); + hits[0] = model; SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1.0f); SearchResponseSections searchSections = new SearchResponseSections( searchHits, diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java index fc6020474a..bb3d5ecebd 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportUpdateConnectorActionTests.java @@ -7,17 +7,16 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.nio.file.Path; +import java.util.Arrays; import java.util.List; -import java.util.Map; +import java.util.UUID; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -41,7 +40,14 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexNotFoundException; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorRequest; +import org.opensearch.ml.engine.MLEngine; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.TestHelper; @@ -54,6 +60,7 @@ import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { @@ -100,6 +107,8 @@ public class TransportUpdateConnectorActionTests extends OpenSearchTestCase { private SearchResponse searchResponse; + private MLEngine mlEngine; + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @@ -122,7 +131,12 @@ public void setup() throws IOException { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); String connector_id = "test_connector_id"; - Map updateContent = Map.of("version", "2", "description", "updated description"); + MLCreateConnectorInput updateContent = MLCreateConnectorInput + .builder() + .updateConnector(true) + .version("2") + .description("updated description") + .build(); when(updateRequest.getConnectorId()).thenReturn(connector_id); when(updateRequest.getUpdateContent()).thenReturn(updateContent); @@ -139,25 +153,56 @@ public void setup() throws IOException { SearchResponse.Clusters.EMPTY ); + Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor); + transportUpdateConnectorAction = new UpdateConnectorTransportAction( transportService, actionFilters, client, connectorAccessControlHelper, - mlModelManager + mlModelManager, + settings, + clusterService, + mlEngine ); when(mlModelManager.getAllModelIds()).thenReturn(new String[] {}); shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); - } - public void test_execute_connectorAccessControl_success() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(2); + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(ImmutableMap.of("api_key", "credential_value")) + .parameters(ImmutableMap.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + // Connector connector = mock(HttpConnector.class); + // doNothing().when(connector).update(any(), any()); + listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + } + + public void test_execute_connectorAccessControl_success() { + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -176,11 +221,7 @@ public void test_execute_connectorAccessControl_success() { } public void test_execute_connectorAccessControl_NoPermission() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(false); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(false).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -192,11 +233,9 @@ public void test_execute_connectorAccessControl_NoPermission() { } public void test_execute_connectorAccessControl_AccessError() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new RuntimeException("Connector Access Control Error")); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doThrow(new RuntimeException("Connector Access Control Error")) + .when(connectorAccessControlHelper) + .validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -207,7 +246,7 @@ public void test_execute_connectorAccessControl_AccessError() { public void test_execute_connectorAccessControl_Exception() { doThrow(new RuntimeException("exception in access control")) .when(connectorAccessControlHelper) - .validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + .validateConnectorAccess(any(Client.class), any(Connector.class)); transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); @@ -216,11 +255,7 @@ public void test_execute_connectorAccessControl_Exception() { } public void test_execute_UpdateWrongStatus() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -240,11 +275,7 @@ public void test_execute_UpdateWrongStatus() { } public void test_execute_UpdateException() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -265,11 +296,7 @@ public void test_execute_UpdateException() { } public void test_execute_SearchResponseNotEmpty() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -280,15 +307,13 @@ public void test_execute_SearchResponseNotEmpty() { transportUpdateConnectorAction.doExecute(task, updateRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("1 models are still using this connector, please undeploy the models first!", argumentCaptor.getValue().getMessage()); + assertTrue( + argumentCaptor.getValue().getMessage().contains("1 models are still using this connector, please undeploy the models first") + ); } public void test_execute_SearchResponseError() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); - return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); @@ -303,11 +328,36 @@ public void test_execute_SearchResponseError() { } public void test_execute_SearchIndexNotFoundError() { + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); + doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(true); + ActionListener listener = invocation.getArgument(2); + Connector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(ImmutableMap.of("api_key", "credential_value")) + .parameters(ImmutableMap.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(ImmutableMap.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + // Connector connector = mock(HttpConnector.class); + // doNothing().when(connector).update(any(), any()); + listener.onResponse(connector); return null; - }).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(String.class), isA(ActionListener.class)); + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java index b057190d98..3bc5a5e940 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateConnectorActionTests.java @@ -99,8 +99,8 @@ public void testRoutes() { assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); - assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/connectors/_update/{connector_id}", route.getPath()); + assertEquals(RestRequest.Method.PUT, route.getMethod()); + assertEquals("/_plugins/_ml/connectors/{connector_id}", route.getPath()); } public void testUpdateConnectorRequest() throws Exception { @@ -110,8 +110,8 @@ public void testUpdateConnectorRequest() throws Exception { verify(client, times(1)).execute(eq(MLUpdateConnectorAction.INSTANCE), argumentCaptor.capture(), any()); MLUpdateConnectorRequest updateConnectorRequest = argumentCaptor.getValue(); assertEquals("test_connectorId", updateConnectorRequest.getConnectorId()); - assertEquals("This is test description", updateConnectorRequest.getUpdateContent().get("description")); - assertEquals("2", updateConnectorRequest.getUpdateContent().get("version")); + assertEquals("This is test description", updateConnectorRequest.getUpdateContent().getDescription()); + assertEquals("2", updateConnectorRequest.getUpdateContent().getVersion()); } public void testUpdateConnectorRequestWithEmptyContent() throws Exception {