From f1d1af16663a9107924baf626b5bca3561772004 Mon Sep 17 00:00:00 2001 From: Abdul Muneer Kolarkunnu Date: Fri, 6 Dec 2024 22:07:20 +0530 Subject: [PATCH] [Enhancement] Enhance validation for create connector API This PR addresses the first part of this enhancement "Validate if connector payload has all the required fields. If not provided, throw the illegal argument exception". Validation of fields description, parameters, credential, and request_body are missing. That validations are added in this fix. Added new test cases correspong to these validations and fixed all failing test cases because of these new validations. Partially Resolves #1382 Signed-off-by: Abdul Muneer Kolarkunnu --- .../ml/common/connector/ConnectorAction.java | 3 + .../connector/MLCreateConnectorInput.java | 9 + .../common/connector/ConnectorActionTest.java | 122 +++++---- .../MLCreateConnectorInputTests.java | 251 ++++++++++++------ .../algorithms/remote/ConnectorUtils.java | 2 +- .../algorithms/remote/ConnectorUtilsTest.java | 5 +- .../TransportCreateConnectorActionTests.java | 9 +- 7 files changed, 271 insertions(+), 130 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index 4a7555d69b..aed6288629 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -63,6 +63,9 @@ public ConnectorAction( if (method == null) { throw new IllegalArgumentException("method can't null"); } + if (requestBody == null) { + throw new IllegalArgumentException("request body can't null"); + } this.actionType = actionType; this.method = method; this.url = url; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 697f27494f..17af851714 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -93,6 +93,15 @@ public MLCreateConnectorInput( if (protocol == null) { throw new IllegalArgumentException("Connector protocol is null"); } + if (description == null) { + throw new IllegalArgumentException("Connector description is null"); + } + if (parameters == null || parameters.isEmpty()) { + throw new IllegalArgumentException("Connector parameters is null or empty list"); + } + if (credential == null || credential.isEmpty()) { + throw new IllegalArgumentException("Connector credential is null or empty list"); + } } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java index 1539b9b432..c4af406ecf 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/ConnectorActionTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction; import java.io.IOException; @@ -12,10 +14,7 @@ import java.util.HashMap; import java.util.Map; -import org.junit.Assert; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; @@ -27,37 +26,54 @@ import org.opensearch.search.SearchModule; public class ConnectorActionTest { - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); @Test public void constructor_NullActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("action type can't null"); - ConnectorAction.ActionType actionType = null; - String method = "post"; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = null; + String method = "post"; + String url = "https://test.com"; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("action type can't null", exception.getMessage()); + } @Test public void constructor_NullUrl() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("url can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = "post"; - String url = null; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "post"; + String url = null; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("url can't null", exception.getMessage()); } @Test public void constructor_NullMethod() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("method can't null"); - ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; - String method = null; - String url = "https://test.com"; - new ConnectorAction(actionType, method, url, null, null, null, null); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = null; + String url = "https://test.com"; + String requestBody = "{\"input\": \"${parameters.input}\"}"; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("method can't null", exception.getMessage()); + } + + @Test + public void constructor_NullRequestBody() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; + String method = "post"; + String url = "https://test.com"; + String requestBody = null; + new ConnectorAction(actionType, method, url, null, requestBody, null, null); + }); + assertEquals("request body can't null", exception.getMessage()); } @Test @@ -65,11 +81,12 @@ public void writeTo_NullValue() throws IOException { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "http"; String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + String requestBody = "{\"input\": \"${parameters.input}\"}"; + ConnectorAction action = new ConnectorAction(actionType, method, url, null, requestBody, null, null); BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test @@ -95,7 +112,7 @@ public void writeTo() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); action.writeTo(output); ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput()); - Assert.assertEquals(action, action2); + assertEquals(action, action2); } @Test @@ -103,12 +120,17 @@ public void toXContent_NullValue() throws IOException { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "http"; String url = "https://test.com"; - ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null); + String requestBody = "{\"input\": \"${parameters.input}\"}"; + ConnectorAction action = new ConnectorAction(actionType, method, url, null, requestBody, null, null); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com",\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\ + """; + assertEquals(expctedContent, content); } @Test @@ -135,22 +157,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); action.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert - .assertEquals( - "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}", - content - ); + String expctedContent = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}\ + """; + assertEquals(expctedContent, content); } @Test public void parse() throws IOException { - String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"," - + "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}"; + String jsonStr = """ + {"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}"\ + """; XContentParser parser = XContentType.JSON .xContent() .createParser( @@ -160,24 +183,23 @@ public void parse() throws IOException { ); parser.nextToken(); ConnectorAction action = ConnectorAction.parse(parser); - Assert.assertEquals("http", action.getMethod()); - Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); - Assert.assertEquals("https://test.com", action.getUrl()); - Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); - Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); - Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); + assertEquals("http", action.getMethod()); + assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType()); + assertEquals("https://test.com", action.getUrl()); + assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody()); + assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction()); + assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction()); } @Test public void test_wrongActionType() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Wrong Action Type"); - ConnectorAction.ActionType.from("badAction"); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); }); + assertEquals("Wrong Action Type of badAction", exception.getMessage()); } @Test public void test_invalidActionInModelPrediction() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute"); - Assert.assertEquals(isValidActionInModelPrediction(actionType), false); + assertEquals(isValidActionInModelPrediction(actionType), false); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index 28e597e186..a3b44f1321 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -8,6 +8,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.IOException; @@ -19,9 +20,7 @@ import java.util.function.Consumer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensearch.Version; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -46,20 +45,19 @@ public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; - @Rule - public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\",\"client_config\":{\"max_connection\":20," - + "\"connection_timeout\":10000,\"read_timeout\":10000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + private final String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1","protocol":"http",\ + "parameters":{"input":"test input value"},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},\ + "request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,\ + "access_mode":"PUBLIC","client_config":{"max_connection":20,\ + "connection_timeout":10000,"read_timeout":10000,\ + "retry_backoff_millis":10,"retry_timeout_seconds":10,"max_retry_times":-1,"retry_backoff_policy":"constant"}}\ + """; @Before public void setUp() { @@ -102,59 +100,162 @@ public void setUp() { @Test public void constructorMLCreateConnectorInput_NullName() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput - .builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(null) + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector name is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullVersion() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version(null) + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector version is null", exception.getMessage()); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput - .builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol(null) + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector protocol is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullDescription() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description(null) + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector description is null", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullParameters() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(null) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector parameters is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyParameters() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of()) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector parameters is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_NullCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(null) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); + } + + @Test + public void constructorMLCreateConnectorInput_EmptyCredential() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of()) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + }); + assertEquals("Connector credential is null or empty list", exception.getMessage()); } @Test @@ -187,16 +288,15 @@ public void testParse() throws Exception { @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," - + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," - + "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = """ + {"name":"test_connector_name","description":"this is a test connector","version":"1",\ + "protocol":"http","parameters":{"input":["test input value"]},"credential":{"key":"test_key_value"},\ + "actions":[{"action_type":"PREDICT","method":"POST","url":"https://test.com",\ + "headers":{"api_key":"${credential.key}"},"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\ + "pre_process_function":"connector.pre_process.openai.embedding",\ + "post_process_function":"connector.post_process.openai.embedding"}],\ + "backend_roles":["role1","role2"],"add_all_backend_roles":false,"access_mode":"PUBLIC"};\ + """; testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); @@ -223,8 +323,11 @@ public void readInputStream_SuccessWithNullFields() throws IOException { MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput .builder() .name("test_connector_name") + .description("this is a test connector") .version("1") .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); @@ -258,10 +361,8 @@ public void testParse_MissingNameField_ShouldThrowException() throws IOException String jsonMissingName = "{\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"}"; XContentParser parser = createParser(jsonMissingName); - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Connector name is null"); - - MLCreateConnectorInput.parse(parser); + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { MLCreateConnectorInput.parse(parser); }); + assertEquals("Connector name is null", exception.getMessage()); } @Test diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f2c93ef5fd..a0206d7036 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -336,7 +336,7 @@ public static ConnectorAction createConnectorAction(Connector connector, Connect // Initialize the default method and requestBody String method = "POST"; - String requestBody = null; + String requestBody = "{}"; String url = ""; switch (getRemoteServerFromURL(predictEndpoint)) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 335dc95245..cb73e18e1f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -344,7 +344,7 @@ public void testGetTask_createBatchStatusActionForOpenAI() { assertEquals(ConnectorAction.ActionType.BATCH_PREDICT_STATUS, result.getActionType()); assertEquals("GET", result.getMethod()); assertEquals("https://api.openai.com/v1/batches/${parameters.id}", result.getUrl()); - assertNull(result.getRequestBody()); + assertEquals("{}", result.getRequestBody()); assertTrue(result.getHeaders().containsKey("Authorization")); } @@ -355,6 +355,7 @@ public void testGetTask_createCancelBatchActionForBedrock() { .name("test") .protocol("http") .version("1") + .description("this is a test connector") .credential(Map.of("api_key", "credential_value")) .parameters(Map.of("param1", "value1")) .actions( @@ -384,6 +385,6 @@ public void testGetTask_createCancelBatchActionForBedrock() { "https://bedrock.${parameters.region}.amazonaws.com/model-invocation-job/${parameters.processedJobArn}/stop", result.getUrl() ); - assertNull(result.getRequestBody()); + assertEquals("{}", result.getRequestBody()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java index e16400bc56..33052e40d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/connector/TransportCreateConnectorActionTests.java @@ -133,6 +133,7 @@ public void setup() { .builder() .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") + .requestBody("{ \"inputText\": \"${parameters.inputText}\" }") .url("https://${parameters.endpoint}/v1/completions") .build() ); @@ -142,6 +143,7 @@ public void setup() { input = MLCreateConnectorInput .builder() .name("test_name") + .description("this is a test connector") .version("1") .actions(actions) .parameters(parameters) @@ -447,21 +449,24 @@ public void test_execute_URL_notMatchingExpression_exception() { .actionType(ConnectorAction.ActionType.PREDICT) .method("POST") .url("https://${parameters.endpoint}/v1/completions") + .requestBody("{ \"inputText\": \"${parameters.inputText}\" }") .build() ); + Map parameters = ImmutableMap.of("endpoint", "api.openai1.com"); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); MLCreateConnectorInput mlCreateConnectorInput = MLCreateConnectorInput .builder() .name(randomAlphaOfLength(5)) .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) + .parameters(parameters) + .credential(credential) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); - Map parameters = ImmutableMap.of("endpoint", "api.openai1.com"); - mlCreateConnectorInput.setParameters(parameters); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters,