Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] populate time fields for connectors on return (#2922) #3035

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ public abstract class AbstractConnector implements Connector {
protected User owner;
@Setter
protected AccessMode access;
@Setter
protected Instant createdTime;
@Setter
protected Instant lastUpdateTime;
@Setter
protected ConnectorClientConfig connectorClientConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -44,6 +45,10 @@ public interface Connector extends ToXContentObject, Writeable {

String getProtocol();

void setCreatedTime(Instant createdTime);

void setLastUpdateTime(Instant lastUpdateTime);

User getOwner();

void setOwner(User user);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,22 @@ public void toXContent_InternalConnector() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mlModel.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
assertEquals(
"{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}",
mlModelContent
);

String expectedConnectorResponse = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\","
+ "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\","
+ "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}";

assertEquals(expectedConnectorResponse, mlModelContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
connector.toXContent(builder, ToXContent.EMPTY_PARAMS);
String content = TestHelper.xContentBuilderToString(builder);

Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content);
}

@Test
public void constructor_Parser() throws IOException {

XContentParser parser = XContentType.JSON
.xContent()
.createParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ public void toXContentTest() throws IOException {
mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals(
"{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}",
jsonStr
);

String expectedControllerResponse = "{\"name\":\"test_connector_name\",\"version\":\"1\","
+ "\"description\":\"this is a test connector\",\"protocol\":\"http\","
+ "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"},"
+ "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\","
+ "\"headers\":{\"api_key\":\"${credential.key}\"},"
+ "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\","
+ "\"pre_process_function\":\"connector.pre_process.openai.embedding\","
+ "\"post_process_function\":\"connector.post_process.openai.embedding\"}],"
+ "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\","
+ "\"client_config\":{\"max_connection\":30,"
+ "\"connection_timeout\":30000,\"read_timeout\":30000,"
+ "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}";

assertEquals(expectedControllerResponse, jsonStr);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public class MLUpdateModelInputTest {
+ "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":"
+ "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":"
+ "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":"
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":"
+ "\"test-connector_id\",\"last_updated_time\":1}";
+ "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},"
+ "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}";

private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":"
+ "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":"
Expand Down Expand Up @@ -152,6 +152,7 @@ public void readInputStreamSuccessWithNullFields() throws IOException {
@Test
public void testToXContent() throws Exception {
String jsonStr = serializationWithToXContent(updateModelInput);

assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.HashSet;
import java.util.List;

Expand Down Expand Up @@ -134,6 +135,10 @@ private void indexConnector(Connector connector, ActionListener<MLCreateConnecto
listener.onResponse(response);
}, listener::onFailure);

Instant currentTime = Instant.now();
connector.setCreatedTime(currentTime);
connector.setLastUpdateTime(currentTime);

IndexRequest indexRequest = new IndexRequest(ML_CONNECTOR_INDEX);
indexRequest.source(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -93,6 +94,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
if (Boolean.TRUE.equals(hasPermission)) {
connector.update(mlUpdateConnectorAction.getUpdateContent(), mlEngine::encrypt);
connector.validateConnectorURL(trustedConnectorEndpointsRegex);

connector.setLastUpdateTime(Instant.now());

UpdateRequest updateRequest = new UpdateRequest(ML_CONNECTOR_INDEX, connectorId);
updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
updateRequest.doc(connector.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -202,6 +203,117 @@ public void setup() throws IOException {
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));
}

@Test
public void testUpdateConnectorDoesNotUpdateHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

assertNull(connector.getCreatedTime());
assertNull(connector.getLastUpdateTime());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertNull(connector.getCreatedTime());
assertNotNull(connector.getLastUpdateTime());
}

@Test
public void testUpdateConnectorUpdatesHttpConnectorTimeFields() {
HttpConnector connector = HttpConnector
.builder()
.name("test")
.protocol("http")
.version("1")
.credential(Map.of("api_key", "credential_value"))
.parameters(Map.of("param1", "value1"))
.actions(
Arrays
.asList(
ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("https://api.openai.com/v1/chat/completions")
.headers(Map.of("Authorization", "Bearer ${credential.api_key}"))
.requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }")
.build()
)
)
.build();

Instant testInitialTime = Instant.now();
connector.setCreatedTime(testInitialTime);
connector.setLastUpdateTime(testInitialTime);

assert (connector.getCreatedTime().toEpochMilli() == connector.getLastUpdateTime().toEpochMilli());

doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));

doAnswer(invocation -> {
ActionListener<Connector> listener = invocation.getArgument(2);
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
listener.onResponse(updateResponse);
return null;
}).when(client).update(any(UpdateRequest.class), isA(ActionListener.class));

updateConnectorTransportAction.doExecute(task, updateRequest, actionListener);

assertTrue(
"Last update time must be bigger than the creation time",
connector.getLastUpdateTime().toEpochMilli() >= connector.getCreatedTime().toEpochMilli()
);
}

@Test
public void testExecuteConnectorAccessControlSuccess() {
doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));
Expand Down
Loading