Skip to content

Commit

Permalink
[Enhancement] Enhance validation for create connector API
Browse files Browse the repository at this point in the history
This PR addresses the first part of this enhancement "Validate if connector payload has all the required fields. If not provided, throw the illegal argument exception".
Validation of fields description, parameters, credential, and request_body are missing. That validations are added in this fix.
Added new test cases correspong to these validations and fixed all failing test cases because of these new validations.

Partially Resolves #1382

Signed-off-by: Abdul Muneer Kolarkunnu <[email protected]>
  • Loading branch information
akolarkunnu committed Dec 16, 2024
1 parent 1d30671 commit ff5392b
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 169 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ public ConnectorAction(
String postProcessFunction
) {
if (actionType == null) {
throw new IllegalArgumentException("action type can't null");
throw new IllegalArgumentException("action type can't be null");
}
if (url == null) {
throw new IllegalArgumentException("url can't null");
throw new IllegalArgumentException("url can't be null");
}
if (method == null) {
throw new IllegalArgumentException("method can't null");
throw new IllegalArgumentException("method can't be null");
}
this.actionType = actionType;
this.method = method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public MLCreateConnectorInput(
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
if (credential == null || credential.isEmpty()) {
throw new IllegalArgumentException("Connector credential is null or empty list");
}
}
this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.ml.common.connector;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -27,130 +26,124 @@
import org.opensearch.search.SearchModule;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

// Shared test data for the class
private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT;
private static final String TEST_METHOD_POST = "post";
private static final String TEST_METHOD_HTTP = "http";
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";

@Test
public void constructor_NullActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("action type can't null");
ConnectorAction.ActionType actionType = null;
String method = "post";
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("action type can't be null", exception.getMessage());

}

@Test
public void constructor_NullUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("url can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = null;
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("url can't be null", exception.getMessage());
}

@Test
public void constructor_NullMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("method can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = null;
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("method can't be null", exception.getMessage());
}

@Test
public void writeTo_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void writeTo() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void toXContent_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content);
String expctedContent = """
{"action_type":"PREDICT","method":"http","url":"https://test.com",\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void toXContent() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert
.assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
String expctedContent = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}";
String jsonStr = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}"\
""";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand All @@ -160,24 +153,23 @@ public void parse() throws IOException {
);
parser.nextToken();
ConnectorAction action = ConnectorAction.parse(parser);
Assert.assertEquals("http", action.getMethod());
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
Assert.assertEquals("https://test.com", action.getUrl());
Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody());
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
assertEquals(TEST_METHOD_HTTP, action.getMethod());
assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
assertEquals(URL, action.getUrl());
assertEquals(TEST_REQUEST_BODY, action.getRequestBody());
assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); });
assertEquals("Wrong Action Type of badAction", exception.getMessage());
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Loading

0 comments on commit ff5392b

Please sign in to comment.