Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Enhance validation for create connector API #3260

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ public ConnectorAction(
String postProcessFunction
) {
if (actionType == null) {
throw new IllegalArgumentException("action type can't null");
throw new IllegalArgumentException("action type can't be null");
}
if (url == null) {
throw new IllegalArgumentException("url can't null");
throw new IllegalArgumentException("url can't be null");
}
if (method == null) {
throw new IllegalArgumentException("method can't null");
throw new IllegalArgumentException("method can't be null");
}
this.actionType = actionType;
this.method = method;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ public MLCreateConnectorInput(
if (protocol == null) {
throw new IllegalArgumentException("Connector protocol is null");
}
if (credential == null || credential.isEmpty()) {
throw new IllegalArgumentException("Connector credential is null or empty list");
}
}
this.name = name;
this.description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

package org.opensearch.ml.common.connector;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.isValidActionInModelPrediction;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentType;
Expand All @@ -27,130 +26,124 @@
import org.opensearch.search.SearchModule;

public class ConnectorActionTest {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();

// Shared test data for the class
private static final ConnectorAction.ActionType TEST_ACTION_TYPE = ConnectorAction.ActionType.PREDICT;
private static final String TEST_METHOD_POST = "post";
private static final String TEST_METHOD_HTTP = "http";
private static final String TEST_REQUEST_BODY = "{\"input\": \"${parameters.input}\"}";
private static final String URL = "https://test.com";

@Test
public void constructor_NullActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("action type can't null");
ConnectorAction.ActionType actionType = null;
String method = "post";
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(null, TEST_METHOD_POST, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("action type can't be null", exception.getMessage());

}

@Test
public void constructor_NullUrl() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("url can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "post";
String url = null;
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_POST, null, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("url can't be null", exception.getMessage());
}

@Test
public void constructor_NullMethod() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("method can't null");
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = null;
String url = "https://test.com";
new ConnectorAction(actionType, method, url, null, null, null, null);
Throwable exception = assertThrows(
IllegalArgumentException.class,
() -> new ConnectorAction(TEST_ACTION_TYPE, null, URL, null, TEST_REQUEST_BODY, null, null)
);
assertEquals("method can't be null", exception.getMessage());
}

@Test
public void writeTo_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void writeTo() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);
BytesStreamOutput output = new BytesStreamOutput();
action.writeTo(output);
ConnectorAction action2 = new ConnectorAction(output.bytes().streamInput());
Assert.assertEquals(action, action2);
assertEquals(action, action2);
}

@Test
public void toXContent_NullValue() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
ConnectorAction action = new ConnectorAction(actionType, method, url, null, null, null, null);
ConnectorAction action = new ConnectorAction(TEST_ACTION_TYPE, TEST_METHOD_HTTP, URL, null, TEST_REQUEST_BODY, null, null);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals("{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\"}", content);
String expctedContent = """
{"action_type":"PREDICT","method":"http","url":"https://test.com",\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void toXContent() throws IOException {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
String method = "http";
String url = "https://test.com";
Map<String, String> headers = new HashMap<>();
headers.put("key1", "value1");
String requestBody = "{\"input\": \"${parameters.input}\"}";
String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT;
String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING;

ConnectorAction action = new ConnectorAction(
actionType,
method,
url,
TEST_ACTION_TYPE,
TEST_METHOD_HTTP,
URL,
headers,
requestBody,
TEST_REQUEST_BODY,
preProcessFunction,
postProcessFunction
);

XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
action.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert
.assertEquals(
"{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}",
content
);
String expctedContent = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}\
""";
assertEquals(expctedContent, content);
}

@Test
public void parse() throws IOException {
String jsonStr = "{\"action_type\":\"PREDICT\",\"method\":\"http\",\"url\":\"https://test.com\","
+ "\"headers\":{\"key1\":\"value1\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}";
String jsonStr = """
{"action_type":"PREDICT","method":"http","url":"https://test.com","headers":{"key1":"value1"},\
"request_body":"{\\"input\\": \\"${parameters.input}\\"}",\
"pre_process_function":"connector.pre_process.openai.embedding",\
"post_process_function":"connector.post_process.openai.embedding"}"\
""";
XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand All @@ -160,24 +153,23 @@ public void parse() throws IOException {
);
parser.nextToken();
ConnectorAction action = ConnectorAction.parse(parser);
Assert.assertEquals("http", action.getMethod());
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
Assert.assertEquals("https://test.com", action.getUrl());
Assert.assertEquals("{\"input\": \"${parameters.input}\"}", action.getRequestBody());
Assert.assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
Assert.assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
assertEquals(TEST_METHOD_HTTP, action.getMethod());
assertEquals(ConnectorAction.ActionType.PREDICT, action.getActionType());
assertEquals(URL, action.getUrl());
assertEquals(TEST_REQUEST_BODY, action.getRequestBody());
assertEquals("connector.pre_process.openai.embedding", action.getPreProcessFunction());
assertEquals("connector.post_process.openai.embedding", action.getPostProcessFunction());
}

@Test
public void test_wrongActionType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Wrong Action Type");
ConnectorAction.ActionType.from("badAction");
Throwable exception = assertThrows(IllegalArgumentException.class, () -> { ConnectorAction.ActionType.from("badAction"); });
assertEquals("Wrong Action Type of badAction", exception.getMessage());
}

@Test
public void test_invalidActionInModelPrediction() {
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.from("execute");
Assert.assertEquals(isValidActionInModelPrediction(actionType), false);
assertEquals(isValidActionInModelPrediction(actionType), false);
}
}
Loading
Loading