Skip to content

Commit

Permalink
populate time fields for connectors on return
Browse files Browse the repository at this point in the history
fixes opensearch-project#2890 Currently any class that extends the AbstractConnector class has the fields createdTime and lastUpdatedTime set to null. The solution was instantiating the fields in the constructor of the AbstractConnector class, as well updating it within the HTTPConnector class whenever an update happens. Many tests were modified to catch the time fields being populated as such there will be many differences on the string in order to get around the timing issue when doing tests.

Signed-off-by: Brian Flores <[email protected]>
  • Loading branch information
brianf-aws committed Sep 11, 2024
1 parent 20842b5 commit 33ea2c8
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,18 @@ public abstract class AbstractConnector implements Connector {
protected User owner;
@Setter
protected AccessMode access;
@Setter
protected Instant createdTime;
@Setter
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;

public AbstractConnector() {
this.createdTime = Instant.now();
this.lastUpdateTime = Instant.now();
}

protected Map<String, String> createDecryptedHeaders(Map<String, String> headers) {
if (headers == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public void update(MLCreateConnectorInput updateContent, Function<String, String> function) {
this.setLastUpdateTime(Instant.now());
if (updateContent.getName() != null) {
this.name = updateContent.getName();
}
Expand Down
39 changes: 23 additions & 16 deletions common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import java.io.IOException;

import org.json.JSONObject;
import org.junit.Assert;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand Down Expand Up @@ -61,22 +62,28 @@ public void toXContent_InternalConnector() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent
);

JSONObject mlModelJsonObject = new JSONObject(mlModelContent);
long mlModelCreatedTime = mlModelJsonObject.getJSONObject("connector").getLong("created_time");
long mlModelLastUpdatedTime = mlModelJsonObject.getJSONObject("connector").getLong("last_updated_time");

String expectedConnectorResponseFormat = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"created_time\":%d,\"last_updated_time\":%d,"
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}";

String expectedConnectorResponse = String.format(expectedConnectorResponseFormat, mlModelCreatedTime, mlModelLastUpdatedTime);
assertEquals(expectedConnectorResponse, mlModelContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -17,6 +18,7 @@
import java.util.Map;
import java.util.function.Function;

import org.json.JSONObject;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
Expand All @@ -41,7 +43,7 @@ public class HttpConnectorTest {
Function<String, String> encryptFunction;
Function<String, String> decryptFunction;

String TEST_CONNECTOR_JSON_STRING = "{\"name\":\"test_connector_name\",\"version\":\"1\","
String TEST_CONNECTOR_JSON_STRING_FORMAT = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
Expand All @@ -50,6 +52,7 @@ public class HttpConnectorTest {
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"created_time\":%d,\"last_updated_time\":%d,"
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

Expand Down Expand Up @@ -85,18 +88,27 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);
Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);

JSONObject connectorJsonObject = new JSONObject(content);
long connectorCreatedTime = connectorJsonObject.getLong("created_time");
long connectorLastUpdatedTime = connectorJsonObject.getLong("last_updated_time");

String testConnectorString = String.format(TEST_CONNECTOR_JSON_STRING_FORMAT, connectorCreatedTime, connectorLastUpdatedTime);

Assert.assertEquals(testConnectorString, content);
}

@Test
public void constructor_Parser() throws IOException {
long testEpochTime = Instant.now().toEpochMilli();
String testConnectorString = String.format(TEST_CONNECTOR_JSON_STRING_FORMAT, testEpochTime, testEpochTime);

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()),
null,
TEST_CONNECTOR_JSON_STRING
testConnectorString
);
parser.nextToken();

Expand All @@ -112,6 +124,8 @@ public void constructor_Parser() throws IOException {
Assert.assertEquals(ConnectorAction.ActionType.PREDICT, connector.getActions().get(0).getActionType());
Assert.assertEquals("POST", connector.getActions().get(0).getMethod());
Assert.assertEquals("https://test.com", connector.getActions().get(0).getUrl());
Assert.assertEquals(testEpochTime, connector.getCreatedTime().toEpochMilli());
Assert.assertEquals(testEpochTime, connector.getLastUpdateTime().toEpochMilli());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.io.IOException;
import java.io.UncheckedIOException;

import org.json.JSONObject;
import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
Expand Down Expand Up @@ -58,21 +59,27 @@ public void toXContentTest() throws IOException {
mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(
"{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}",
jsonStr
);

JSONObject connectorJsonObject = new JSONObject(jsonStr);
long connectorCreatedTime = connectorJsonObject.getLong("created_time");
long connectorLastUpdatedTime = connectorJsonObject.getLong("last_updated_time");

String expectedControllerResponseFormat = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"created_time\":%d,\"last_updated_time\":%d,"
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

String expectedControllerResponse = String.format(expectedControllerResponseFormat, connectorCreatedTime, connectorLastUpdatedTime);
assertEquals(expectedControllerResponse, jsonStr);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.json.JSONObject;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -55,7 +56,7 @@ public class MLUpdateModelInputTest {
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+ "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}";

private final String expectedOutputStrForUpdateRequestDoc =
private final String expectedOutputStrForUpdateRequestDocFormat =
"{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":"
+ "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":"
+ "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":"
Expand All @@ -64,8 +65,9 @@ public class MLUpdateModelInputTest {
+ "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
+ "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
+ "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+ "\"test-connector_id\",\"last_updated_time\":1}";
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}],"
+ "\"created_time\":%d,\"last_updated_time\":%d},"
+ "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}";

private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
+ "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":"
Expand Down Expand Up @@ -152,7 +154,14 @@ public void readInputStreamSuccessWithNullFields() throws IOException {
@Test
public void testToXContent() throws Exception {
String jsonStr = serializationWithToXContent(updateModelInput);
assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);

JSONObject connectorJsonObject = new JSONObject(jsonStr);
long connectorCreatedTime = connectorJsonObject.getJSONObject("connector").getLong("created_time");
long connectorLastUpdatedTime = connectorJsonObject.getJSONObject("connector").getLong("last_updated_time");

String expectedStringResponse = String
.format(expectedOutputStrForUpdateRequestDocFormat, connectorCreatedTime, connectorLastUpdatedTime);
assertEquals(expectedStringResponse, jsonStr);
}

@Test
Expand Down

0 comments on commit 33ea2c8

Please sign in to comment.