diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index aac9a1acad..4849f79c93 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -65,7 +65,9 @@ public abstract class AbstractConnector implements Connector { protected User owner; @Setter protected AccessMode access; + @Setter protected Instant createdTime; + @Setter protected Instant lastUpdateTime; @Setter protected ConnectorClientConfig connectorClientConfig; diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index c808f6628c..0a37641144 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -12,6 +12,7 @@ import java.security.AccessController; import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; +import java.time.Instant; import java.util.List; import java.util.Map; import java.util.Optional; @@ -44,6 +45,10 @@ public interface Connector extends ToXContentObject, Writeable { String getProtocol(); + void setCreatedTime(Instant createdTime); + + void setLastUpdateTime(Instant lastUpdateTime); + User getOwner(); void setOwner(User user); diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java index acb05c9b66..b09553e7d9 100644 --- a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -61,22 +61,22 @@ public void toXContent_InternalConnector() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals( - "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," - + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," - + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," - + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," - + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}", - mlModelContent - ); + + String expectedConnectorResponse = "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\"," + + "\"algorithm\":\"REMOTE\",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector\":{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30,\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}}"; + + assertEquals(expectedConnectorResponse, mlModelContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 0115ac1376..16bbc76bfa 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -85,12 +85,12 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); connector.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); + Assert.assertEquals(TEST_CONNECTOR_JSON_STRING, content); } @Test public void constructor_Parser() throws IOException { - XContentParser parser = XContentType.JSON .xContent() .createParser( diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 70cbd64bc2..0623357b41 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -58,21 +58,21 @@ public void toXContentTest() throws IOException { mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals( - "{\"name\":\"test_connector_name\",\"version\":\"1\"," - + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," - + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," - + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," - + "\"headers\":{\"api_key\":\"${credential.key}\"}," - + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," - + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," - + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," - + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," - + "\"client_config\":{\"max_connection\":30," - + "\"connection_timeout\":30000,\"read_timeout\":30000," - + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}", - jsonStr - ); + + String expectedControllerResponse = "{\"name\":\"test_connector_name\",\"version\":\"1\"," + + "\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"access\":\"public\"," + + "\"client_config\":{\"max_connection\":30," + + "\"connection_timeout\":30000,\"read_timeout\":30000," + + "\"retry_backoff_millis\":10,\"retry_timeout_seconds\":10,\"max_retry_times\":-1,\"retry_backoff_policy\":\"constant\"}}"; + + assertEquals(expectedControllerResponse, jsonStr); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 9014f0ec49..46e89d0aa6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -64,8 +64,8 @@ public class MLUpdateModelInputTest { + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" - + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" - + "\"test-connector_id\",\"last_updated_time\":1}"; + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]}," + + "\"connector_id\":\"test-connector_id\",\"last_updated_time\":1}"; private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" @@ -152,6 +152,7 @@ public void readInputStreamSuccessWithNullFields() throws IOException { @Test public void testToXContent() throws Exception { String jsonStr = serializationWithToXContent(updateModelInput); + assertEquals(expectedOutputStrForUpdateRequestDoc, jsonStr); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index 4cadcc936a..92b087f686 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -8,6 +8,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import java.time.Instant; import java.util.HashSet; import java.util.List; @@ -134,6 +135,10 @@ private void indexConnector(Connector connector, ActionListener { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); + + assertNull(connector.getCreatedTime()); + assertNotNull(connector.getLastUpdateTime()); + } + + @Test + public void testUpdateConnectorUpdatesHttpConnectorTimeFields() { + HttpConnector connector = HttpConnector + .builder() + .name("test") + .protocol("http") + .version("1") + .credential(Map.of("api_key", "credential_value")) + .parameters(Map.of("param1", "value1")) + .actions( + Arrays + .asList( + ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("https://api.openai.com/v1/chat/completions") + .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) + .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build() + ) + ) + .build(); + + Instant testInitialTime = Instant.now(); + connector.setCreatedTime(testInitialTime); + connector.setLastUpdateTime(testInitialTime); + + assert (connector.getCreatedTime().toEpochMilli() == connector.getLastUpdateTime().toEpochMilli()); + + doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(connector); + return null; + }).when(connectorAccessControlHelper).getConnector(any(Client.class), any(String.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(updateResponse); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + updateConnectorTransportAction.doExecute(task, updateRequest, actionListener); + + assertTrue( + "Last update time must be bigger than the creation time", + connector.getLastUpdateTime().toEpochMilli() >= connector.getCreatedTime().toEpochMilli() + ); + } + @Test public void testExecuteConnectorAccessControlSuccess() { doReturn(true).when(connectorAccessControlHelper).validateConnectorAccess(any(Client.class), any(Connector.class));