From 8dbd68ea28de56a3a279e4853975ded411a1cbae Mon Sep 17 00:00:00 2001 From: TrungBui59 Date: Wed, 22 Nov 2023 13:45:24 -0500 Subject: [PATCH] Adding checkstyle and apply Signed-off-by: TrungBui59 --- .../org/opensearch/ml/common/AccessMode.java | 4 +- .../org/opensearch/ml/common/CommonValue.java | 571 +++++++++--------- .../ml/common/MLCommonsClassLoader.java | 40 +- .../org/opensearch/ml/common/MLModel.java | 150 ++--- .../opensearch/ml/common/MLModelGroup.java | 86 +-- .../java/org/opensearch/ml/common/MLTask.java | 60 +- .../ml/common/annotation/ExecuteInput.java | 4 +- .../ml/common/annotation/ExecuteOutput.java | 4 +- .../ml/common/annotation/InputDataSet.java | 4 +- .../ml/common/annotation/MLAlgoOutput.java | 4 +- .../ml/common/annotation/MLAlgoParameter.java | 4 +- .../ml/common/annotation/MLInput.java | 4 +- .../common/connector/AbstractConnector.java | 25 +- .../ml/common/connector/AwsConnector.java | 45 +- .../ml/common/connector/Connector.java | 34 +- .../ml/common/connector/ConnectorAction.java | 36 +- .../common/connector/ConnectorProtocols.java | 9 +- .../ml/common/connector/HttpConnector.java | 56 +- .../connector/MLPostProcessFunction.java | 29 +- .../common/conversation/ActionConstants.java | 4 +- .../common/conversation/ConversationMeta.java | 33 +- .../ConversationalIndexConstants.java | 10 +- .../ml/common/conversation/Interaction.java | 62 +- .../ml/common/dataframe/ColumnMeta.java | 10 +- .../ml/common/dataframe/ColumnType.java | 14 +- .../ml/common/dataframe/ColumnValue.java | 6 +- .../common/dataframe/ColumnValueBuilder.java | 30 +- .../common/dataframe/ColumnValueReader.java | 2 +- .../ml/common/dataframe/DataFrameBuilder.java | 27 +- .../ml/common/dataframe/DefaultDataFrame.java | 60 +- .../ml/common/dataframe/FloatValue.java | 7 +- .../ml/common/dataframe/LongValue.java | 7 +- .../opensearch/ml/common/dataframe/Row.java | 49 +- .../ml/common/dataframe/ShortValue.java | 7 +- .../common/dataset/DataFrameInputDataset.java | 4 +- .../ml/common/dataset/MLInputDataset.java | 2 +- .../dataset/SearchQueryInputDataset.java | 8 +- .../common/dataset/TextDocsInputDataSet.java | 21 +- .../remote/RemoteInferenceInputDataSet.java | 15 +- .../ml/common/exception/ExecuteException.java | 16 +- .../exception/MLLimitExceededException.java | 2 +- .../ml/common/input/InputHelper.java | 90 ++- .../opensearch/ml/common/input/MLInput.java | 46 +- .../AnomalyLocalizationInput.java | 33 +- .../MetricsCorrelationInput.java | 25 +- .../LocalSampleCalculatorInput.java | 29 +- .../ml/common/input/nlp/TextDocsMLInput.java | 27 +- .../ad/AnomalyDetectionLibSVMParams.java | 36 +- .../parameter/clustering/KMeansParams.java | 29 +- .../clustering/RCFSummarizeParams.java | 23 +- .../input/parameter/rcf/BatchRCFParams.java | 40 +- .../input/parameter/rcf/FitRCFParams.java | 56 +- .../regression/LinearRegressionParams.java | 64 +- .../regression/LogisticRegressionParams.java | 47 +- .../parameter/sample/SampleAlgoParams.java | 21 +- .../input/remote/RemoteInferenceMLInput.java | 12 +- .../ml/common/model/MLModelConfig.java | 7 +- .../ml/common/model/MLModelState.java | 2 +- .../model/MetricsCorrelationModelConfig.java | 16 +- .../model/TextEmbeddingModelConfig.java | 48 +- .../opensearch/ml/common/output/MLOutput.java | 7 +- .../ml/common/output/MLPredictionOutput.java | 13 +- .../ml/common/output/MLTrainingOutput.java | 9 +- .../AnomalyLocalizationOutput.java | 24 +- .../metrics_correlation/MCorrModelTensor.java | 7 +- .../MCorrModelTensors.java | 29 +- .../MetricsCorrelationOutput.java | 19 +- .../LocalSampleCalculatorOutput.java | 9 +- .../common/output/model/MLResultDataType.java | 1 + .../output/model/ModelResultFilter.java | 30 +- .../ml/common/output/model/ModelTensor.java | 42 +- .../output/model/ModelTensorOutput.java | 18 +- .../ml/common/output/model/ModelTensors.java | 29 +- .../output/sample/SampleAlgoOutput.java | 12 +- .../opensearch/ml/common/package-info.java | 2 +- .../ml/common/transport/MLTaskRequest.java | 9 +- .../ml/common/transport/MLTaskResponse.java | 18 +- .../connector/MLConnectorDeleteAction.java | 4 +- .../connector/MLConnectorDeleteRequest.java | 22 +- .../connector/MLConnectorGetAction.java | 4 +- .../connector/MLConnectorGetRequest.java | 20 +- .../connector/MLConnectorGetResponse.java | 16 +- .../connector/MLCreateConnectorInput.java | 63 +- .../connector/MLCreateConnectorRequest.java | 20 +- .../connector/MLCreateConnectorResponse.java | 14 +- .../connector/MLUpdateConnectorAction.java | 4 +- .../connector/MLUpdateConnectorRequest.java | 20 +- .../transport/deploy/MLDeployModelInput.java | 21 +- .../deploy/MLDeployModelNodeRequest.java | 7 +- .../deploy/MLDeployModelNodeResponse.java | 11 +- .../deploy/MLDeployModelNodesRequest.java | 5 +- .../deploy/MLDeployModelNodesResponse.java | 6 +- .../deploy/MLDeployModelRequest.java | 32 +- .../deploy/MLDeployModelResponse.java | 16 +- .../execute/MLExecuteTaskRequest.java | 33 +- .../execute/MLExecuteTaskResponse.java | 20 +- .../transport/forward/MLForwardInput.java | 23 +- .../transport/forward/MLForwardRequest.java | 28 +- .../transport/forward/MLForwardResponse.java | 19 +- .../transport/model/MLModelDeleteAction.java | 4 +- .../transport/model/MLModelDeleteRequest.java | 22 +- .../transport/model/MLModelGetAction.java | 4 +- .../transport/model/MLModelGetRequest.java | 28 +- .../transport/model/MLModelGetResponse.java | 21 +- .../transport/model/MLUpdateModelInput.java | 30 +- .../transport/model/MLUpdateModelRequest.java | 32 +- .../model_group/MLModelGroupDeleteAction.java | 4 +- .../MLModelGroupDeleteRequest.java | 22 +- .../MLRegisterModelGroupInput.java | 39 +- .../MLRegisterModelGroupRequest.java | 26 +- .../MLRegisterModelGroupResponse.java | 16 +- .../model_group/MLUpdateModelGroupInput.java | 39 +- .../MLUpdateModelGroupRequest.java | 26 +- .../MLUpdateModelGroupResponse.java | 5 +- .../ml/common/transport/package-info.java | 2 +- .../prediction/MLPredictionTaskRequest.java | 12 +- .../register/MLRegisterModelInput.java | 106 +++- .../register/MLRegisterModelRequest.java | 26 +- .../register/MLRegisterModelResponse.java | 19 +- .../common/transport/sync/MLSyncUpInput.java | 31 +- .../transport/sync/MLSyncUpNodeRequest.java | 7 +- .../transport/sync/MLSyncUpNodeResponse.java | 18 +- .../transport/sync/MLSyncUpNodesRequest.java | 5 +- .../transport/sync/MLSyncUpNodesResponse.java | 6 +- .../transport/sync/MLSyncUpResponse.java | 7 +- .../transport/task/MLTaskDeleteAction.java | 4 +- .../transport/task/MLTaskDeleteRequest.java | 22 +- .../transport/task/MLTaskGetAction.java | 4 +- .../transport/task/MLTaskGetRequest.java | 21 +- .../transport/task/MLTaskGetResponse.java | 18 +- .../transport/task/MLTaskSearchAction.java | 1 - .../training/MLTrainingTaskRequest.java | 28 +- .../undeploy/MLUndeployModelInput.java | 15 +- .../undeploy/MLUndeployModelNodeRequest.java | 5 +- .../undeploy/MLUndeployModelNodeResponse.java | 18 +- .../undeploy/MLUndeployModelNodesRequest.java | 5 +- .../MLUndeployModelNodesResponse.java | 12 +- .../undeploy/MLUndeployModelsRequest.java | 30 +- .../undeploy/MLUndeployModelsResponse.java | 5 +- .../MLRegisterModelMetaAction.java | 2 +- .../MLRegisterModelMetaInput.java | 110 +++- .../MLRegisterModelMetaRequest.java | 28 +- .../MLRegisterModelMetaResponse.java | 7 +- .../MLUploadModelChunkAction.java | 1 - .../upload_chunk/MLUploadModelChunkInput.java | 12 +- .../MLUploadModelChunkRequest.java | 29 +- .../MLUploadModelChunkResponse.java | 12 +- .../ml/common/utils/StringUtils.java | 17 +- .../ml/common/MLCommonsClassLoaderTests.java | 80 +-- .../ml/common/MLModelGroupTest.java | 80 ++- .../opensearch/ml/common/MLModelTests.java | 86 +-- .../org/opensearch/ml/common/MLTaskTests.java | 19 +- .../ml/common/RemoteModelTests.java | 155 ++--- .../org/opensearch/ml/common/TestHelper.java | 21 +- .../ml/common/connector/AwsConnectorTest.java | 83 ++- .../common/connector/ConnectorActionTest.java | 59 +- .../ml/common/connector/ConnectorTest.java | 45 +- .../common/connector/HttpConnectorTest.java | 111 ++-- .../connector/MLPostProcessFunctionTest.java | 10 +- .../connector/MLPreProcessFunctionTest.java | 4 +- .../ml/common/dataframe/BooleanValueTest.java | 12 +- .../ml/common/dataframe/ColumnMetaTest.java | 13 +- .../ml/common/dataframe/ColumnTypeTest.java | 4 +- .../dataframe/ColumnValueBuilderTest.java | 14 +- .../dataframe/ColumnValueReaderTest.java | 8 +- .../ml/common/dataframe/ColumnValueTest.java | 4 +- .../dataframe/DataFrameBuilderTest.java | 38 +- .../dataframe/DefaultDataFrameTest.java | 98 ++- .../ml/common/dataframe/DoubleValueTest.java | 6 +- .../ml/common/dataframe/FloatValueTest.java | 10 +- .../ml/common/dataframe/IntValueTest.java | 10 +- .../ml/common/dataframe/LongValueTest.java | 14 +- .../ml/common/dataframe/NullValueTest.java | 12 +- .../ml/common/dataframe/RowTest.java | 124 +++- .../ml/common/dataframe/ShortValueTest.java | 16 +- .../ml/common/dataframe/StringValueTest.java | 10 +- .../dataset/DataFrameInputDatasetTest.java | 17 +- .../dataset/SearchQueryInputDatasetTest.java | 14 +- .../dataset/TextDocsInputDataSetTest.java | 10 +- .../RemoteInferenceInputDataSetTest.java | 12 +- .../ml/common/exception/MLExceptionTest.java | 4 +- .../MLLimitExceededExceptionTest.java | 5 +- .../MLResourceNotFoundExceptionTest.java | 4 +- .../exception/MLValidationExceptionTest.java | 5 +- .../ml/common/input/MLInputTest.java | 85 +-- .../AnomalyLocalizationInputTests.java | 145 +++-- .../MetricsCorrelationInputTest.java | 26 +- .../LocalSampleCalculatorInputTest.java | 19 +- .../common/input/nlp/TextDocsMLInputTest.java | 46 +- .../ad/AnomalyDetectionLibSVMParamsTest.java | 29 +- .../clustering/KMeansParamsTest.java | 22 +- .../clustering/RCFSummarizeParamsTest.java | 22 +- .../parameter/rcf/BatchRCFParamsTest.java | 18 +- .../input/parameter/rcf/FitRCFParamsTest.java | 33 +- .../LinearRegressionParamsTest.java | 74 +-- .../LogisticRegressionParamsTest.java | 54 +- .../sample/SampleAlgoParamsTest.java | 12 +- .../remote/RemoteInferenceMLInputTest.java | 19 +- .../ml/common/model/MLModelFormatTests.java | 4 +- .../ml/common/model/MLModelStateTests.java | 5 +- .../MetricsCorrelationModelConfigTests.java | 32 +- .../model/TextEmbeddingModelConfigTests.java | 47 +- .../common/output/MLPredictionOutputTest.java | 33 +- .../common/output/MLTrainingOutputTest.java | 10 +- .../AnomalyLocalizationOutputTests.java | 4 +- .../MCorrModelTensorTest.java | 27 +- .../MCorrModelTensorsTest.java | 39 +- .../MetricsCorrelationOutputTest.java | 39 +- .../LocalSampleCalculatorOutputTest.java | 16 +- .../output/model/ModelResultFilterTest.java | 30 +- .../output/model/ModelTensorOutputTest.java | 29 +- .../common/output/model/ModelTensorTest.java | 85 +-- .../common/output/model/ModelTensorsTest.java | 63 +- .../output/sample/SampleAlgoOutputTest.java | 16 +- .../MLConnectorDeleteRequestTests.java | 31 +- .../connector/MLConnectorGetRequestTests.java | 12 +- .../MLConnectorGetResponseTests.java | 45 +- .../MLCreateConnectorInputTests.java | 201 +++--- .../MLCreateConnectorRequestTests.java | 105 ++-- .../MLCreateConnectorResponseTests.java | 4 +- .../MLUpdateConnectorRequestTests.java | 45 +- .../deploy/MLDeployModelInputTest.java | 63 +- .../deploy/MLDeployModelNodeResponseTest.java | 32 +- .../deploy/MLDeployModelNodesRequestTest.java | 190 ++++-- .../MLDeployModelNodesResponseTest.java | 46 +- .../deploy/MLDeployModelRequestTest.java | 76 +-- .../deploy/MLDeployModelResponseTest.java | 15 +- .../execute/MLExecuteTaskRequestTest.java | 47 +- .../execute/MLExecuteTaskResponseTest.java | 94 +-- .../transport/forward/MLForwardInputTest.java | 135 +++-- .../forward/MLForwardRequestTest.java | 171 +++--- .../forward/MLForwardResponseTest.java | 31 +- .../model/MLModelDeleteRequestTest.java | 29 +- .../model/MLModelGetRequestTest.java | 30 +- .../model/MLModelGetResponseTest.java | 45 +- .../model/MLUpdateModelInputTest.java | 90 +-- .../model/MLUpdateModelRequestTest.java | 63 +- .../MLModelGroupDeleteRequestTest.java | 18 +- .../MLRegisterModelGroupInputTest.java | 25 +- .../MLRegisterModelGroupRequestTest.java | 90 +-- .../MLRegisterModelGroupResponseTest.java | 13 +- .../MLUpdateModelGroupInputTest.java | 27 +- .../MLUpdateModelGroupRequestTest.java | 63 +- .../MLUpdateModelGroupResponseTest.java | 13 +- .../MLPredictionTaskRequestTest.java | 77 ++- .../MLPredictionTaskResponseTest.java | 127 ++-- .../register/MLRegisterModelInputTest.java | 264 ++++---- .../register/MLRegisterModelRequestTest.java | 79 ++- .../register/MLRegisterModelResponseTest.java | 17 +- .../transport/sync/MLSyncUpInputTest.java | 67 +- .../sync/MLSyncUpNodeRequestTest.java | 113 ++-- .../sync/MLSyncUpNodeResponseTest.java | 49 +- .../sync/MLSyncUpNodesResponseTest.java | 13 +- .../transport/sync/MLSyncUpResponseTest.java | 12 +- .../transport/task/MLTaskGetRequestTest.java | 18 +- .../transport/task/MLTaskGetResponseTest.java | 80 +-- .../training/MLTrainingTaskRequestTest.java | 68 +-- .../training/MLTrainingTaskResponseTest.java | 39 +- .../undeploy/MLUndeployModelInputTest.java | 60 +- .../MLUndeployModelNodeResponseTest.java | 34 +- .../MLUndeployModelNodesRequestTest.java | 64 +- .../MLUndeployModelNodesResponseTest.java | 61 +- .../MLRegisterModelMetaInputTest.java | 200 +++--- .../MLRegisterModelMetaRequestTest.java | 230 +++---- .../MLRegisterModelMetaResponseTest.java | 99 ++- .../MLUploadModelChunkInputTest.java | 223 +++---- .../MLUploadModelChunkRequestTest.java | 165 +++-- .../MLUploadModelChunkResponseTest.java | 94 +-- .../ml/common/utils/StringUtilsTest.java | 21 +- 269 files changed, 5552 insertions(+), 4736 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/AccessMode.java b/common/src/main/java/org/opensearch/ml/common/AccessMode.java index 6b8e31e2fd..d4195206d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/AccessMode.java +++ b/common/src/main/java/org/opensearch/ml/common/AccessMode.java @@ -7,11 +7,11 @@ package org.opensearch.ml.common; -import lombok.Getter; - import java.util.HashMap; import java.util.Map; +import lombok.Getter; + public enum AccessMode { PUBLIC("public"), PRIVATE("private"), diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 84c3a96712..a04a525e48 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -5,8 +5,6 @@ package org.opensearch.ml.common; -import org.opensearch.ml.common.connector.AbstractConnector; - import static org.opensearch.ml.common.model.MLModelConfig.ALL_CONFIG_FIELD; import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD; @@ -15,6 +13,8 @@ import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_MODE_FIELD; +import org.opensearch.ml.common.connector.AbstractConnector; + public class CommonValue { public static Integer NO_SCHEMA_VERSION = 0; @@ -29,7 +29,7 @@ public class CommonValue { public static final String CREATE_TIME_FIELD = "create_time"; public static final String BOX_TYPE_KEY = "box_type"; - //hot node + // hot node public static String HOT_BOX_TYPE = "hot"; // warm node public static String WARM_BOX_TYPE = "warm"; @@ -45,282 +45,311 @@ public class CommonValue { public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; public static final String ML_MAP_RESPONSE_KEY = "response"; public static final String USER_FIELD_MAPPING = " \"" - + CommonValue.USER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n"; - public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + - " \"_meta\": {\n" + - " \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" + - " },\n" + - " \"properties\": {\n" + - " \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" + - " \"type\": \"integer\"\n" + - " },\n" + - " \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.ACCESS+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.OWNER+"\": {\n" + - " \"type\": \"nested\",\n" + - " \"properties\": {\n" + - " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + - " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + - " \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + - " }\n" + - "}"; + + CommonValue.USER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " }\n"; + public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + MLModelGroup.MODEL_GROUP_NAME_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.DESCRIPTION_FIELD + + "\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"" + + MLModelGroup.LATEST_VERSION_FIELD + + "\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"" + + MLModelGroup.MODEL_GROUP_ID_FIELD + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.CREATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModelGroup.LAST_UPDATED_TIME_FIELD + + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" - + " \"" - + AbstractConnector.NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + AbstractConnector.VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + AbstractConnector.PROTOCOL_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.CREDENTIAL_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.ACTIONS_FIELD - + "\" : {\"type\": \"flat_object\"}\n"; + + " \"" + + AbstractConnector.NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + AbstractConnector.VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + AbstractConnector.PROTOCOL_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.CREDENTIAL_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.ACTIONS_FIELD + + "\" : {\"type\": \"flat_object\"}\n"; public static final String ML_MODEL_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_MODEL_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLModel.ALGORITHM_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLModel.OLD_MODEL_VERSION_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_GROUP_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_FIELD - + "\" : {\"type\": \"binary\"},\n" - + " \"" - + MLModel.CHUNK_NUMBER_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.TOTAL_CHUNKS_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLModel.MODEL_FORMAT_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_STATE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODES_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DEPLOY_TO_ALL_NODES_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.MODEL_CONFIG_FIELD - + "\" : {\"properties\":{\"" - + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" - + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + POOLING_MODE_FIELD + "\":{\"type\":\"keyword\"},\"" - + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" - + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" - + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" - + " \"" - + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_REGISTERED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_DEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UNDEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.CONNECTOR_FIELD - + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," - + USER_FIELD_MAPPING - + " }\n" - + "}"; + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModel.ALGORITHM_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLModel.OLD_MODEL_VERSION_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_GROUP_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_FIELD + + "\" : {\"type\": \"binary\"},\n" + + " \"" + + MLModel.CHUNK_NUMBER_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.TOTAL_CHUNKS_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLModel.MODEL_FORMAT_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_STATE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODES_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DEPLOY_TO_ALL_NODES_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.MODEL_CONFIG_FIELD + + "\" : {\"properties\":{\"" + + MODEL_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + EMBEDDING_DIMENSION_FIELD + + "\":{\"type\":\"integer\"},\"" + + FRAMEWORK_TYPE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + POOLING_MODE_FIELD + + "\":{\"type\":\"keyword\"},\"" + + NORMALIZE_RESULT_FIELD + + "\":{\"type\":\"boolean\"},\"" + + MODEL_MAX_LENGTH_FIELD + + "\":{\"type\":\"integer\"},\"" + + ALL_CONFIG_FIELD + + "\":{\"type\":\"text\"}}},\n" + + " \"" + + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_REGISTERED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_DEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UNDEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.CONNECTOR_FIELD + + "\": {" + + ML_CONNECTOR_INDEX_FIELDS + + " }\n}," + + USER_FIELD_MAPPING + + " }\n" + + "}"; public static final String ML_TASK_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_TASK_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLTask.MODEL_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.TASK_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.FUNCTION_NAME_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.STATE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.INPUT_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.PROGRESS_FIELD - + "\": {\"type\": \"float\"},\n" - + " \"" - + MLTask.OUTPUT_INDEX_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.WORKER_NODE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.LAST_UPDATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.ERROR_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + MLTask.IS_ASYNC_TASK_FIELD - + "\" : {\"type\" : \"boolean\"}, \n" - + USER_FIELD_MAPPING - + " }\n" - + "}"; + + " \"_meta\": {\"schema_version\": " + + ML_TASK_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLTask.MODEL_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.TASK_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.FUNCTION_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.STATE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.INPUT_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.PROGRESS_FIELD + + "\": {\"type\": \"float\"},\n" + + " \"" + + MLTask.OUTPUT_INDEX_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.WORKER_NODE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.LAST_UPDATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.ERROR_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + MLTask.IS_ASYNC_TASK_FIELD + + "\" : {\"type\" : \"boolean\"}, \n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONNECTOR_SCHEMA_VERSION - + "},\n" - + ML_CONNECTOR_INDEX_FIELDS + ",\n" - + " \"" - + MLModelGroup.BACKEND_ROLES_FIELD - + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" - + MLModelGroup.ACCESS - + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" - + MLModelGroup.OWNER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " },\n" - + " \"" - + AbstractConnector.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + AbstractConnector.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; - + + " \"_meta\": {\"schema_version\": " + + ML_CONNECTOR_SCHEMA_VERSION + + "},\n" + + ML_CONNECTOR_INDEX_FIELDS + + ",\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + AbstractConnector.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + AbstractConnector.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; public static final String ML_CONFIG_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONFIG_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MASTER_KEY - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; } diff --git a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java index 8f3e537e68..28dbc7eb12 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java +++ b/common/src/main/java/org/opensearch/ml/common/MLCommonsClassLoader.java @@ -5,7 +5,14 @@ package org.opensearch.ml.common; -import lombok.extern.log4j.Log4j2; +import java.lang.reflect.Constructor; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.opensearch.ml.common.annotation.Connector; import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.annotation.ExecuteOutput; @@ -18,13 +25,7 @@ import org.opensearch.ml.common.output.MLOutputType; import org.reflections.Reflections; -import java.lang.reflect.Constructor; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import lombok.extern.log4j.Log4j2; @Log4j2 public class MLCommonsClassLoader { @@ -89,7 +90,7 @@ private static void loadMLAlgoParameterClassMapping() { if (mlAlgoParameter != null) { FunctionName[] algorithms = mlAlgoParameter.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { parameterClassMap.put(name, clazz); } } @@ -153,7 +154,7 @@ private static void loadExecuteInputClassMapping() { if (executeInput != null) { FunctionName[] algorithms = executeInput.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { executeInputClassMap.put(name, clazz); } } @@ -172,7 +173,7 @@ private static void loadExecuteOutputClassMapping() { if (executeOutput != null) { FunctionName[] algorithms = executeOutput.algorithms(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { executeOutputClassMap.put(name, clazz); } } @@ -188,7 +189,7 @@ private static void loadMLInputClassMapping() { if (mlInput != null) { FunctionName[] algorithms = mlInput.functionNames(); if (algorithms != null && algorithms.length > 0) { - for(FunctionName name : algorithms){ + for (FunctionName name : algorithms) { mlInputClassMap.put(name, clazz); } } @@ -223,7 +224,7 @@ private static S init(Map> map, T type, I i } catch (Exception e) { Throwable cause = e.getCause(); if (cause instanceof MLException || cause instanceof IllegalArgumentException) { - throw (RuntimeException)cause; + throw (RuntimeException) cause; } else { log.error("Failed to init instance for type " + type, e); return null; @@ -235,19 +236,16 @@ public static boolean canInitMLInput(FunctionName functionName) { return mlInputClassMap.containsKey(functionName); } - public static S initConnector(String name, Object[] initArgs, - Class... constructorParameterTypes) { + public static S initConnector(String name, Object[] initArgs, Class... constructorParameterTypes) { return init(connectorClassMap, name, initArgs, constructorParameterTypes); } @SuppressWarnings("unchecked") - public static , S> S initMLInput(T type, Object[] initArgs, - Class... constructorParameterTypes) { + public static , S> S initMLInput(T type, Object[] initArgs, Class... constructorParameterTypes) { return init(mlInputClassMap, type, initArgs, constructorParameterTypes); } - private static S init(Map> map, T type, - Object[] initArgs, Class... constructorParameterTypes) { + private static S init(Map> map, T type, Object[] initArgs, Class... constructorParameterTypes) { Class clazz = map.get(type); if (clazz == null) { throw new IllegalArgumentException("Can't find class for type " + type); @@ -258,8 +256,8 @@ private static S init(Map> map, T type, } catch (Exception e) { Throwable cause = e.getCause(); if (cause instanceof MLException) { - throw (MLException)cause; - } else if (cause instanceof IllegalArgumentException) { + throw (MLException) cause; + } else if (cause instanceof IllegalArgumentException) { throw (IllegalArgumentException) cause; } else { log.error("Failed to init instance for type " + type, e); diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 78f6f4ac60..f75e63d1c3 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -5,12 +5,19 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; +import static org.opensearch.ml.common.connector.Connector.createConnector; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; @@ -19,18 +26,12 @@ import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.USER; -import static org.opensearch.ml.common.connector.Connector.createConnector; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Getter public class MLModel implements ToXContentObject { @@ -50,7 +51,7 @@ public class MLModel implements ToXContentObject { public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_STATE_FIELD = "model_state"; public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes"; - //SHA256 hash value of model content. + // SHA256 hash value of model content. public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; public static final String MODEL_CONFIG_FIELD = "model_config"; @@ -115,32 +116,35 @@ public class MLModel implements ToXContentObject { private String connectorId; @Builder(toBuilder = true) - public MLModel(String name, - String modelGroupId, - FunctionName algorithm, - String version, - String content, - User user, - String description, - MLModelFormat modelFormat, - MLModelState modelState, - Long modelContentSizeInBytes, - String modelContentHash, - MLModelConfig modelConfig, - Instant createdTime, - Instant lastUpdateTime, - Instant lastRegisteredTime, - Instant lastDeployedTime, - Instant lastUndeployedTime, - Integer autoRedeployRetryTimes, - String modelId, Integer chunkNumber, - Integer totalChunks, - Integer planningWorkerNodeCount, - Integer currentWorkerNodeCount, - String[] planningWorkerNodes, - boolean deployToAllNodes, - Connector connector, - String connectorId) { + public MLModel( + String name, + String modelGroupId, + FunctionName algorithm, + String version, + String content, + User user, + String description, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHash, + MLModelConfig modelConfig, + Instant createdTime, + Instant lastUpdateTime, + Instant lastRegisteredTime, + Instant lastDeployedTime, + Instant lastUndeployedTime, + Integer autoRedeployRetryTimes, + String modelId, + Integer chunkNumber, + Integer totalChunks, + Integer planningWorkerNodeCount, + Integer currentWorkerNodeCount, + String[] planningWorkerNodes, + boolean deployToAllNodes, + Connector connector, + String connectorId + ) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -170,7 +174,7 @@ public MLModel(String name, this.connectorId = connectorId; } - public MLModel(StreamInput input) throws IOException{ + public MLModel(StreamInput input) throws IOException { name = input.readOptionalString(); algorithm = input.readEnum(FunctionName.class); version = input.readString(); @@ -371,7 +375,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String oldContent = null; User user = null; - String description = null;; + String description = null; + ; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; @@ -511,35 +516,36 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws break; } } - return MLModel.builder() - .name(name) - .modelGroupId(modelGroupId) - .algorithm(algorithm) - .version(version == null ? oldVersion + "" : version) - .content(content == null ? oldContent : content) - .user(user) - .description(description) - .modelFormat(modelFormat) - .modelState(modelState) - .modelContentSizeInBytes(modelContentSizeInBytes) - .modelContentHash(modelContentHash) - .modelConfig(modelConfig) - .createdTime(createdTime) - .lastUpdateTime(lastUpdateTime) - .lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime) - .lastDeployedTime(lastDeployedTime == null? lastLoadedTime : lastDeployedTime) - .lastUndeployedTime(lastUndeployedTime == null? lastUnloadedTime : lastUndeployedTime) - .modelId(modelId) - .autoRedeployRetryTimes(autoRedeployRetryTimes) - .chunkNumber(chunkNumber) - .totalChunks(totalChunks) - .planningWorkerNodeCount(planningWorkerNodeCount) - .currentWorkerNodeCount(currentWorkerNodeCount) - .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) - .deployToAllNodes(deployToAllNodes) - .connector(connector) - .connectorId(connectorId) - .build(); + return MLModel + .builder() + .name(name) + .modelGroupId(modelGroupId) + .algorithm(algorithm) + .version(version == null ? oldVersion + "" : version) + .content(content == null ? oldContent : content) + .user(user) + .description(description) + .modelFormat(modelFormat) + .modelState(modelState) + .modelContentSizeInBytes(modelContentSizeInBytes) + .modelContentHash(modelContentHash) + .modelConfig(modelConfig) + .createdTime(createdTime) + .lastUpdateTime(lastUpdateTime) + .lastRegisteredTime(lastRegisteredTime == null ? lastUploadedTime : lastRegisteredTime) + .lastDeployedTime(lastDeployedTime == null ? lastLoadedTime : lastDeployedTime) + .lastUndeployedTime(lastUndeployedTime == null ? lastUnloadedTime : lastUndeployedTime) + .modelId(modelId) + .autoRedeployRetryTimes(autoRedeployRetryTimes) + .chunkNumber(chunkNumber) + .totalChunks(totalChunks) + .planningWorkerNodeCount(planningWorkerNodeCount) + .currentWorkerNodeCount(currentWorkerNodeCount) + .planningWorkerNodes(planningWorkerNodes.toArray(new String[0])) + .deployToAllNodes(deployToAllNodes) + .connector(connector) + .connectorId(connectorId) + .build(); } public static MLModel fromStream(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 0b9143f8cd..91b21131d4 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -5,16 +5,7 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.common.util.CollectionUtils; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; import java.time.Instant; @@ -22,21 +13,30 @@ import java.util.List; import java.util.Objects; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Getter public class MLModelGroup implements ToXContentObject { - public static final String MODEL_GROUP_NAME_FIELD = "name"; //name of the model group - public static final String DESCRIPTION_FIELD = "description"; //description of the model group - public static final String LATEST_VERSION_FIELD = "latest_version"; //latest model version added to the model group - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //back_end roles as specified by the owner/admin - public static final String OWNER = "owner"; //user who creates/owns the model group - - public static final String ACCESS = "access"; //assigned to public, private, or null when model group created - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //unique ID assigned to each model group - public static final String CREATED_TIME_FIELD = "created_time"; //model group created time stamp - public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; //updated whenever a new model version is created + public static final String MODEL_GROUP_NAME_FIELD = "name"; // name of the model group + public static final String DESCRIPTION_FIELD = "description"; // description of the model group + public static final String LATEST_VERSION_FIELD = "latest_version"; // latest model version added to the model group + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // back_end roles as specified by the owner/admin + public static final String OWNER = "owner"; // user who creates/owns the model group + public static final String ACCESS = "access"; // assigned to public, private, or null when model group created + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // unique ID assigned to each model group + public static final String CREATED_TIME_FIELD = "created_time"; // model group created time stamp + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // updated whenever a new model version is created @Setter private String name; @@ -52,13 +52,18 @@ public class MLModelGroup implements ToXContentObject { private Instant createdTime; private Instant lastUpdatedTime; - @Builder(toBuilder = true) - public MLModelGroup(String name, String description, int latestVersion, - List backendRoles, User owner, String access, - String modelGroupId, - Instant createdTime, - Instant lastUpdatedTime) { + public MLModelGroup( + String name, + String description, + int latestVersion, + List backendRoles, + User owner, + String access, + String modelGroupId, + Instant createdTime, + Instant lastUpdatedTime + ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.latestVersion = latestVersion; @@ -70,8 +75,7 @@ public MLModelGroup(String name, String description, int latestVersion, this.lastUpdatedTime = lastUpdatedTime; } - - public MLModelGroup(StreamInput input) throws IOException{ + public MLModelGroup(StreamInput input) throws IOException { name = input.readString(); description = input.readOptionalString(); latestVersion = input.readInt(); @@ -194,20 +198,20 @@ public static MLModelGroup parse(XContentParser parser) throws IOException { break; } } - return MLModelGroup.builder() - .name(name) - .description(description) - .backendRoles(backendRoles) - .latestVersion(latestVersion) - .owner(owner) - .access(access) - .modelGroupId(modelGroupId) - .createdTime(createdTime) - .lastUpdatedTime(lastUpdateTime) - .build(); + return MLModelGroup + .builder() + .name(name) + .description(description) + .backendRoles(backendRoles) + .latestVersion(latestVersion) + .owner(owner) + .access(access) + .modelGroupId(modelGroupId) + .createdTime(createdTime) + .lastUpdatedTime(lastUpdateTime) + .build(); } - public static MLModelGroup fromStream(StreamInput in) throws IOException { MLModelGroup mlModel = new MLModelGroup(in); return mlModel; diff --git a/common/src/main/java/org/opensearch/ml/common/MLTask.java b/common/src/main/java/org/opensearch/ml/common/MLTask.java index 229bba5771..a810fa5159 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLTask.java +++ b/common/src/main/java/org/opensearch/ml/common/MLTask.java @@ -5,27 +5,28 @@ package org.opensearch.ml.common; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.USER; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.USER; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.Setter; @Getter @EqualsAndHashCode @@ -279,21 +280,22 @@ public static MLTask parse(XContentParser parser) throws IOException { break; } } - return MLTask.builder() - .taskId(taskId) - .modelId(modelId) - .taskType(taskType) - .functionName(functionName) - .state(state) - .inputType(inputType) - .progress(progress) - .outputIndex(outputIndex) - .workerNodes(workerNodes) - .createTime(createTime) - .lastUpdateTime(lastUpdateTime) - .error(error) - .user(user) - .async(async) - .build(); + return MLTask + .builder() + .taskId(taskId) + .modelId(modelId) + .taskType(taskType) + .functionName(functionName) + .state(state) + .inputType(inputType) + .progress(progress) + .outputIndex(outputIndex) + .workerNodes(workerNodes) + .createTime(createTime) + .lastUpdateTime(lastUpdateTime) + .error(error) + .user(user) + .async(async) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java index 34a879aaae..a9874a286a 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteInput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface ExecuteInput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java index e5a858f42d..42b2e1c1d0 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/ExecuteOutput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface ExecuteOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java b/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java index 847e00ac36..93965886ff 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/InputDataSet.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.dataset.MLInputDataType; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.dataset.MLInputDataType; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface InputDataSet { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java index df0afd7673..d24064be71 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoOutput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.output.MLOutputType; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.output.MLOutputType; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLAlgoOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java index 18136a78f6..eff313fc8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLAlgoParameter.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLAlgoParameter { diff --git a/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java index b8100473b0..31f520b181 100644 --- a/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/annotation/MLInput.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.annotation; -import org.opensearch.ml.common.FunctionName; - import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.opensearch.ml.common.FunctionName; + @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface MLInput { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java index 5fa213db99..db4328208d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AbstractConnector.java @@ -5,8 +5,16 @@ package org.opensearch.ml.common.connector; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + +import java.io.IOException; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + import org.apache.commons.text.StringSubstitutor; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; @@ -16,15 +24,8 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.utils.StringUtils; -import java.io.IOException; -import java.time.Instant; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static org.opensearch.ml.common.CommonValue.ML_MAP_RESPONSE_KEY; -import static org.opensearch.ml.common.utils.StringUtils.isJson; +import lombok.Getter; +import lombok.Setter; @Getter public abstract class AbstractConnector implements Connector { @@ -101,7 +102,7 @@ public void parseResponse(T response, List modelTensors, boolea } return; } - if (response instanceof String && isJson((String)response)) { + if (response instanceof String && isJson((String) response)) { Map data = StringUtils.fromJson((String) response, ML_MAP_RESPONSE_KEY); modelTensors.add(ModelTensor.builder().name("response").dataAsMap(data).build()); } else { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index ed9c64ac94..1ea25053a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -5,22 +5,23 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.AccessMode; +import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Optional; -import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; + +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; @Log4j2 @NoArgsConstructor @@ -29,9 +30,18 @@ public class AwsConnector extends HttpConnector { @Builder(builderMethodName = "awsConnectorBuilder") - public AwsConnector(String name, String description, String version, String protocol, - Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner) { + public AwsConnector( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + AccessMode accessMode, + User owner + ) { super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner); validate(); } @@ -41,7 +51,6 @@ public AwsConnector(String protocol, XContentParser parser) throws IOException { validate(); } - public AwsConnector(StreamInput input) throws IOException { super(input); validate(); @@ -51,17 +60,19 @@ private void validate() { if (credential == null || !credential.containsKey(ACCESS_KEY_FIELD) || !credential.containsKey(SECRET_KEY_FIELD)) { throw new IllegalArgumentException("Missing credential"); } - if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) { + if ((credential == null || !credential.containsKey(SERVICE_NAME_FIELD)) + && (parameters == null || !parameters.containsKey(SERVICE_NAME_FIELD))) { throw new IllegalArgumentException("Missing service name"); } - if ((credential == null || !credential.containsKey(REGION_FIELD)) && (parameters == null || !parameters.containsKey(REGION_FIELD))) { + if ((credential == null || !credential.containsKey(REGION_FIELD)) + && (parameters == null || !parameters.containsKey(REGION_FIELD))) { throw new IllegalArgumentException("Missing region"); } } @Override public Connector cloneConnector() { - try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){ + try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { this.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); return new AwsConnector(streamInput); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java index 0652a83421..ca17ca8a7a 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/Connector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/Connector.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; import java.io.IOException; import java.security.AccessController; @@ -16,12 +18,13 @@ import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; + import org.apache.commons.text.StringSubstitutor; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContentObject; @@ -32,27 +35,31 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.gson; - /** * Connector defines how to connect to a remote service. */ public interface Connector extends ToXContentObject, Writeable { String getName(); + String getProtocol(); + User getOwner(); + void setOwner(User user); AccessMode getAccess(); + void setAccess(AccessMode access); + List getBackendRoles(); void setBackendRoles(List backendRoles); + Map getParameters(); List getActions(); + String getPredictEndpoint(Map parameters); String getPredictHttpMethod(); @@ -60,6 +67,7 @@ public interface Connector extends ToXContentObject, Writeable { T createPredictPayload(Map parameters); void decrypt(Function function); + void encrypt(Function function); Connector cloneConnector(); @@ -92,7 +100,8 @@ default void validatePayload(String payload) { static Connector fromStream(StreamInput in) throws IOException { try { String connectorProtocol = in.readString(); - return MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, in}, String.class, StreamInput.class); + return MLCommonsClassLoader + .initConnector(connectorProtocol, new Object[] { connectorProtocol, in }, String.class, StreamInput.class); } catch (IllegalArgumentException illegalArgumentException) { throw illegalArgumentException; } @@ -115,25 +124,30 @@ static Connector createConnector(XContentParser parser) throws IOException { } catch (PrivilegedActionException e) { throw new IllegalArgumentException("wrong connector"); } - String connectorProtocol = (String)connectorMap.get("protocol"); + String connectorProtocol = (String) connectorMap.get("protocol"); return createConnector(jsonStr, connectorProtocol); } private static Connector createConnector(String jsonStr, String connectorProtocol) throws IOException { - try (XContentParser connectorParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr)) { + try ( + XContentParser connectorParser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr) + ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, connectorParser.nextToken(), connectorParser); if (connectorProtocol == null) { throw new IllegalArgumentException("connector protocol is null"); } - return MLCommonsClassLoader.initConnector(connectorProtocol, new Object[]{connectorProtocol, connectorParser}, String.class, XContentParser.class); + return MLCommonsClassLoader + .initConnector(connectorProtocol, new Object[] { connectorProtocol, connectorParser }, String.class, XContentParser.class); } catch (Exception ex) { if (ex instanceof IllegalArgumentException) { throw ex; } return null; - } + } } default void validateConnectorURL(List urlRegexes) { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java index ae43c10867..b06be717aa 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorAction.java @@ -5,9 +5,12 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,11 +18,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Locale; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; @Getter @EqualsAndHashCode @@ -170,15 +171,16 @@ public static ConnectorAction parse(XContentParser parser) throws IOException { break; } } - return ConnectorAction.builder() - .actionType(actionType) - .method(method) - .url(url) - .headers(headers) - .requestBody(requestBody) - .preProcessFunction(preProcessFunction) - .postProcessFunction(postProcessFunction) - .build(); + return ConnectorAction + .builder() + .actionType(actionType) + .method(method) + .url(url) + .headers(headers) + .requestBody(requestBody) + .preProcessFunction(preProcessFunction) + .postProcessFunction(postProcessFunction) + .build(); } public enum ActionType { diff --git a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java index 50412ce09a..408e4ea7c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/ConnectorProtocols.java @@ -7,7 +7,6 @@ import java.util.Arrays; import java.util.List; -import java.util.Set; public class ConnectorProtocols { @@ -18,10 +17,14 @@ public class ConnectorProtocols { public static void validateProtocol(String protocol) { if (protocol == null) { - throw new IllegalArgumentException("Connector protocol is null. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))); + throw new IllegalArgumentException( + "Connector protocol is null. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0])) + ); } if (!VALID_PROTOCOLS.contains(protocol)) { - throw new IllegalArgumentException("Unsupported connector protocol. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0]))); + throw new IllegalArgumentException( + "Unsupported connector protocol. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0])) + ); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index ef0e4bf4a1..7a9306d2eb 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -5,18 +5,11 @@ package org.opensearch.ml.common.connector; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.extern.log4j.Log4j2; -import org.apache.commons.text.StringSubstitutor; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.commons.authuser.User; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.AccessMode; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; +import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import static org.opensearch.ml.common.utils.StringUtils.isJson; import java.io.IOException; import java.time.Instant; @@ -29,13 +22,21 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; -import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; -import static org.opensearch.ml.common.utils.StringUtils.isJson; +import org.apache.commons.text.StringSubstitutor; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.extern.log4j.Log4j2; + @Log4j2 @NoArgsConstructor @EqualsAndHashCode @@ -47,12 +48,21 @@ public class HttpConnector extends AbstractConnector { public static final String SERVICE_NAME_FIELD = "service_name"; public static final String REGION_FIELD = "region"; - //TODO: add RequestConfig like request time out, + // TODO: add RequestConfig like request time out, @Builder - public HttpConnector(String name, String description, String version, String protocol, - Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode, User owner) { + public HttpConnector( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + AccessMode accessMode, + User owner + ) { validateProtocol(protocol); this.name = name; this.description = description; @@ -282,7 +292,7 @@ public void update(MLCreateConnectorInput updateContent, Function T createPredictPayload(Map parameters) { + public T createPredictPayload(Map parameters) { Optional predictAction = findPredictAction(); if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) { String payload = predictAction.get().getRequestBody(); @@ -336,7 +346,7 @@ public void decrypt(Function function) { @Override public Connector cloneConnector() { - try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()){ + try (BytesStreamOutput bytesStreamOutput = new BytesStreamOutput()) { this.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); return new HttpConnector(streamInput); diff --git a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java index 9d9ba90171..6ed25180b5 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/MLPostProcessFunction.java @@ -5,15 +5,15 @@ package org.opensearch.ml.common.connector; -import org.opensearch.ml.common.output.model.MLResultDataType; -import org.opensearch.ml.common.output.model.ModelTensor; - import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; + public class MLPostProcessFunction { public static final String COHERE_EMBEDDING = "connector.post_process.cohere.embedding"; @@ -25,7 +25,6 @@ public class MLPostProcessFunction { private static final Map>, List>> POST_PROCESS_FUNCTIONS = new HashMap<>(); - static { JSON_PATH_EXPRESSION.put(OPENAI_EMBEDDING, "$.data[*].embedding"); JSON_PATH_EXPRESSION.put(COHERE_EMBEDDING, "$.embeddings"); @@ -41,15 +40,19 @@ public static Function>, List> buildModelTensorLis if (embeddings == null) { throw new IllegalArgumentException("The list of embeddings is null when using the built-in post-processing function."); } - embeddings.forEach(embedding -> modelTensors.add( - ModelTensor - .builder() - .name("sentence_embedding") - .dataType(MLResultDataType.FLOAT32) - .shape(new long[]{embedding.size()}) - .data(embedding.toArray(new Number[0])) - .build() - )); + embeddings + .forEach( + embedding -> modelTensors + .add( + ModelTensor + .builder() + .name("sentence_embedding") + .dataType(MLResultDataType.FLOAT32) + .shape(new long[] { embedding.size() }) + .data(embedding.toArray(new Number[0])) + .build() + ) + ); return modelTensors; }; } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 5bb8334bc1..88d4e95188 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -54,7 +54,7 @@ public class ActionConstants { /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for list conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation"; + public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for put interaction */ public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; /** path for get interactions */ @@ -67,4 +67,4 @@ public class ActionConstants { /** default username for reporting security errors if no or malformed username */ public final static String DEFAULT_USERNAME_FOR_ERRORS = "BAD_USER"; -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java index 8ba518a065..04cbacf306 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -94,7 +94,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(user); } - /** * Convert this conversationMeta object into an IndexRequest so it can be indexed * @param index the index to send this conversation to. Should usually be .conversational-meta @@ -102,19 +101,19 @@ public void writeTo(StreamOutput out) throws IOException { */ public IndexRequest toIndexRequest(String index) { IndexRequest request = new IndexRequest(index); - return request.id(this.id).source( - ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime, - ConversationalIndexConstants.META_NAME_FIELD, this.name - ); + return request + .id(this.id) + .source( + ConversationalIndexConstants.META_CREATED_FIELD, + this.createdTime, + ConversationalIndexConstants.META_NAME_FIELD, + this.name + ); } @Override public String toString() { - return "{id=" + id - + ", name=" + name - + ", created=" + createdTime.toString() - + ", user=" + user - + "}"; + return "{id=" + id + ", name=" + name + ", created=" + createdTime.toString() + ", user=" + user + "}"; } @Override @@ -123,7 +122,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id); builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime); builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name); - if(this.user != null) { + if (this.user != null) { builder.field(ConversationalIndexConstants.USER_FIELD, this.user); } builder.endObject(); @@ -132,14 +131,14 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para @Override public boolean equals(Object other) { - if(!(other instanceof ConversationMeta)) { + if (!(other instanceof ConversationMeta)) { return false; } ConversationMeta otherConversation = (ConversationMeta) other; - return Objects.equals(this.id, otherConversation.id) && - Objects.equals(this.user, otherConversation.user) && - Objects.equals(this.createdTime, otherConversation.createdTime) && - Objects.equals(this.name, otherConversation.name); + return Objects.equals(this.id, otherConversation.id) + && Objects.equals(this.user, otherConversation.user) + && Objects.equals(this.createdTime, otherConversation.createdTime) + && Objects.equals(this.name, otherConversation.name); } - + } diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java index c8e652265b..701ca47871 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -36,7 +36,9 @@ public class ConversationalIndexConstants { /** Mappings for the conversational metadata index */ public final static String META_MAPPING = "{\n" + " \"_meta\": {\n" - + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " \"schema_version\": " + + META_INDEX_SCHEMA_VERSION + + "\n" + " },\n" + " \"properties\": {\n" + " \"" @@ -72,7 +74,9 @@ public class ConversationalIndexConstants { /** Mappings for the interactions index */ public final static String INTERACTIONS_MAPPINGS = "{\n" + " \"_meta\": {\n" - + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " \"schema_version\": " + + INTERACTIONS_INDEX_SCHEMA_VERSION + + "\n" + " },\n" + " \"properties\": {\n" + " \"" @@ -102,4 +106,4 @@ public class ConversationalIndexConstants { /** Feature Flag setting for conversational memory */ public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting .boolSetting("plugins.ml_commons.memory_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java index 9b6ec636bd..82350216f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -64,12 +64,12 @@ public class Interaction implements Writeable, ToXContentObject { */ public static Interaction fromMap(String id, Map fields) { Instant createTime = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD)); - String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); - String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); + String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); + String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); - String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); - String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); - String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); + String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); + String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); + String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); } @@ -101,7 +101,6 @@ public static Interaction fromStream(StreamInput in) throws IOException { return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); @@ -124,7 +123,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); - if(additionalInfo != null) { + if (additionalInfo != null) { builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); } builder.endObject(); @@ -133,33 +132,38 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Para @Override public boolean equals(Object other) { - return ( - other instanceof Interaction && - ((Interaction) other).id.equals(this.id) && - ((Interaction) other).conversationId.equals(this.conversationId) && - ((Interaction) other).createTime.equals(this.createTime) && - ((Interaction) other).input.equals(this.input) && - ((Interaction) other).promptTemplate.equals(this.promptTemplate) && - ((Interaction) other).response.equals(this.response) && - ((Interaction) other).origin.equals(this.origin) && - ( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) || - ((Interaction) other).additionalInfo.equals(this.additionalInfo)) - ); + return (other instanceof Interaction + && ((Interaction) other).id.equals(this.id) + && ((Interaction) other).conversationId.equals(this.conversationId) + && ((Interaction) other).createTime.equals(this.createTime) + && ((Interaction) other).input.equals(this.input) + && ((Interaction) other).promptTemplate.equals(this.promptTemplate) + && ((Interaction) other).response.equals(this.response) + && ((Interaction) other).origin.equals(this.origin) + && ((((Interaction) other).additionalInfo == null && this.additionalInfo == null) + || ((Interaction) other).additionalInfo.equals(this.additionalInfo))); } @Override public String toString() { return "Interaction{" - + "id=" + id - + ",cid=" + conversationId - + ",create_time=" + createTime - + ",origin=" + origin - + ",input=" + input - + ",promt_template=" + promptTemplate - + ",response=" + response - + ",additional_info=" + additionalInfo + + "id=" + + id + + ",cid=" + + conversationId + + ",create_time=" + + createTime + + ",origin=" + + origin + + ",input=" + + input + + ",promt_template=" + + promptTemplate + + ",response=" + + response + + ",additional_info=" + + additionalInfo + "}"; } - -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java index cfbb6484cb..4ca89ae2fa 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnMeta.java @@ -5,12 +5,17 @@ package org.opensearch.ml.common.dataframe; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.Locale; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import lombok.AccessLevel; import lombok.Builder; @@ -18,11 +23,6 @@ import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java index 28fe550cfe..1b15f3c7bf 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnType.java @@ -16,31 +16,31 @@ public enum ColumnType { NULL; public static ColumnType from(Object object) { - if(object instanceof Short) { + if (object instanceof Short) { return SHORT; } - if(object instanceof Integer) { + if (object instanceof Integer) { return INTEGER; } - if(object instanceof Long) { + if (object instanceof Long) { return LONG; } - if(object instanceof String) { + if (object instanceof String) { return STRING; } - if(object instanceof Double) { + if (object instanceof Double) { return DOUBLE; } - if(object instanceof Boolean) { + if (object instanceof Boolean) { return BOOLEAN; } - if(object instanceof Float) { + if (object instanceof Float) { return FLOAT; } diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java index 3a804d3f5e..03aa7d6acc 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValue.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; +import java.util.Objects; + import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Objects; - public interface ColumnValue extends Writeable, ToXContentObject { ColumnType columnType(); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java index 6f91b11764..098699be63 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueBuilder.java @@ -18,36 +18,36 @@ public class ColumnValueBuilder { * @return ColumnValue */ public ColumnValue build(Object object) { - if(Objects.isNull(object)) { + if (Objects.isNull(object)) { return new NullValue(); } - if(object instanceof Short) { - return new ShortValue((Short)object); + if (object instanceof Short) { + return new ShortValue((Short) object); } - if(object instanceof Integer) { - return new IntValue((Integer)object); + if (object instanceof Integer) { + return new IntValue((Integer) object); } - if(object instanceof Long) { - return new LongValue((Long)object); + if (object instanceof Long) { + return new LongValue((Long) object); } - if(object instanceof String) { - return new StringValue((String)object); + if (object instanceof String) { + return new StringValue((String) object); } - if(object instanceof Double) { - return new DoubleValue((Double)object); + if (object instanceof Double) { + return new DoubleValue((Double) object); } - if(object instanceof Boolean) { - return new BooleanValue((Boolean)object); + if (object instanceof Boolean) { + return new BooleanValue((Boolean) object); } - if(object instanceof Float) { - return new FloatValue((Float)object); + if (object instanceof Float) { + return new FloatValue((Float) object); } throw new IllegalArgumentException("unsupported type:" + object.getClass().getName()); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java index 759b723faa..e132b362a5 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ColumnValueReader.java @@ -14,7 +14,7 @@ public class ColumnValueReader implements Writeable.Reader { @Override public ColumnValue read(StreamInput in) throws IOException { ColumnType columnType = in.readEnum(ColumnType.class); - switch (columnType){ + switch (columnType) { case SHORT: return new ShortValue(in.readShort()); case INTEGER: diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java index c225b742e2..fac9e2ce24 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DataFrameBuilder.java @@ -24,7 +24,7 @@ public class DataFrameBuilder { * @return empty data frame */ public DataFrame emptyDataFrame(final ColumnMeta[] columnMetas) { - if(columnMetas == null || columnMetas.length == 0) { + if (columnMetas == null || columnMetas.length == 0) { throw new IllegalArgumentException("columnMetas array is null or empty"); } return new DefaultDataFrame(columnMetas); @@ -37,7 +37,7 @@ public DataFrame emptyDataFrame(final ColumnMeta[] columnMetas) { * @return data frame */ public DataFrame load(final List> input) { - if(input == null || input.isEmpty()) { + if (input == null || input.isEmpty()) { throw new IllegalArgumentException("input is null or empty"); } @@ -45,11 +45,8 @@ public DataFrame load(final List> input) { ColumnMeta[] columnMetas = new ColumnMeta[element.size()]; int index = 0; - for(Map.Entry entry : element.entrySet()) { - ColumnMeta columnMeta = ColumnMeta.builder() - .name(entry.getKey()) - .columnType(ColumnType.from(entry.getValue())) - .build(); + for (Map.Entry entry : element.entrySet()) { + ColumnMeta columnMeta = ColumnMeta.builder().name(entry.getKey()).columnType(ColumnType.from(entry.getValue())).build(); columnMetas[index++] = columnMeta; } @@ -63,36 +60,36 @@ public DataFrame load(final List> input) { * @param input input list of map objects * @return data frame */ - public DataFrame load(final ColumnMeta[] columnMetas, final List> input){ - if(columnMetas == null || columnMetas.length == 0) { + public DataFrame load(final ColumnMeta[] columnMetas, final List> input) { + if (columnMetas == null || columnMetas.length == 0) { throw new IllegalArgumentException("columnMetas array is null or empty"); } - if(input == null || input.isEmpty()) { + if (input == null || input.isEmpty()) { throw new IllegalArgumentException("input data list is null or empty"); } int columnSize = columnMetas.length; Map columnsMap = new HashMap<>(); - for(int i = 0; i < columnSize; i++) { + for (int i = 0; i < columnSize; i++) { columnsMap.put(columnMetas[i].getName(), i); } List rows = input.stream().map(item -> { Row row = new Row(columnSize); - if(item.size() != columnSize) { + if (item.size() != columnSize) { throw new IllegalArgumentException("input item map size is different in the map"); } - for(Map.Entry entry : item.entrySet()) { - if(!columnsMap.containsKey(entry.getKey())) { + for (Map.Entry entry : item.entrySet()) { + if (!columnsMap.containsKey(entry.getKey())) { throw new IllegalArgumentException("field of input item doesn't exist in columns, filed:" + entry.getKey()); } String columnName = entry.getKey(); int index = columnsMap.get(columnName); ColumnType columnType = columnMetas[index].getColumnType(); ColumnValue value = ColumnValueBuilder.build(entry.getValue()); - if(columnType != value.columnType()) { + if (columnType != value.columnType()) { throw new IllegalArgumentException("the same field has different data type"); } row.setValue(index, value); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java index 27dd667de6..e7d67fcca6 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/DefaultDataFrame.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataframe; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -14,31 +16,29 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.AccessLevel; -import lombok.ToString; -import lombok.experimental.FieldDefaults; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class DefaultDataFrame extends AbstractDataFrame{ +public class DefaultDataFrame extends AbstractDataFrame { private static final String COLUMN_META_FIELD = "column_metas"; private static final String ROWS_FIELD = "rows"; List rows; ColumnMeta[] columnMetas; - public DefaultDataFrame(final ColumnMeta[] columnMetas){ + public DefaultDataFrame(final ColumnMeta[] columnMetas) { super(DataFrameType.DEFAULT); this.columnMetas = columnMetas; this.rows = new ArrayList<>(); } - public DefaultDataFrame(final ColumnMeta[] columnMetas, final List rows){ + public DefaultDataFrame(final ColumnMeta[] columnMetas, final List rows) { super(DataFrameType.DEFAULT); this.columnMetas = columnMetas; this.rows = rows; @@ -52,12 +52,12 @@ public DefaultDataFrame(StreamInput streamInput) throws IOException { @Override public void appendRow(final Object[] values) { - if(values == null) { + if (values == null) { throw new IllegalArgumentException("input values can't be null"); } Row row = new Row(values.length); - for(int i = 0; i < values.length; i++) { + for (int i = 0; i < values.length; i++) { row.setValue(i, ColumnValueBuilder.build(values[i])); } @@ -66,20 +66,25 @@ public void appendRow(final Object[] values) { @Override public void appendRow(final Row row) { - if(row == null) { + if (row == null) { throw new IllegalArgumentException("input row can't be null"); } - if(row.size() != columnMetas.length) { - final String message = String.format("the size is different between input row:%d " + - "and column size in dataframe:%d", row.size(), columnMetas.length); + if (row.size() != columnMetas.length) { + final String message = String + .format("the size is different between input row:%d " + "and column size in dataframe:%d", row.size(), columnMetas.length); throw new IllegalArgumentException(message); } - for(int i = 0; i < columnMetas.length; i++) { - if(columnMetas[i].getColumnType() != row.getValue(i).columnType()) { - final String message = String.format("the column type is different in column meta:%s and input row:%s for index: %d", - columnMetas[i].getColumnType(), row.getValue(i).columnType(), i); + for (int i = 0; i < columnMetas.length; i++) { + if (columnMetas[i].getColumnType() != row.getValue(i).columnType()) { + final String message = String + .format( + "the column type is different in column meta:%s and input row:%s for index: %d", + columnMetas[i].getColumnType(), + row.getValue(i).columnType(), + i + ); throw new IllegalArgumentException(message); } } @@ -103,33 +108,33 @@ public ColumnMeta[] columnMetas() { @Override public DataFrame remove(int columnIndex) { - if(columnIndex < 0 || columnIndex >= columnMetas.length) { + if (columnIndex < 0 || columnIndex >= columnMetas.length) { throw new IllegalArgumentException("columnIndex can't be negative or bigger than columns length:" + columnMetas.length); } ColumnMeta[] newColumnMetas = new ColumnMeta[columnMetas.length - 1]; int index = 0; - for(int i = 0; i < columnMetas.length && i != columnIndex; i++) { + for (int i = 0; i < columnMetas.length && i != columnIndex; i++) { newColumnMetas[index++] = columnMetas[i]; } - return new DefaultDataFrame(newColumnMetas, rows.stream().map(row-> row.remove(columnIndex)).collect(Collectors.toList())); + return new DefaultDataFrame(newColumnMetas, rows.stream().map(row -> row.remove(columnIndex)).collect(Collectors.toList())); } @Override public DataFrame select(int[] columns) { - if(columns == null || columns.length == 0) { + if (columns == null || columns.length == 0) { throw new IllegalArgumentException("columns can't be null or empty"); } ColumnMeta[] newColumnMetas = new ColumnMeta[columns.length]; int index = 0; - for(int col : columns) { - if(col < 0 || col >= columnMetas.length) { + for (int col : columns) { + if (col < 0 || col >= columnMetas.length) { throw new IllegalArgumentException("columnIndex can't be negative or bigger than columns length"); } newColumnMetas[index++] = columnMetas[col]; } - return new DefaultDataFrame(newColumnMetas, rows.stream().map(row-> row.select(columns)).collect(Collectors.toList())); + return new DefaultDataFrame(newColumnMetas, rows.stream().map(row -> row.select(columns)).collect(Collectors.toList())); } @Override @@ -155,7 +160,6 @@ public Iterator iterator() { return rows.iterator(); } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -200,13 +204,13 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startArray(COLUMN_META_FIELD); - for(ColumnMeta columnMeta : columnMetas) { + for (ColumnMeta columnMeta : columnMetas) { columnMeta.toXContent(builder, params); } builder.endArray(); builder.startArray(ROWS_FIELD); - for(Row row : rows) { + for (Row row : rows) { row.toXContent(builder, params); } builder.endArray(); diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java index 98727f537b..7cf0be0543 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/FloatValue.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java index 24c15c9b8b..d3c7d70c8e 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/LongValue.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java b/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java index 8727c416e3..5d38b4bc21 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/Row.java @@ -5,9 +5,14 @@ package org.opensearch.ml.common.dataframe; -import lombok.AccessLevel; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -16,13 +21,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString @@ -43,14 +44,14 @@ public Row(ColumnValue[] values) { } void setValue(int index, ColumnValue value) { - if(index < 0 || index > size() - 1) { + if (index < 0 || index > size() - 1) { throw new IllegalArgumentException("index is out of scope, index:" + index + "; row size:" + size()); } this.values[index] = value; } public ColumnValue getValue(int index) { - if(index < 0 || index > size() - 1) { + if (index < 0 || index > size() - 1) { throw new IllegalArgumentException("index is out of scope, index:" + index + "; row size:" + size()); } return this.values[index]; @@ -71,7 +72,7 @@ public void writeTo(StreamOutput out) throws IOException { } Row remove(int removedIndex) { - if(removedIndex < 0 || removedIndex >= values.length) { + if (removedIndex < 0 || removedIndex >= values.length) { throw new IllegalArgumentException("removed index can't be negative or bigger than row's values length:" + values.length); } ColumnValue[] newValues = new ColumnValue[Math.max(values.length - 1, 0)]; @@ -86,7 +87,7 @@ Row remove(int removedIndex) { Row select(int[] columns) { ColumnValue[] newValues = new ColumnValue[columns.length]; int index = 0; - for(int col: columns) { + for (int col : columns) { newValues[index++] = values[col]; } @@ -109,7 +110,9 @@ public static Row parse(XContentParser parser) throws IOException { if (parser.nextToken() != XContentParser.Token.END_OBJECT) { String columnTypeField = parser.currentName(); if (!"column_type".equals(columnTypeField)) { - throw new IllegalArgumentException("wrong column type, expect column_type field but got " + columnTypeField); + throw new IllegalArgumentException( + "wrong column type, expect column_type field but got " + columnTypeField + ); } parser.nextToken(); String columnType = parser.text(); @@ -182,26 +185,28 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; Row other = (Row) o; if (this.size() != other.size()) { return false; } - for (int i = 0; i< this.size(); i++) { - if(!this.getValue(i).equals(other.getValue(i))) { + for (int i = 0; i < this.size(); i++) { + if (!this.getValue(i).equals(other.getValue(i))) { return false; } } return true; } - public boolean equals(Row other) { + public boolean equals(Row other) { if (this.size() != other.size()) { return false; } - for (int i = 0; i< this.size(); i++) { - if(!this.getValue(i).equals(other.getValue(i))) { + for (int i = 0; i < this.size(); i++) { + if (!this.getValue(i).equals(other.getValue(i))) { return false; } } diff --git a/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java b/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java index 77de5aecf4..c08f6e1ac9 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java +++ b/common/src/main/java/org/opensearch/ml/common/dataframe/ShortValue.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.dataframe; +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamOutput; + import lombok.AccessLevel; import lombok.RequiredArgsConstructor; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @RequiredArgsConstructor diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java index a535144354..ccb5e84014 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/DataFrameInputDataset.java @@ -11,14 +11,14 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataframe.DataFrameType; +import org.opensearch.ml.common.dataframe.DefaultDataFrame; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; import lombok.NonNull; import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.dataframe.DataFrameType; -import org.opensearch.ml.common.dataframe.DefaultDataFrame; /** * DataFrame based input data. Client directly passes the data frame to ML plugin with this. diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java index 2c3514530f..0d7374d4c1 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/MLInputDataset.java @@ -10,12 +10,12 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.ml.common.MLCommonsClassLoader; import lombok.AccessLevel; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.MLCommonsClassLoader; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java index 636384adbc..6d737887a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/SearchQueryInputDataset.java @@ -9,11 +9,11 @@ import java.util.Collections; import java.util.List; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.annotation.InputDataSet; @@ -60,7 +60,9 @@ public SearchQueryInputDataset(@NonNull List indices, @NonNull SearchSou public SearchQueryInputDataset(StreamInput streaminput) throws IOException { super(MLInputDataType.SEARCH_QUERY); String searchString = streaminput.readString(); - XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchString); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, searchString); this.searchSourceBuilder = SearchSourceBuilder.fromXContent(parser); this.indices = streaminput.readStringList(); } diff --git a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java index 98672841d7..0c9e4f224d 100644 --- a/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java +++ b/common/src/main/java/org/opensearch/ml/common/dataset/TextDocsInputDataSet.java @@ -5,24 +5,25 @@ package org.opensearch.ml.common.dataset; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.experimental.FieldDefaults; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.annotation.InputDataSet; import org.opensearch.ml.common.output.model.ModelResultFilter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @InputDataSet(MLInputDataType.TEXT_DOCS) -public class TextDocsInputDataSet extends MLInputDataset{ +public class TextDocsInputDataSet extends MLInputDataset { private ModelResultFilter resultFilter; @@ -43,7 +44,7 @@ public TextDocsInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.TEXT_DOCS); docs = new ArrayList<>(); int size = streamInput.readInt(); - for (int i=0; i parameters) { public RemoteInferenceInputDataSet(StreamInput streamInput) throws IOException { super(MLInputDataType.REMOTE); if (streamInput.readBoolean()) { - parameters = streamInput.readMap(s -> s.readString(), s-> s.readString()); + parameters = streamInput.readMap(s -> s.readString(), s -> s.readString()); } } @Override public void writeTo(StreamOutput streamOutput) throws IOException { super.writeTo(streamOutput); - if (parameters != null) { + if (parameters != null) { streamOutput.writeBoolean(true); streamOutput.writeMap(parameters, StreamOutput::writeString, StreamOutput::writeString); } else { diff --git a/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java b/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java index 756ec4f319..a8d4322e8c 100644 --- a/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java +++ b/common/src/main/java/org/opensearch/ml/common/exception/ExecuteException.java @@ -1,7 +1,15 @@ package org.opensearch.ml.common.exception; -public class ExecuteException extends MLException{ - public ExecuteException(String msg) { super(msg); } - public ExecuteException(Throwable cause) { super(cause); } - public ExecuteException(String msg, Throwable cause) { super(msg, cause); } +public class ExecuteException extends MLException { + public ExecuteException(String msg) { + super(msg); + } + + public ExecuteException(Throwable cause) { + super(cause); + } + + public ExecuteException(String msg, Throwable cause) { + super(msg, cause); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java b/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java index b5a529ad6d..476c0e0c8a 100644 --- a/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java +++ b/common/src/main/java/org/opensearch/ml/common/exception/MLLimitExceededException.java @@ -9,7 +9,7 @@ * This exception is thrown when a some limit is exceeded. * Won't count this exception in stats. */ -public class MLLimitExceededException extends MLException{ +public class MLLimitExceededException extends MLException { /** * Constructor with error message. diff --git a/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java b/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java index 067c74e4b2..fb2bbb8a7b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java +++ b/common/src/main/java/org/opensearch/ml/common/input/InputHelper.java @@ -5,15 +5,6 @@ package org.opensearch.ml.common.input; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; -import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; - -import java.util.Locale; -import java.util.Map; - import static org.opensearch.ml.common.FunctionName.BATCH_RCF; import static org.opensearch.ml.common.FunctionName.FIT_RCF; import static org.opensearch.ml.common.FunctionName.KMEANS; @@ -34,6 +25,15 @@ import static org.opensearch.ml.common.input.Constants.KM_DISTANCE_TYPE; import static org.opensearch.ml.common.input.Constants.KM_ITERATIONS; +import java.util.Locale; +import java.util.Map; + +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; +import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; +import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; + public class InputHelper { public static String getAction(Map arguments) { return (String) arguments.get(ACTION); @@ -42,22 +42,19 @@ public static String getAction(Map arguments) { public static FunctionName getFunctionName(Map arguments) { String algo = (String) arguments.get(ALGORITHM); if (algo == null) { - throw new IllegalArgumentException("The parameter algorithm is required."); + throw new IllegalArgumentException("The parameter algorithm is required."); } switch (algo.toLowerCase(Locale.ROOT)) { case Constants.KMEANS: return KMEANS; case Constants.RCF: - return arguments.get(AD_TIME_FIELD) == null ? - BATCH_RCF : FIT_RCF; + return arguments.get(AD_TIME_FIELD) == null ? BATCH_RCF : FIT_RCF; default: - throw new IllegalArgumentException( - String.format("unsupported algorithm: %s.", algo)); + throw new IllegalArgumentException(String.format("unsupported algorithm: %s.", algo)); } } - public static MLAlgoParams convertArgumentToMLParameter(Map arguments, - FunctionName func) { + public static MLAlgoParams convertArgumentToMLParameter(Map arguments, FunctionName func) { switch (func) { case KMEANS: return buildKMeansParameters(arguments); @@ -66,45 +63,46 @@ public static MLAlgoParams convertArgumentToMLParameter(Map argu case FIT_RCF: return buildFitRCFParameters(arguments); default: - throw new IllegalArgumentException( - String.format("unsupported algorithm: %s.", func)); + throw new IllegalArgumentException(String.format("unsupported algorithm: %s.", func)); } } private static MLAlgoParams buildKMeansParameters(Map arguments) { - return KMeansParams.builder() - .centroids((Integer) arguments.get(KM_CENTROIDS)) - .iterations((Integer) arguments.get(KM_ITERATIONS)) - .distanceType(arguments.containsKey(KM_DISTANCE_TYPE) - ? KMeansParams.DistanceType.valueOf(( - (String) arguments.get(KM_DISTANCE_TYPE)).toUpperCase(Locale.ROOT)) - : null) - .build(); + return KMeansParams + .builder() + .centroids((Integer) arguments.get(KM_CENTROIDS)) + .iterations((Integer) arguments.get(KM_ITERATIONS)) + .distanceType( + arguments.containsKey(KM_DISTANCE_TYPE) + ? KMeansParams.DistanceType.valueOf(((String) arguments.get(KM_DISTANCE_TYPE)).toUpperCase(Locale.ROOT)) + : null + ) + .build(); } private static MLAlgoParams buildBatchRCFParameters(Map arguments) { - return BatchRCFParams.builder() - .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) - .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) - .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) - .trainingDataSize((Integer) arguments.get(AD_TRAINING_DATA_SIZE)) - .anomalyScoreThreshold((Double) arguments.get(AD_ANOMALY_SCORE_THRESHOLD)) - .build(); + return BatchRCFParams + .builder() + .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) + .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) + .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) + .trainingDataSize((Integer) arguments.get(AD_TRAINING_DATA_SIZE)) + .anomalyScoreThreshold((Double) arguments.get(AD_ANOMALY_SCORE_THRESHOLD)) + .build(); } private static MLAlgoParams buildFitRCFParameters(Map arguments) { - return FitRCFParams.builder() - .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) - .shingleSize((Integer) arguments.get(AD_SHINGLE_SIZE)) - .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) - .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) - .timeDecay((Double) arguments.get(AD_TIME_DECAY)) - .anomalyRate((Double) arguments.get(AD_ANOMALY_RATE)) - .timeField((String) arguments.get(AD_TIME_FIELD)) - .dateFormat(arguments.containsKey(AD_DATE_FORMAT) - ? ((String) arguments.get(AD_DATE_FORMAT)) - : "yyyy-MM-dd HH:mm:ss") - .timeZone((String) arguments.get(AD_TIME_ZONE)) - .build(); + return FitRCFParams + .builder() + .numberOfTrees((Integer) arguments.get(AD_NUMBER_OF_TREES)) + .shingleSize((Integer) arguments.get(AD_SHINGLE_SIZE)) + .sampleSize((Integer) arguments.get(AD_SAMPLE_SIZE)) + .outputAfter((Integer) arguments.get(AD_OUTPUT_AFTER)) + .timeDecay((Double) arguments.get(AD_TIME_DECAY)) + .anomalyRate((Double) arguments.get(AD_ANOMALY_RATE)) + .timeField((String) arguments.get(AD_TIME_FIELD)) + .dateFormat(arguments.containsKey(AD_DATE_FORMAT) ? ((String) arguments.get(AD_DATE_FORMAT)) : "yyyy-MM-dd HH:mm:ss") + .timeZone((String) arguments.get(AD_TIME_ZONE)) + .build(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 574f13e9c3..578a29eec0 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -5,31 +5,32 @@ package org.opensearch.ml.common.input; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DefaultDataFrame; import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.parameter.MLAlgoParams; +import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; /** * ML input data: algorithm name, parameters and input data set. @@ -73,8 +74,14 @@ public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset i this.inputDataset = inputDataset; } - public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, - List sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) { + public MLInput( + FunctionName algorithm, + MLAlgoParams parameters, + SearchSourceBuilder searchSourceBuilder, + List sourceIndices, + DataFrame dataFrame, + MLInputDataset inputDataset + ) { validate(algorithm); this.algorithm = algorithm; this.parameters = parameters; @@ -130,12 +137,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (inputDataset != null) { switch (inputDataset.getInputDataType()) { case SEARCH_QUERY: - builder.field(INPUT_INDEX_FIELD, ((SearchQueryInputDataset)inputDataset).getIndices().toArray(new String[0])); - builder.field(INPUT_QUERY_FIELD, ((SearchQueryInputDataset)inputDataset).getSearchSourceBuilder()); + builder.field(INPUT_INDEX_FIELD, ((SearchQueryInputDataset) inputDataset).getIndices().toArray(new String[0])); + builder.field(INPUT_QUERY_FIELD, ((SearchQueryInputDataset) inputDataset).getSearchSourceBuilder()); break; case DATA_FRAME: builder.startObject(INPUT_DATA_FIELD); - ((DataFrameInputDataset)inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS); + ((DataFrameInputDataset) inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS); builder.endObject(); break; case TEXT_DOCS: @@ -171,7 +178,8 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws FunctionName algorithm = FunctionName.from(algorithmName); if (MLCommonsClassLoader.canInitMLInput(algorithm)) { - MLInput mlInput = MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class); + MLInput mlInput = MLCommonsClassLoader + .initMLInput(algorithm, new Object[] { parser, algorithm }, XContentParser.class, FunctionName.class); mlInput.setAlgorithm(algorithm); return mlInput; } @@ -239,7 +247,9 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws } } MLInputDataset inputDataSet = null; - if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) { + if (algorithm == FunctionName.TEXT_EMBEDDING + || algorithm == FunctionName.SPARSE_ENCODING + || algorithm == FunctionName.SPARSE_TOKENIZE) { ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions); inputDataSet = new TextDocsInputDataSet(textDocs, filter); } diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java index 6383bf6646..823763aa9d 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInput.java @@ -5,22 +5,25 @@ package org.opensearch.ml.common.input.execute.anomalylocalization; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Optional; +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; -import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorFactories; @@ -28,13 +31,10 @@ import lombok.AllArgsConstructor; import lombok.Data; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; - /** * Information about aggregate, time, etc to localize. */ -@ExecuteInput(algorithms={FunctionName.ANOMALY_LOCALIZATION}) +@ExecuteInput(algorithms = { FunctionName.ANOMALY_LOCALIZATION }) @Data @AllArgsConstructor public class AnomalyLocalizationInput implements Input { @@ -50,9 +50,9 @@ public class AnomalyLocalizationInput implements Input { public static final String FIELD_ANOMALY_START_TIME = "anomaly_start_time"; public static final String FIELD_FILTER_QUERY = "filter_query"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY_ENTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(FunctionName.ANOMALY_LOCALIZATION.name()), - parser -> parse(parser) + Input.class, + new ParseField(FunctionName.ANOMALY_LOCALIZATION.name()), + parser -> parse(parser) ); public static AnomalyLocalizationInput parse(XContentParser parser) throws IOException { @@ -124,9 +124,18 @@ public static AnomalyLocalizationInput parse(XContentParser parser) throws IOExc break; } } - return new AnomalyLocalizationInput(indexName, attributeFieldNames, aggregations, timeFieldName, startTime, endTime, - minTimeInterval, numOutputs, - anomalyStartTime, filterQuery); + return new AnomalyLocalizationInput( + indexName, + attributeFieldNames, + aggregations, + timeFieldName, + startTime, + endTime, + minTimeInterval, + numOutputs, + anomalyStartTime, + filterQuery + ); } private final String indexName; // name pattern of the data index diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java index 3de3cee60d..8a8713cd47 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInput.java @@ -5,11 +5,15 @@ package org.opensearch.ml.common.input.execute.metricscorrelation; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +21,17 @@ import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; -@ExecuteInput(algorithms={FunctionName.METRICS_CORRELATION}) +@ExecuteInput(algorithms = { FunctionName.METRICS_CORRELATION }) @Data public class MetricsCorrelationInput implements Input { public static final String PARSE_FIELD_NAME = FunctionName.METRICS_CORRELATION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + Input.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String METRICS_FIELD = "metrics"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java index a4d08fb69f..91c046e99b 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInput.java @@ -5,32 +5,33 @@ package org.opensearch.ml.common.input.execute.samplecalculator; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteInput; import org.opensearch.ml.common.input.Input; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; -@ExecuteInput(algorithms={FunctionName.LOCAL_SAMPLE_CALCULATOR}) +@ExecuteInput(algorithms = { FunctionName.LOCAL_SAMPLE_CALCULATOR }) @Data public class LocalSampleCalculatorInput implements Input { public static final String PARSE_FIELD_NAME = FunctionName.LOCAL_SAMPLE_CALCULATOR.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - Input.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + Input.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OPERATION_FIELD = "operation"; @@ -87,7 +88,7 @@ public LocalSampleCalculatorInput(StreamInput in) throws IOException { this.operation = in.readString(); int size = in.readInt(); this.inputData = new ArrayList<>(); - for (int i = 0; i parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String KERNEL_FIELD = "kernel"; @@ -47,9 +48,16 @@ public class AnomalyDetectionLibSVMParams implements MLAlgoParams { private Double epsilon; private Integer degree; - @Builder(toBuilder = true) - public AnomalyDetectionLibSVMParams(ADKernelType kernelType, Double gamma, Double nu, Double cost, Double coeff, Double epsilon, Integer degree) { + public AnomalyDetectionLibSVMParams( + ADKernelType kernelType, + Double gamma, + Double nu, + Double cost, + Double coeff, + Double epsilon, + Integer degree + ) { this.kernelType = kernelType; this.gamma = gamma; this.nu = nu; @@ -176,7 +184,7 @@ public enum ADKernelType { SIGMOID; public static ADKernelType from(String value) { - try{ + try { return ADKernelType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong AD kernel type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java index 73fff86f94..39d3684a92 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -18,32 +20,31 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.KMEANS}) +@MLAlgoParameter(algorithms = { FunctionName.KMEANS }) public class KMeansParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.KMEANS.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String CENTROIDS_FIELD = "centroids"; public static final String ITERATIONS_FIELD = "iterations"; public static final String DISTANCE_TYPE_FIELD = "distance_type"; - //The number of centroids to use. + // The number of centroids to use. private Integer centroids; - //The maximum number of iterations + // The maximum number of iterations private Integer iterations; - //The distance function. + // The distance function. private DistanceType distanceType; - //TODO: expose number of thread and seed? + // TODO: expose number of thread and seed? @Builder(toBuilder = true) public KMeansParams(Integer centroids, Integer iterations, DistanceType distanceType) { diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java index b23461428c..3956514b47 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,18 +19,17 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.RCF_SUMMARIZE}) +@MLAlgoParameter(algorithms = { FunctionName.RCF_SUMMARIZE }) public class RCFSummarizeParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.RCF_SUMMARIZE.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String MAX_K_FIELD = "max_k"; @@ -37,7 +38,7 @@ public class RCFSummarizeParams implements MLAlgoParams { public static final String PHASE1_REASSIGN_FIELD = "phase1_reassign"; public static final String PARALLEL__FIELD = "parallel"; - // The max of K allowed + // The max of K allowed private Integer maxK; // The initial K used private Integer initialK; diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java index 3c284a51a6..fc23c13933 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParams.java @@ -5,30 +5,31 @@ package org.opensearch.ml.common.input.parameter.rcf; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.BATCH_RCF}) +@MLAlgoParameter(algorithms = { FunctionName.BATCH_RCF }) public class BatchRCFParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.BATCH_RCF.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String NUMBER_OF_TREES = "number_of_trees"; @@ -45,12 +46,14 @@ public class BatchRCFParams implements MLAlgoParams { private Double anomalyScoreThreshold; @Builder - public BatchRCFParams(Integer numberOfTrees, - Integer shingleSize, - Integer sampleSize, - Integer outputAfter, - Integer trainingDataSize, - Double anomalyScoreThreshold) { + public BatchRCFParams( + Integer numberOfTrees, + Integer shingleSize, + Integer sampleSize, + Integer outputAfter, + Integer trainingDataSize, + Double anomalyScoreThreshold + ) { this.numberOfTrees = numberOfTrees; this.shingleSize = shingleSize; this.sampleSize = sampleSize; @@ -115,8 +118,7 @@ public static BatchRCFParams parse(XContentParser parser) throws IOException { break; } } - return new BatchRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, - trainingDataSize, anomalyScoreThreshold); + return new BatchRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, trainingDataSize, anomalyScoreThreshold); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java index d55fd57735..59ae8f7037 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParams.java @@ -5,30 +5,31 @@ package org.opensearch.ml.common.input.parameter.rcf; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.FIT_RCF}) +@MLAlgoParameter(algorithms = { FunctionName.FIT_RCF }) public class FitRCFParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.FIT_RCF.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String NUMBER_OF_TREES = "number_of_trees"; @@ -51,15 +52,17 @@ public class FitRCFParams implements MLAlgoParams { private String timeZone; @Builder - public FitRCFParams(Integer numberOfTrees, - Integer shingleSize, - Integer sampleSize, - Integer outputAfter, - Double timeDecay, - Double anomalyRate, - String timeField, - String dateFormat, - String timeZone) { + public FitRCFParams( + Integer numberOfTrees, + Integer shingleSize, + Integer sampleSize, + Integer outputAfter, + Double timeDecay, + Double anomalyRate, + String timeField, + String dateFormat, + String timeZone + ) { this.numberOfTrees = numberOfTrees; this.shingleSize = shingleSize; this.sampleSize = sampleSize; @@ -145,8 +148,17 @@ public static FitRCFParams parse(XContentParser parser) throws IOException { break; } } - return new FitRCFParams(numberOfTrees, shingleSize, sampleSize, outputAfter, - timeDecay, anomalyRate, timeField, dateFormat, timeZone); + return new FitRCFParams( + numberOfTrees, + shingleSize, + sampleSize, + outputAfter, + timeDecay, + anomalyRate, + timeField, + dateFormat, + timeZone + ); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java index 9e9cb7f129..9ea9d88959 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParams.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +20,18 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.LINEAR_REGRESSION}) +@MLAlgoParameter(algorithms = { FunctionName.LINEAR_REGRESSION }) public class LinearRegressionParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.LINEAR_REGRESSION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OBJECTIVE_FIELD = "objective"; @@ -64,7 +65,22 @@ public class LinearRegressionParams implements MLAlgoParams { private String target; @Builder(toBuilder = true) - public LinearRegressionParams(ObjectiveType objectiveType, OptimizerType optimizerType, Double learningRate, MomentumType momentumType, Double momentumFactor, Double epsilon, Double beta1, Double beta2, Double decayRate, Integer epochs, Integer batchSize, Integer loggingInterval, Long seed, String target) { + public LinearRegressionParams( + ObjectiveType objectiveType, + OptimizerType optimizerType, + Double learningRate, + MomentumType momentumType, + Double momentumFactor, + Double epsilon, + Double beta1, + Double beta2, + Double decayRate, + Integer epochs, + Integer batchSize, + Integer loggingInterval, + Long seed, + String target + ) { this.objectiveType = objectiveType; this.optimizerType = optimizerType; this.learningRate = learningRate; @@ -173,7 +189,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { break; } } - return new LinearRegressionParams(objective, optimizerType, learningRate, momentumType, momentumFactor, epsilon, beta1, beta2,decayRate, epochs, batchSize, loggingInterval, seed, target); + return new LinearRegressionParams( + objective, + optimizerType, + learningRate, + momentumType, + momentumFactor, + epsilon, + beta1, + beta2, + decayRate, + epochs, + batchSize, + loggingInterval, + seed, + target + ); } @Override @@ -272,8 +303,9 @@ public enum ObjectiveType { SQUARED_LOSS, ABSOLUTE_LOSS, HUBER; + public static ObjectiveType from(String value) { - try{ + try { return ObjectiveType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong objective type"); @@ -286,7 +318,7 @@ public enum MomentumType { NESTEROV; public static MomentumType from(String value) { - try{ + try { return MomentumType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong momentum type"); @@ -304,7 +336,7 @@ public enum OptimizerType { RMS_PROP; public static OptimizerType from(String value) { - try{ + try { return OptimizerType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong optimizer type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java index 3340050ff5..d4238c1111 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,20 +20,18 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.LOGISTIC_REGRESSION}) +@MLAlgoParameter(algorithms = { FunctionName.LOGISTIC_REGRESSION }) public class LogisticRegressionParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.LOGISTIC_REGRESSION.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String OBJECTIVE_FIELD = "objective"; @@ -188,7 +189,22 @@ public static MLAlgoParams parse(XContentParser parser) throws IOException { break; } } - return new LogisticRegressionParams(objective, optimizerType, momentumType, learningRate, epsilon, momentumFactor, beta1, beta2, decayRate, epochs, batchSize, loggingInterval, seed, target); + return new LogisticRegressionParams( + objective, + optimizerType, + momentumType, + learningRate, + epsilon, + momentumFactor, + beta1, + beta2, + decayRate, + epochs, + batchSize, + loggingInterval, + seed, + target + ); } @Override @@ -286,8 +302,9 @@ public int getVersion() { public enum ObjectiveType { HINGE, LOGMULTICLASS; + public static ObjectiveType from(String value) { - try{ + try { return ObjectiveType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong objective type"); @@ -300,7 +317,7 @@ public enum MomentumType { NESTEROV; public static MomentumType from(String value) { - try{ + try { return MomentumType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong momentum type"); @@ -318,7 +335,7 @@ public enum OptimizerType { RMS_PROP; public static OptimizerType from(String value) { - try{ + try { return OptimizerType.valueOf(value); } catch (Exception e) { throw new IllegalArgumentException("Wrong optimizer type"); diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java index 2544a748f5..7fc8c8be38 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParams.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.input.parameter.sample; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -17,18 +19,17 @@ import org.opensearch.ml.common.annotation.MLAlgoParameter; import org.opensearch.ml.common.input.parameter.MLAlgoParams; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -@MLAlgoParameter(algorithms={FunctionName.SAMPLE_ALGO}) +@MLAlgoParameter(algorithms = { FunctionName.SAMPLE_ALGO }) public class SampleAlgoParams implements MLAlgoParams { public static final String PARSE_FIELD_NAME = FunctionName.SAMPLE_ALGO.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - MLAlgoParams.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + MLAlgoParams.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String SAMPLE_PARAM_FIELD = "sample_param"; diff --git a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java index da4a9ad73d..cc3b2cdb5c 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInput.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.remote; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; @@ -13,12 +18,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.utils.StringUtils; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.REMOTE}) +@org.opensearch.ml.common.annotation.MLInput(functionNames = { FunctionName.REMOTE }) public class RemoteInferenceMLInput extends MLInput { public static final String PARAMETERS_FIELD = "parameters"; diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java index 2fb07b6d8e..67f13e4f62 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLModelConfig.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.model; -import lombok.Getter; -import lombok.Setter; +import java.io.IOException; + import org.opensearch.core.common.io.stream.NamedWriteable; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; -import java.io.IOException; +import lombok.Getter; +import lombok.Setter; @Setter @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java b/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java index cfd06be1f0..3bb84c6bbd 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MLModelState.java @@ -37,4 +37,4 @@ public static MLModelState from(String value) { throw new IllegalArgumentException("Wrong model state"); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java index e1c9203cae..b690bd0342 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfig.java @@ -5,18 +5,18 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter @@ -29,7 +29,7 @@ public MetricsCorrelationModelConfig(String modelType, String allConfig) { super(modelType, allConfig); } - public MetricsCorrelationModelConfig(StreamInput in) throws IOException{ + public MetricsCorrelationModelConfig(StreamInput in) throws IOException { super(in); } diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index dbb15fa2d6..77c324ca8d 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -5,31 +5,32 @@ package org.opensearch.ml.common.model; -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import java.io.IOException; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; @Setter @Getter public class TextEmbeddingModelConfig extends MLModelConfig { public static final String PARSE_FIELD_NAME = FunctionName.TEXT_EMBEDDING.name(); public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - TextEmbeddingModelConfig.class, - new ParseField(PARSE_FIELD_NAME), - it -> parse(it) + TextEmbeddingModelConfig.class, + new ParseField(PARSE_FIELD_NAME), + it -> parse(it) ); public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; @@ -45,8 +46,15 @@ public class TextEmbeddingModelConfig extends MLModelConfig { private final Integer modelMaxLength; @Builder(toBuilder = true) - public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, - PoolingMode poolingMode, boolean normalizeResult, Integer modelMaxLength) { + public TextEmbeddingModelConfig( + String modelType, + Integer embeddingDimension, + FrameworkType frameworkType, + String allConfig, + PoolingMode poolingMode, + boolean normalizeResult, + Integer modelMaxLength + ) { super(modelType, allConfig); if (embeddingDimension == null) { throw new IllegalArgumentException("embedding dimension is null"); @@ -102,7 +110,15 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc break; } } - return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMode, normalizeResult, modelMaxLength); + return new TextEmbeddingModelConfig( + modelType, + embeddingDimension, + frameworkType, + allConfig, + poolingMode, + normalizeResult, + modelMaxLength + ); } @Override @@ -110,7 +126,7 @@ public String getWriteableName() { return PARSE_FIELD_NAME; } - public TextEmbeddingModelConfig(StreamInput in) throws IOException{ + public TextEmbeddingModelConfig(StreamInput in) throws IOException { super(in); embeddingDimension = in.readInt(); frameworkType = in.readEnum(FrameworkType.class); @@ -179,6 +195,7 @@ public enum PoolingMode { public String getName() { return name; } + PoolingMode(String name) { this.name = name; } @@ -191,6 +208,7 @@ public static PoolingMode from(String value) { } } } + public enum FrameworkType { HUGGINGFACE_TRANSFORMERS, SENTENCE_TRANSFORMERS, diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java index 83fbfe1cc1..d967059892 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLOutput.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.output; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.MLCommonsClassLoader; -import java.io.IOException; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; /** * ML output data. Must specify output type and diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java index 28b5b07821..5675dab409 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLPredictionOutput.java @@ -5,10 +5,8 @@ package org.opensearch.ml.common.output; -import lombok.Builder; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.ToString; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -17,10 +15,13 @@ import org.opensearch.ml.common.dataframe.DataFrameType; import org.opensearch.ml.common.dataframe.DefaultDataFrame; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; @Data -@EqualsAndHashCode(callSuper=false) +@EqualsAndHashCode(callSuper = false) @MLAlgoOutput(MLOutputType.PREDICTION) public class MLPredictionOutput extends MLOutput { diff --git a/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java b/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java index c69bb9ca74..c6bb98c73f 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/MLTrainingOutput.java @@ -5,15 +5,16 @@ package org.opensearch.ml.common.output; -import lombok.Builder; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.annotation.MLAlgoOutput; -import java.io.IOException; +import lombok.Builder; +import lombok.Getter; @Getter @MLAlgoOutput(MLOutputType.TRAINING) @@ -32,7 +33,7 @@ public MLTrainingOutput(String modelId, String taskId, String status) { super(OUTPUT_TYPE); this.modelId = modelId; this.taskId = taskId; - this.status= status; + this.status = status; } public MLTrainingOutput(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java b/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java index a4c6f5963b..490ea101f1 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutput.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.output.execute.anomalylocalization; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; @@ -15,25 +17,23 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.SneakyThrows; -import lombok.ToString; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.annotation.ExecuteOutput; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.annotation.ExecuteOutput; import org.opensearch.ml.common.output.Output; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; +import lombok.SneakyThrows; +import lombok.ToString; /** * Output of localized results. */ -@ExecuteOutput(algorithms={FunctionName.ANOMALY_LOCALIZATION}) +@ExecuteOutput(algorithms = { FunctionName.ANOMALY_LOCALIZATION }) @Data @NoArgsConstructor public class AnomalyLocalizationOutput implements Output { @@ -102,8 +102,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par */ @Data @NoArgsConstructor - @ToString(exclude = {"base", "counter", "completed"}) - @EqualsAndHashCode(exclude = {"base", "counter", "completed"}) + @ToString(exclude = { "base", "counter", "completed" }) + @EqualsAndHashCode(exclude = { "base", "counter", "completed" }) public static class Bucket implements Output { public static final String FIELD_START_TIME = "start_time"; @@ -134,7 +134,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(startTime); out.writeLong(endTime); out.writeDouble(overallAggValue); - if (entities == null) { + if (entities == null) { out.writeBoolean(false); } else { out.writeBoolean(true); diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java index a8dc54481b..2783787cc4 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensor.java @@ -5,15 +5,16 @@ package org.opensearch.ml.common.output.execute.metrics_correlation; -import lombok.Builder; -import lombok.Data; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; @Data public class MCorrModelTensor implements Writeable, ToXContentObject { diff --git a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java index d26a9e8b0e..5ebcc41248 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/execute/metrics_correlation/MCorrModelTensors.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.output.execute.metrics_correlation; -import lombok.Builder; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; -import org.opensearch.core.common.bytes.BytesReference; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -18,10 +20,9 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelResultFilter; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; +import lombok.Builder; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Log4j2 @Getter @@ -48,7 +49,7 @@ public MCorrModelTensors(StreamInput in) throws IOException { if (in.readBoolean()) { mCorrModelTensors = new ArrayList<>(); int size = in.readInt(); - for (int i=0; i targetResponse = resultFilter.getTargetResponse(); List targetResponsePositions = resultFilter.getTargetResponsePositions(); if ((targetResponse == null || targetResponse.size() == 0) - && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { - mCorrModelTensors.forEach(output -> filter(output, returnNumber)); + && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { + mCorrModelTensors.forEach(output -> filter(output, returnNumber)); return; } List targetOutput = new ArrayList<>(); if (mCorrModelTensors != null) { - for (int i = 0 ; i(); int size = in.readInt(); - for (int i=0; i targetResponsePositions; @Builder - public ModelResultFilter(boolean returnBytes, - boolean returnNumber, - List targetResponse, - List targetResponsePositions + public ModelResultFilter( + boolean returnBytes, + boolean returnNumber, + List targetResponse, + List targetResponsePositions ) { this.returnBytes = returnBytes; this.returnNumber = returnNumber; @@ -65,7 +67,7 @@ public ModelResultFilter(StreamInput streamInput) throws IOException { if (streamInput.readBoolean()) { int size = streamInput.readInt(); targetResponsePositions = new ArrayList<>(); - for (int i=0;i dataAsMap;// whole result in Map @Builder - public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType dataType, ByteBuffer byteBuffer, String result, Map dataAsMap) { + public ModelTensor( + String name, + Number[] data, + long[] shape, + MLResultDataType dataType, + ByteBuffer byteBuffer, + String result, + Map dataAsMap + ) { if (data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) { throw new IllegalArgumentException("data type is null"); } @@ -175,14 +184,7 @@ public static ModelTensor parser(XContentParser parser) throws IOException { data[i] = (Number) dataList.get(i); } } - return ModelTensor.builder() - .name(name) - .shape(shape) - .dataType(dataType) - .data(data) - .result(result) - .dataAsMap(dataAsMap) - .build(); + return ModelTensor.builder().name(name).shape(shape).dataType(dataType).data(data).result(result).dataAsMap(dataAsMap).build(); } public ModelTensor(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java index 664bd3510f..32f3318718 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensorOutput.java @@ -5,9 +5,10 @@ package org.opensearch.ml.common.output.model; -import lombok.Builder; -import lombok.Data; -import lombok.EqualsAndHashCode; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentBuilder; @@ -15,12 +16,12 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLOutputType; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; @Data -@EqualsAndHashCode(callSuper=false) +@EqualsAndHashCode(callSuper = false) @MLAlgoOutput(MLOutputType.MODEL_TENSOR) public class ModelTensorOutput extends MLOutput { private static final MLOutputType OUTPUT_TYPE = MLOutputType.MODEL_TENSOR; @@ -34,13 +35,12 @@ public ModelTensorOutput(List mlModelOutputs) { this.mlModelOutputs = mlModelOutputs; } - public ModelTensorOutput(StreamInput in) throws IOException { super(OUTPUT_TYPE); if (in.readBoolean()) { mlModelOutputs = new ArrayList<>(); int size = in.readInt(); - for (int i=0; i(); int size = in.readInt(); - for (int i=0; i targetResponse = resultFilter.getTargetResponse(); List targetResponsePositions = resultFilter.getTargetResponsePositions(); if ((targetResponse == null || targetResponse.size() == 0) - && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { - mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber)); + && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { + mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber)); return; } List targetOutput = new ArrayList<>(); if (mlModelTensors != null) { - for (int i = 0 ; i { public static final MLConnectorDeleteAction INSTANCE = new MLConnectorDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/delete"; - private MLConnectorDeleteAction() { super(NAME, DeleteResponse::new);} + private MLConnectorDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java index 9da5be98aa..a1e3a6391e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLConnectorDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLConnectorDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLConnectorDeleteRequest) { - return (MLConnectorDeleteRequest)actionRequest; + return (MLConnectorDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java index da29dd86fe..6695e2ada1 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetAction.java @@ -11,6 +11,8 @@ public class MLConnectorGetAction extends ActionType { public static final MLConnectorGetAction INSTANCE = new MLConnectorGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/get"; - private MLConnectorGetAction() { super(NAME, MLConnectorGetResponse::new);} + private MLConnectorGetAction() { + super(NAME, MLConnectorGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java index 118a70ccde..53c6c9c497 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLConnectorGetRequest extends ActionRequest { @@ -62,8 +63,7 @@ public static MLConnectorGetRequest fromActionRequest(ActionRequest actionReques return (MLConnectorGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java index 492566d20a..dbd7c9b42c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -16,10 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.connector.Connector; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; public class MLConnectorGetResponse extends ActionResponse implements ToXContentObject { Connector mlConnector; @@ -35,7 +36,7 @@ public MLConnectorGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlConnector.writeTo(out); } @@ -49,8 +50,7 @@ public static MLConnectorGetResponse fromActionResponse(ActionResponse actionRes return (MLConnectorGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLConnectorGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index 9d9879daec..d0176440d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -5,8 +5,15 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -17,14 +24,8 @@ import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.connector.ConnectorAction; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; +import lombok.Builder; +import lombok.Data; @Data public class MLCreateConnectorInput implements ToXContentObject, Writeable { @@ -59,18 +60,19 @@ public class MLCreateConnectorInput implements ToXContentObject, Writeable { private boolean updateConnector = false; @Builder(toBuilder = true) - public MLCreateConnectorInput(String name, - String description, - String version, - String protocol, - Map parameters, - Map credential, - List actions, - List backendRoles, - Boolean addAllBackendRoles, - AccessMode access, - boolean dryRun, - boolean updateConnector + public MLCreateConnectorInput( + String name, + String description, + String version, + String protocol, + Map parameters, + Map credential, + List actions, + List backendRoles, + Boolean addAllBackendRoles, + AccessMode access, + boolean dryRun, + boolean updateConnector ) { if (!dryRun && !updateConnector) { if (name == null) { @@ -166,7 +168,20 @@ public static MLCreateConnectorInput parse(XContentParser parser, boolean update break; } } - return new MLCreateConnectorInput(name, description, version, protocol, parameters, credential, actions, backendRoles, addAllBackendRoles, access, dryRun, updateConnector); + return new MLCreateConnectorInput( + name, + description, + version, + protocol, + parameters, + credential, + actions, + backendRoles, + addAllBackendRoles, + access, + dryRun, + updateConnector + ); } @Override @@ -259,7 +274,7 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { parameters = input.readMap(s -> s.readString(), s -> s.readString()); } if (input.readBoolean()) { - credential = input.readMap(s -> s.readString(), s-> s.readString()); + credential = input.readMap(s -> s.readString(), s -> s.readString()); } if (input.readBoolean()) { actions = new ArrayList<>(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java index 107d5001b8..e227c30478 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLCreateConnectorRequest extends ActionRequest { @@ -56,8 +57,7 @@ public static MLCreateConnectorRequest fromActionRequest(ActionRequest actionReq return (MLCreateConnectorRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateConnectorRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java index 68ce877baa..08b1631853 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLCreateConnectorResponse extends ActionResponse implements ToXContentObject { @@ -53,8 +54,7 @@ public static MLCreateConnectorResponse fromActionResponse(ActionResponse action return (MLCreateConnectorResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLCreateConnectorResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java index 9fa10a39c6..8609af2134 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorAction.java @@ -12,5 +12,7 @@ public class MLUpdateConnectorAction extends ActionType { public static final MLUpdateConnectorAction INSTANCE = new MLUpdateConnectorAction(); public static final String NAME = "cluster:admin/opensearch/ml/connectors/update"; - private MLUpdateConnectorAction() { super(NAME, UpdateResponse::new);} + private MLUpdateConnectorAction() { + super(NAME, UpdateResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java index 089180cdc5..8a365140de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.connector; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -15,12 +20,8 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.XContentParser; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; @Getter public class MLUpdateConnectorRequest extends ActionRequest { @@ -72,8 +73,7 @@ public static MLUpdateConnectorRequest fromActionRequest(ActionRequest actionReq return (MLUpdateConnectorRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateConnectorRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java index 6831369b2b..d8ae7ab829 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInput.java @@ -5,15 +5,15 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Builder; -import lombok.Data; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.MLTask; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; @Data public class MLDeployModelInput implements Writeable { @@ -36,7 +36,15 @@ public MLDeployModelInput(StreamInput in) throws IOException { } @Builder - public MLDeployModelInput(String modelId, String taskId, String modelContentHash, Integer nodeCount, String coordinatingNodeId, Boolean isDeployToAllNodes, MLTask mlTask) { + public MLDeployModelInput( + String modelId, + String taskId, + String modelContentHash, + Integer nodeCount, + String coordinatingNodeId, + Boolean isDeployToAllNodes, + MLTask mlTask + ) { this.modelId = modelId; this.taskId = taskId; this.modelContentHash = modelContentHash; @@ -46,8 +54,7 @@ public MLDeployModelInput(String modelId, String taskId, String modelContentHash this.mlTask = mlTask; } - public MLDeployModelInput() { - } + public MLDeployModelInput() {} @Override public void writeTo(StreamOutput out) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java index 1d99f47d15..0edf1b57d5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; -import org.opensearch.transport.TransportRequest; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; -import java.io.IOException; +import lombok.Getter; public class MLDeployModelNodeRequest extends TransportRequest { @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java index e2a0cbf084..ac9879e789 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponse.java @@ -5,7 +5,9 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -13,8 +15,8 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.extern.log4j.Log4j2; + @Log4j2 public class MLDeployModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { @@ -27,6 +29,7 @@ public MLDeployModelNodeResponse(DiscoveryNode node, Map modelDe super(node); this.modelDeployStatus = modelDeployStatus; } + /** * Constructor * @@ -36,7 +39,7 @@ public MLDeployModelNodeResponse(DiscoveryNode node, Map modelDe public MLDeployModelNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelDeployStatus = in.readMap(s -> s.readString(), s-> s.readString()); + this.modelDeployStatus = in.readMap(s -> s.readString(), s -> s.readString()); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java index e2c8043b04..5f5c347dac 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLDeployModelNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java index c27abebfaf..be8d5cc1ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.deploy; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -14,9 +17,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLDeployModelNodesResponse extends BaseNodesResponse implements ToXContentObject { /** diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java index b0ad113d95..5aa155c81e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequest.java @@ -5,11 +5,16 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,15 +24,11 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.action.ValidateActions.addValidationError; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -104,8 +105,7 @@ public static MLDeployModelRequest fromActionRequest(ActionRequest actionRequest return (MLDeployModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLDeployModelRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java index ca35af68f0..eefcee7de5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.deploy; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -16,10 +20,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLDeployModelResponse extends ActionResponse implements ToXContentObject { @@ -41,7 +42,7 @@ public MLDeployModelResponse(StreamInput in) throws IOException { public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) { this.taskId = taskId; this.taskType = mlTaskType; - this.status= status; + this.status = status; } @Override @@ -68,8 +69,7 @@ public static MLDeployModelResponse fromActionResponse(ActionResponse actionResp return (MLDeployModelResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLDeployModelResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java index e772b78d2d..d998ea71de 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java @@ -5,29 +5,30 @@ package org.opensearch.ml.common.transport.execute; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -64,7 +65,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if(this.input == null) { + if (this.input == null) { exception = addValidationError("ML input can't be null", exception); } else { if (this.input.getFunctionName() == null) { @@ -75,14 +76,12 @@ public ActionRequestValidationException validate() { return exception; } - public static MLExecuteTaskRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLExecuteTaskRequest) { return (MLExecuteTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLExecuteTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java index b97c02ca9c..1dccdb139f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java @@ -5,10 +5,11 @@ package org.opensearch.ml.common.transport.execute; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -20,10 +21,10 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.output.Output; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import lombok.ToString; @Getter @ToString @@ -62,8 +63,7 @@ public static MLExecuteTaskResponse fromActionResponse(ActionResponse actionResp return (MLExecuteTaskResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLExecuteTaskResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java index 624cec3c7d..902603cf03 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardInput.java @@ -5,9 +5,8 @@ package org.opensearch.ml.common.transport.forward; -import lombok.Builder; -import lombok.Data; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,7 +14,9 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.extern.log4j.Log4j2; @Data @Log4j2 @@ -32,9 +33,17 @@ public class MLForwardInput implements Writeable { private MLRegisterModelInput registerModelInput; @Builder(toBuilder = true) - public MLForwardInput(String taskId, String modelId, String workerNodeId, MLForwardRequestType requestType, - MLTask mlTask, MLInput modelInput, - String error, String[] workerNodes, MLRegisterModelInput registerModelInput) { + public MLForwardInput( + String taskId, + String modelId, + String workerNodeId, + MLForwardRequestType requestType, + MLTask mlTask, + MLInput modelInput, + String error, + String[] workerNodes, + MLRegisterModelInput registerModelInput + ) { this.taskId = taskId; this.modelId = modelId; this.workerNodeId = workerNodeId; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java index 7d2949fd3a..c029e81bd2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.forward; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -18,12 +19,12 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -64,8 +65,7 @@ public static MLForwardRequest fromActionRequest(ActionRequest actionRequest) { return (MLForwardRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLForwardRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java index ff51103671..f873c8a4b9 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/forward/MLForwardResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.forward; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -18,10 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.output.MLOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -36,7 +37,6 @@ public MLForwardResponse(String status, MLOutput mlOutput) { this.mlOutput = mlOutput; } - public MLForwardResponse(StreamInput in) throws IOException { super(in); status = in.readOptionalString(); @@ -70,8 +70,7 @@ public static MLForwardResponse fromActionResponse(ActionResponse actionResponse return (MLForwardResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLForwardResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java index 6886fc57d6..8374eb5f9f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteAction.java @@ -12,5 +12,7 @@ public class MLModelDeleteAction extends ActionType { public static final MLModelDeleteAction INSTANCE = new MLModelDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/models/delete"; - private MLModelDeleteAction() { super(NAME, DeleteResponse::new);} + private MLModelDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java index a42cf1d071..4c57c5912c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLModelDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLModelDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelDeleteRequest) { - return (MLModelDeleteRequest)actionRequest; + return (MLModelDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java index 37e3831404..dd47e8cdee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetAction.java @@ -11,5 +11,7 @@ public class MLModelGetAction extends ActionType { public static final MLModelGetAction INSTANCE = new MLModelGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/models/get"; - private MLModelGetAction() { super(NAME, MLModelGetResponse::new);} + private MLModelGetAction() { + super(NAME, MLModelGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java index 71c5200971..8d333f37fc 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -64,11 +65,10 @@ public ActionRequestValidationException validate() { public static MLModelGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelGetRequest) { - return (MLModelGetRequest)actionRequest; + return (MLModelGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java index b9a1040474..ec91a4ea43 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLModelGetResponse.java @@ -5,9 +5,11 @@ package org.opensearch.ml.common.transport.model; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -18,10 +20,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLModel; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; @Getter @ToString @@ -34,14 +35,13 @@ public MLModelGetResponse(MLModel mlModel) { this.mlModel = mlModel; } - public MLModelGetResponse(StreamInput in) throws IOException { super(in); mlModel = mlModel.fromStream(in); } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlModel.writeTo(out); } @@ -55,8 +55,7 @@ public static MLModelGetResponse fromActionResponse(ActionResponse actionRespons return (MLModelGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index ca0a2f70d4..4d631cdb9f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -5,28 +5,26 @@ package org.opensearch.ml.common.transport.model; -import lombok.Data; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.util.Map; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.connector.Connector.createConnector; +import lombok.Builder; +import lombok.Data; +import lombok.Getter; @Data public class MLUpdateModelInput implements ToXContentObject, Writeable { - + public static final String MODEL_ID_FIELD = "model_id"; // mandatory public static final String DESCRIPTION_FIELD = "description"; // optional public static final String MODEL_VERSION_FIELD = "model_version"; // optional @@ -45,7 +43,15 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private String connectorId; @Builder(toBuilder = true) - public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, MLModelConfig modelConfig, String connectorId) { + public MLUpdateModelInput( + String modelId, + String description, + String version, + String name, + String modelGroupId, + MLModelConfig modelConfig, + String connectorId + ) { this.modelId = modelId; this.description = description; this.version = version; @@ -152,4 +158,4 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException // Model ID can only be set through RestRequest. Model version can only be set automatically. return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, modelConfig, connectorId); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java index b589f71ed4..61524689f7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,18 +19,17 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString public class MLUpdateModelRequest extends ActionRequest { - + MLUpdateModelInput updateModelInput; @Builder @@ -57,13 +58,12 @@ public void writeTo(StreamOutput out) throws IOException { this.updateModelInput.writeTo(out); } - public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest){ + public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLUpdateModelRequest) { return (MLUpdateModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput in = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateModelRequest(in); @@ -72,4 +72,4 @@ public static MLUpdateModelRequest fromActionRequest(ActionRequest actionRequest throw new UncheckedIOException("Failed to parse ActionRequest into MLUpdateModelRequest", e); } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java index 7acd877c3a..434ace5a63 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteAction.java @@ -12,5 +12,7 @@ public class MLModelGroupDeleteAction extends ActionType { public static final MLModelGroupDeleteAction INSTANCE = new MLModelGroupDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/model_groups/delete"; - private MLModelGroupDeleteAction() { super(NAME, DeleteResponse::new);} + private MLModelGroupDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java index 8c5326ab8d..86a1d093ee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLModelGroupDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLModelGroupDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLModelGroupDeleteRequest) { - return (MLModelGroupDeleteRequest)actionRequest; + return (MLModelGroupDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLModelGroupDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java index 9d8dd67050..8f4162f11f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java @@ -5,8 +5,14 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,22 +21,17 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import java.util.Objects; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLRegisterModelGroupInput implements ToXContentObject, Writeable { - public static final String NAME_FIELD = "name"; //mandatory - public static final String DESCRIPTION_FIELD = "description"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String MODEL_ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String NAME_FIELD = "name"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String MODEL_ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; // optional private String name; private String description; @@ -39,7 +40,13 @@ public class MLRegisterModelGroupInput implements ToXContentObject, Writeable { private Boolean isAddAllBackendRoles; @Builder(toBuilder = true) - public MLRegisterModelGroupInput(String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLRegisterModelGroupInput( + String name, + String description, + List backendRoles, + AccessMode modelAccessMode, + Boolean isAddAllBackendRoles + ) { this.name = Objects.requireNonNull(name, "model group name must not be null"); this.description = description; this.backendRoles = backendRoles; @@ -47,7 +54,7 @@ public MLRegisterModelGroupInput(String name, String description, List b this.isAddAllBackendRoles = isAddAllBackendRoles; } - public MLRegisterModelGroupInput(StreamInput in) throws IOException{ + public MLRegisterModelGroupInput(StreamInput in) throws IOException { this.name = in.readString(); this.description = in.readOptionalString(); this.backendRoles = in.readOptionalStringList(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java index 3bf3dabd03..0f3b8163f5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelGroupRequest fromActionRequest(ActionRequest action return (MLRegisterModelGroupRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelGroupRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java index 01c63d18de..83aace89f6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,10 +18,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLRegisterModelGroupResponse extends ActionResponse implements ToXContentObject { @@ -37,7 +38,7 @@ public MLRegisterModelGroupResponse(StreamInput in) throws IOException { public MLRegisterModelGroupResponse(String modelGroupId, String status) { this.modelGroupId = modelGroupId; - this.status= status; + this.status = status; } @Override @@ -60,8 +61,7 @@ public static MLRegisterModelGroupResponse fromActionResponse(ActionResponse act return (MLRegisterModelGroupResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelGroupResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java index 22e612a5b1..3dd92082c8 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,23 +20,18 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //mandatory - public static final String NAME_FIELD = "name"; //optional - public static final String DESCRIPTION_FIELD = "description"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String MODEL_ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; //optional - + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // mandatory + public static final String NAME_FIELD = "name"; // optional + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String MODEL_ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES_FIELD = "add_all_backend_roles"; // optional private String modelGroupID; private String name; @@ -41,7 +41,14 @@ public class MLUpdateModelGroupInput implements ToXContentObject, Writeable { private Boolean isAddAllBackendRoles; @Builder(toBuilder = true) - public MLUpdateModelGroupInput(String modelGroupID, String name, String description, List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) { + public MLUpdateModelGroupInput( + String modelGroupID, + String name, + String description, + List backendRoles, + AccessMode modelAccessMode, + Boolean isAddAllBackendRoles + ) { this.modelGroupID = modelGroupID; this.name = name; this.description = description; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java index aecb62a8d2..e3f103dcf3 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLUpdateModelGroupRequest fromActionRequest(ActionRequest actionRe return (MLUpdateModelGroupRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUpdateModelGroupRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java index 23bec3b0aa..fbe5795c4f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponse.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.transport.model_group; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; @Getter public class MLUpdateModelGroupResponse extends ActionResponse implements ToXContentObject { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/package-info.java b/common/src/main/java/org/opensearch/ml/common/transport/package-info.java index 77111bf8f4..d01f4f9512 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/package-info.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/package-info.java @@ -3,4 +3,4 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.transport; \ No newline at end of file +package org.opensearch.ml.common.transport; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java index 963892215f..c55713fddd 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.prediction; +import static org.opensearch.action.ValidateActions.addValidationError; + import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -import lombok.Setter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.commons.authuser.User; @@ -19,15 +20,14 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.transport.MLTaskRequest; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; +import lombok.Setter; import lombok.ToString; import lombok.experimental.FieldDefaults; -import org.opensearch.ml.common.transport.MLTaskRequest; - -import static org.opensearch.action.ValidateActions.addValidationError; @Getter @FieldDefaults(level = AccessLevel.PRIVATE) @@ -85,14 +85,12 @@ public ActionRequestValidationException validate() { return exception; } - public static MLPredictionTaskRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLPredictionTaskRequest) { return (MLPredictionTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLPredictionTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index a871332f95..ef17bba28b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -5,8 +5,14 @@ package org.opensearch.ml.common.transport.register; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.connector.Connector.createConnector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -21,14 +27,8 @@ import org.opensearch.ml.common.model.MetricsCorrelationModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - -import static org.opensearch.ml.common.connector.Connector.createConnector; +import lombok.Builder; +import lombok.Data; /** * ML input data: algirithm name, parameters and input data set. @@ -76,23 +76,24 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) - public MLRegisterModelInput(FunctionName functionName, - String modelName, - String modelGroupId, - String version, - String description, - String url, - String hashValue, - MLModelFormat modelFormat, - MLModelConfig modelConfig, - boolean deployModel, - String[] modelNodeIds, - Connector connector, - String connectorId, - List backendRoles, - Boolean addAllBackendRoles, - AccessMode accessMode, - Boolean doesVersionCreateModelGroup + public MLRegisterModelInput( + FunctionName functionName, + String modelName, + String modelGroupId, + String version, + String description, + String url, + String hashValue, + MLModelFormat modelFormat, + MLModelConfig modelConfig, + boolean deployModel, + String[] modelNodeIds, + Connector connector, + String connectorId, + List backendRoles, + Boolean addAllBackendRoles, + AccessMode accessMode, + Boolean doesVersionCreateModelGroup ) { if (functionName == null) { this.functionName = FunctionName.TEXT_EMBEDDING; @@ -106,7 +107,12 @@ public MLRegisterModelInput(FunctionName functionName, if (modelFormat == null) { throw new IllegalArgumentException("model format is null"); } - if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration. + if (url != null + && modelConfig == null + && functionName != FunctionName.SPARSE_TOKENIZE + && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, + // we only support one type of sparse model, which is pretrained, and it + // doesn't necessitate a model configuration. throw new IllegalArgumentException("model config is null"); } } @@ -128,7 +134,6 @@ public MLRegisterModelInput(FunctionName functionName, this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } - public MLRegisterModelInput(StreamInput in) throws IOException { this.functionName = in.readEnum(FunctionName.class); this.modelName = in.readString(); @@ -261,7 +266,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException { + public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) + throws IOException { FunctionName functionName = null; String modelGroupId = null; String url = null; @@ -335,7 +341,25 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput( + functionName, + modelName, + modelGroupId, + version, + description, + url, + hashValue, + modelFormat, + modelConfig, + deployModel, + modelNodeIds.toArray(new String[0]), + connector, + connectorId, + backendRoles, + addAllBackendRoles, + accessMode, + doesVersionCreateModelGroup + ); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -421,6 +445,24 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup); + return new MLRegisterModelInput( + functionName, + name, + modelGroupId, + version, + description, + url, + hashValue, + modelFormat, + modelConfig, + deployModel, + modelNodeIds.toArray(new String[0]), + connector, + connectorId, + backendRoles, + addAllBackendRoles, + accessMode, + doesVersionCreateModelGroup + ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java index b57b65c524..adff46812f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.register; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelRequest fromActionRequest(ActionRequest actionReque return (MLRegisterModelRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index 18c64c6c5f..2714ddef3e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -5,7 +5,11 @@ package org.opensearch.ml.common.transport.register; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -14,12 +18,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.transport.MLTaskResponse; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Getter; @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { @@ -40,12 +40,12 @@ public MLRegisterModelResponse(StreamInput in) throws IOException { public MLRegisterModelResponse(String taskId, String status) { this.taskId = taskId; - this.status= status; + this.status = status; } public MLRegisterModelResponse(String taskId, String status, String modelId) { this.taskId = taskId; - this.status= status; + this.status = status; this.modelId = modelId; } @@ -73,8 +73,7 @@ public static MLRegisterModelResponse fromActionResponse(ActionResponse actionRe return (MLRegisterModelResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java index de04b2936d..017b97761b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpInput.java @@ -5,15 +5,16 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Builder; -import lombok.Data; +import java.io.IOException; +import java.util.Map; +import java.util.Set; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import java.io.IOException; -import java.util.Map; -import java.util.Set; +import lombok.Builder; +import lombok.Data; @Data public class MLSyncUpInput implements Writeable { @@ -36,14 +37,16 @@ public class MLSyncUpInput implements Writeable { private Map deployToAllNodes; @Builder - public MLSyncUpInput(boolean getDeployedModels, - Map addedWorkerNodes, - Map removedWorkerNodes, - Map> modelRoutingTable, - Map> runningDeployModelTasks, - Map deployToAllNodes, - boolean clearRoutingTable, - boolean syncRunningDeployModelTasks) { + public MLSyncUpInput( + boolean getDeployedModels, + Map addedWorkerNodes, + Map removedWorkerNodes, + Map> modelRoutingTable, + Map> runningDeployModelTasks, + Map deployToAllNodes, + boolean clearRoutingTable, + boolean syncRunningDeployModelTasks + ) { this.getDeployedModels = getDeployedModels; this.addedWorkerNodes = addedWorkerNodes; this.removedWorkerNodes = removedWorkerNodes; @@ -54,7 +57,7 @@ public MLSyncUpInput(boolean getDeployedModels, this.syncRunningDeployModelTasks = syncRunningDeployModelTasks; } - public MLSyncUpInput(){} + public MLSyncUpInput() {} public MLSyncUpInput(StreamInput in) throws IOException { this.getDeployedModels = in.readBoolean(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java index 1158dcc843..0e52342f15 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; -import org.opensearch.transport.TransportRequest; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; -import java.io.IOException; +import lombok.Getter; public class MLSyncUpNodeRequest extends TransportRequest { @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java index e7ac993fba..f46a967ba5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponse.java @@ -5,26 +5,32 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; -import lombok.extern.log4j.Log4j2; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; @Log4j2 @Getter -public class MLSyncUpNodeResponse extends BaseNodeResponse { +public class MLSyncUpNodeResponse extends BaseNodeResponse { private String modelStatus; private String[] deployedModelIds; private String[] runningDeployModelIds; // model ids which have deploying model task running private String[] runningDeployModelTaskIds; // deploy model task ids which is running - public MLSyncUpNodeResponse(DiscoveryNode node, String modelStatus, String[] deployedModelIds, String[] runningDeployModelIds, - String[] runningDeployModelTaskIds) { + public MLSyncUpNodeResponse( + DiscoveryNode node, + String modelStatus, + String[] deployedModelIds, + String[] runningDeployModelIds, + String[] runningDeployModelTaskIds + ) { super(node); this.modelStatus = modelStatus; this.deployedModelIds = deployedModelIds; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java index 56ec920f5f..d66af5d8f7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLSyncUpNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java index dee614685c..ecfd42f464 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponse.java @@ -5,15 +5,15 @@ package org.opensearch.ml.common.transport.sync; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.util.List; - public class MLSyncUpNodesResponse extends BaseNodesResponse { public MLSyncUpNodesResponse(StreamInput in) throws IOException { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java index 6c4f4ed82f..1353edcb83 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.sync; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,7 +14,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; @Getter public class MLSyncUpResponse extends ActionResponse implements ToXContentObject { @@ -27,7 +28,7 @@ public MLSyncUpResponse(StreamInput in) throws IOException { } public MLSyncUpResponse(String status) { - this.status= status; + this.status = status; } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java index 7b00b6509a..5aed868589 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteAction.java @@ -12,5 +12,7 @@ public class MLTaskDeleteAction extends ActionType { public static final MLTaskDeleteAction INSTANCE = new MLTaskDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/tasks/delete"; - private MLTaskDeleteAction() { super(NAME, DeleteResponse::new);} + private MLTaskDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java index a7782a60ea..b109c52b42 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskDeleteRequest.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.task; -import lombok.Builder; -import lombok.Getter; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -14,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLTaskDeleteRequest extends ActionRequest { @Getter @@ -54,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLTaskDeleteRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLTaskDeleteRequest) { - return (MLTaskDeleteRequest)actionRequest; + return (MLTaskDeleteRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskDeleteRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java index 4aaa143a1f..2d76df4dc7 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetAction.java @@ -11,5 +11,7 @@ public class MLTaskGetAction extends ActionType { public static final MLTaskGetAction INSTANCE = new MLTaskGetAction(); public static final String NAME = "cluster:admin/opensearch/ml/tasks/get"; - private MLTaskGetAction() { super(NAME, MLTaskGetResponse::new);} + private MLTaskGetAction() { + super(NAME, MLTaskGetResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java index 06145adef7..3feb5c661d 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetRequest.java @@ -5,9 +5,13 @@ package org.opensearch.ml.common.transport.task; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; -import lombok.Builder; -import lombok.Getter; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -15,12 +19,8 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.Builder; +import lombok.Getter; public class MLTaskGetRequest extends ActionRequest { @Getter @@ -55,11 +55,10 @@ public ActionRequestValidationException validate() { public static MLTaskGetRequest fromActionRequest(ActionRequest actionRequest) { if (actionRequest instanceof MLTaskGetRequest) { - return (MLTaskGetRequest)actionRequest; + return (MLTaskGetRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskGetRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java index cc4d51192a..071ab82682 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskGetResponse.java @@ -5,8 +5,11 @@ package org.opensearch.ml.common.transport.task; -import lombok.Builder; -import lombok.Getter; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -17,10 +20,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTask; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; +import lombok.Builder; +import lombok.Getter; @Getter public class MLTaskGetResponse extends ActionResponse implements ToXContentObject { @@ -37,7 +38,7 @@ public MLTaskGetResponse(StreamInput in) throws IOException { } @Override - public void writeTo(StreamOutput out) throws IOException{ + public void writeTo(StreamOutput out) throws IOException { mlTask.writeTo(out); } @@ -51,8 +52,7 @@ public static MLTaskGetResponse fromActionResponse(ActionResponse actionResponse return (MLTaskGetResponse) actionResponse; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTaskGetResponse(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java index 13b38aa687..cd2636e991 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/task/MLTaskSearchAction.java @@ -2,7 +2,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ml.common.transport.model.MLModelSearchAction; public class MLTaskSearchAction extends ActionType { // External Action which used for public facing RestAPIs. diff --git a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java index e03f1e8dda..45012f09eb 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequest.java @@ -5,11 +5,14 @@ package org.opensearch.ml.common.transport.training; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Objects; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,13 +22,11 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Objects; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -79,8 +80,7 @@ public static MLTrainingTaskRequest fromActionRequest(ActionRequest actionReques return (MLTrainingTaskRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLTrainingTaskRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java index d0e399f291..c08dffc336 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInput.java @@ -5,8 +5,12 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,11 +19,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUndeployModelInput implements ToXContentObject, Writeable { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java index 4cd7bee8c6..9ff9255453 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeRequest.java @@ -5,12 +5,13 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.transport.TransportRequest; -import java.io.IOException; +import lombok.Getter; public class MLUndeployModelNodeRequest extends TransportRequest { @Getter diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java index 2af72a6d6a..5ba693585e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponse.java @@ -5,7 +5,9 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -14,8 +16,7 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; +import lombok.Getter; @Getter public class MLUndeployModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { @@ -26,9 +27,10 @@ public class MLUndeployModelNodeResponse extends BaseNodeResponse implements ToX // This is to record before undeploy the model, which nodes are working nodes. private Map modelWorkerNodeBeforeRemoval; - public MLUndeployModelNodeResponse(DiscoveryNode node, - Map modelUndeployStatus, - Map modelWorkerNodeBeforeRemoval + public MLUndeployModelNodeResponse( + DiscoveryNode node, + Map modelUndeployStatus, + Map modelWorkerNodeBeforeRemoval ) { super(node); this.modelUndeployStatus = modelUndeployStatus; @@ -39,10 +41,10 @@ public MLUndeployModelNodeResponse(DiscoveryNode node, public MLUndeployModelNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelUndeployStatus = in.readMap(s -> s.readString(), s-> s.readString()); + this.modelUndeployStatus = in.readMap(s -> s.readString(), s -> s.readString()); } if (in.readBoolean()) { - this.modelWorkerNodeBeforeRemoval = in.readMap(s -> s.readString(), s-> s.readOptionalStringArray()); + this.modelWorkerNodeBeforeRemoval = in.readMap(s -> s.readString(), s -> s.readOptionalStringArray()); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java index cea0d484fe..48b2bf7c5c 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequest.java @@ -5,13 +5,14 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; +import lombok.Getter; public class MLUndeployModelNodesRequest extends BaseNodesRequest { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java index 3728f4dd8e..22976eebf5 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponse.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.transport.undeploy; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -15,16 +18,17 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLUndeployModelNodesResponse extends BaseNodesResponse implements ToXContentObject { public MLUndeployModelNodesResponse(StreamInput in) throws IOException { super(new ClusterName(in), in.readList(MLUndeployModelNodeResponse::readStats), in.readList(FailedNodeException::new)); } - public MLUndeployModelNodesResponse(ClusterName clusterName, List nodes, List failures) { + public MLUndeployModelNodesResponse( + ClusterName clusterName, + List nodes, + List failures + ) { super(clusterName, nodes, failures); } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java index e15987d753..32fdfced27 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsRequest.java @@ -5,11 +5,15 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.ArrayList; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -19,14 +23,11 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.transport.MLTaskRequest; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.List; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -108,8 +109,7 @@ public static MLUndeployModelsRequest fromActionRequest(ActionRequest actionRequ return (MLUndeployModelsRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUndeployModelsRequest(input); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java index 7534b52187..0b3930ecca 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelsResponse.java @@ -5,14 +5,15 @@ package org.opensearch.ml.common.transport.undeploy; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; @Getter public class MLUndeployModelsResponse extends ActionResponse implements ToXContentObject { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java index 3ee8b66805..3fdd8bb09e 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaAction.java @@ -15,4 +15,4 @@ private MLRegisterModelMetaAction() { super(NAME, MLRegisterModelMetaResponse::new); } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index ecb03d9bb6..186a871406 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -5,8 +5,13 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -14,41 +19,36 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.AccessMode; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data -public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ +public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { public static final String FUNCTION_NAME_FIELD = "function_name"; - public static final String MODEL_NAME_FIELD = "name"; //mandatory - public static final String DESCRIPTION_FIELD = "description"; //optional + public static final String MODEL_NAME_FIELD = "name"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional public static final String VERSION_FIELD = "version"; - public static final String MODEL_FORMAT_FIELD = "model_format"; //mandatory + public static final String MODEL_FORMAT_FIELD = "model_format"; // mandatory public static final String MODEL_STATE_FIELD = "model_state"; public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes"; - public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; //mandatory - public static final String MODEL_CONFIG_FIELD = "model_config"; //mandatory - public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; //mandatory - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; // mandatory + public static final String MODEL_CONFIG_FIELD = "model_config"; // mandatory + public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; // mandatory + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; // optional public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; - private FunctionName functionName; private String name; @@ -70,10 +70,23 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private Boolean doesVersionCreateModelGroup; @Builder(toBuilder = true) - public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, - AccessMode accessMode, - Boolean isAddAllBackendRoles, - Boolean doesVersionCreateModelGroup) { + public MLRegisterModelMetaInput( + String name, + FunctionName functionName, + String modelGroupId, + String version, + String description, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHashValue, + MLModelConfig modelConfig, + Integer totalChunks, + List backendRoles, + AccessMode accessMode, + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup + ) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -88,7 +101,32 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m if (modelContentHashValue == null) { throw new IllegalArgumentException("model content hash value is null"); } - if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration. + if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The + // tokenize + // model + // doesn't + // require + // a + // model + // configuration. + // Currently, + // we + // only + // support + // one + // type + // of + // sparse + // model, + // which + // is + // pretrained, + // and it + // doesn't + // necessitate + // a + // model + // configuration. throw new IllegalArgumentException("model config is null"); } if (totalChunks == null) { @@ -110,7 +148,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.doesVersionCreateModelGroup = doesVersionCreateModelGroup; } - public MLRegisterModelMetaInput(StreamInput in) throws IOException{ + public MLRegisterModelMetaInput(StreamInput in) throws IOException { this.name = in.readString(); this.functionName = in.readEnum(FunctionName.class); this.modelGroupId = in.readOptionalString(); @@ -296,7 +334,23 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup); + return new MLRegisterModelMetaInput( + name, + functionName, + modelGroupId, + version, + description, + modelFormat, + modelState, + modelContentSizeInBytes, + modelContentHashValue, + modelConfig, + totalChunks, + backendRoles, + accessMode, + isAddAllBackendRoles, + doesVersionCreateModelGroup + ); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java index dbfc9283fc..19558cfb60 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequest.java @@ -5,11 +5,13 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; @@ -17,12 +19,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -62,8 +63,7 @@ public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionR return (MLRegisterModelMetaRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLRegisterModelMetaRequest(input); @@ -73,4 +73,4 @@ public static MLRegisterModelMetaRequest fromActionRequest(ActionRequest actionR } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java index 42b734a1a2..62c3ef1b7f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,7 +14,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; public class MLRegisterModelMetaResponse extends ActionResponse implements ToXContentObject { @@ -33,7 +34,7 @@ public MLRegisterModelMetaResponse(StreamInput in) throws IOException { public MLRegisterModelMetaResponse(String modelId, String status) { this.modelId = modelId; - this.status= status; + this.status = status; } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java index e6337f1347..1658bb6483 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkAction.java @@ -5,7 +5,6 @@ package org.opensearch.ml.common.transport.upload_chunk; - import org.opensearch.action.ActionType; public class MLUploadModelChunkAction extends ActionType { diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java index 256c4b1fe4..8f1392895f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInput.java @@ -5,8 +5,10 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Builder; -import lombok.Data; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; + import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -15,9 +17,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; - -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import lombok.Builder; +import lombok.Data; @Data public class MLUploadModelChunkInput implements ToXContentObject, Writeable { @@ -37,7 +38,6 @@ public MLUploadModelChunkInput(String modelId, Integer chunkNumber, byte[] conte this.chunkNumber = chunkNumber; } - public MLUploadModelChunkInput(StreamInput in) throws IOException { this.modelId = in.readString(); this.chunkNumber = in.readInt(); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java index 253d13c1ed..5edea364aa 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequest.java @@ -5,25 +5,25 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.ToString; -import lombok.experimental.FieldDefaults; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.UncheckedIOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.InputStreamStreamInput; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.UncheckedIOException; -import static org.opensearch.action.ValidateActions.addValidationError; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -63,8 +63,7 @@ public static MLUploadModelChunkRequest fromActionRequest(ActionRequest actionRe return (MLUploadModelChunkRequest) actionRequest; } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { return new MLUploadModelChunkRequest(input); @@ -74,4 +73,4 @@ public static MLUploadModelChunkRequest fromActionRequest(ActionRequest actionRe } } -} \ No newline at end of file +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java index b6a065a1be..de5b1603a2 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponse.java @@ -5,7 +5,8 @@ package org.opensearch.ml.common.transport.upload_chunk; -import lombok.Getter; +import java.io.IOException; + import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -13,20 +14,20 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import lombok.Getter; public class MLUploadModelChunkResponse extends ActionResponse implements ToXContentObject { public static final String STATUS_FIELD = "status"; @Getter private String status; - public MLUploadModelChunkResponse (StreamInput in) throws IOException { + public MLUploadModelChunkResponse(StreamInput in) throws IOException { super(in); this.status = in.readString(); } - public MLUploadModelChunkResponse (String status) { - this.status= status; + public MLUploadModelChunkResponse(String status) { + this.status = status; } @Override @@ -42,4 +43,3 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } } - diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index edbd94b37f..5d1e7dcb33 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -5,13 +5,6 @@ package org.opensearch.ml.common.utils; -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; - import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.AccessController; @@ -21,6 +14,14 @@ import java.util.List; import java.util.Map; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; + public class StringUtils { public static final Gson gson; @@ -71,7 +72,7 @@ public static Map getParameterMap(Map parameterObjs) try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { if (value instanceof String) { - parameters.put(key, (String)value); + parameters.put(key, (String) value); } else { parameters.put(key, gson.toJson(value)); } diff --git a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java index f8884f11fd..f930390208 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLCommonsClassLoaderTests.java @@ -5,14 +5,25 @@ package org.opensearch.ml.common; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; @@ -29,18 +40,6 @@ import org.opensearch.ml.common.output.execute.samplecalculator.LocalSampleCalculatorOutput; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - public class MLCommonsClassLoaderTests { private SampleAlgoParams params; @@ -75,7 +74,8 @@ public void setUp() throws IOException { @Test public void testClassLoader_SampleAlgoParams() { - SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader.initMLInstance(FunctionName.SAMPLE_ALGO, streamInputForParams, StreamInput.class); + SampleAlgoParams sampleAlgoParams = MLCommonsClassLoader + .initMLInstance(FunctionName.SAMPLE_ALGO, streamInputForParams, StreamInput.class); assertEquals(params.getSampleParam(), sampleAlgoParams.getSampleParam()); } @@ -83,7 +83,7 @@ public void testClassLoader_SampleAlgoParams() { public void testClassLoader_Return_MLAlgoParams() { MLAlgoParams mlAlgoParams = MLCommonsClassLoader.initMLInstance(FunctionName.SAMPLE_ALGO, streamInputForParams, StreamInput.class); assertTrue(mlAlgoParams instanceof SampleAlgoParams); - assertEquals(params.getSampleParam(), ((SampleAlgoParams)mlAlgoParams).getSampleParam()); + assertEquals(params.getSampleParam(), ((SampleAlgoParams) mlAlgoParams).getSampleParam()); } @Test @@ -97,27 +97,30 @@ public void testClassLoader_WrongType() { @Test public void testClassLoader_ExecuteInput() { - LocalSampleCalculatorInput calculatorInput = MLCommonsClassLoader.initExecuteInputInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, streamInputForInput, StreamInput.class); + LocalSampleCalculatorInput calculatorInput = MLCommonsClassLoader + .initExecuteInputInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, streamInputForInput, StreamInput.class); assertEquals(this.input, calculatorInput); } @Test public void testClassLoader_ExecuteOutput() { - LocalSampleCalculatorOutput calculatorOutput = MLCommonsClassLoader.initExecuteOutputInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, streamInputForOutput, StreamInput.class); + LocalSampleCalculatorOutput calculatorOutput = MLCommonsClassLoader + .initExecuteOutputInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, streamInputForOutput, StreamInput.class); assertEquals(this.output, calculatorOutput); } @Test public void testClassLoader_ExecuteMCorrInput() throws IOException { List inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); Input mcorrInput = new MetricsCorrelationInput(inputData); BytesStreamOutput bytesStreamOutputMCorrInput = new BytesStreamOutput(); mcorrInput.writeTo(bytesStreamOutputMCorrInput); StreamInput streamInputForMcorrInput = bytesStreamOutputMCorrInput.bytes().streamInput(); - MetricsCorrelationInput mcorrStreamInput = MLCommonsClassLoader.initExecuteInputInstance(FunctionName.METRICS_CORRELATION, streamInputForMcorrInput, StreamInput.class); + MetricsCorrelationInput mcorrStreamInput = MLCommonsClassLoader + .initExecuteInputInstance(FunctionName.METRICS_CORRELATION, streamInputForMcorrInput, StreamInput.class); assertArrayEquals(((MetricsCorrelationInput) mcorrInput).getInputData().toArray(), mcorrStreamInput.getInputData().toArray()); } @@ -125,11 +128,12 @@ public void testClassLoader_ExecuteMCorrInput() throws IOException { @Test public void testClassLoader_ExecuteOutputMCorr() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); @@ -137,7 +141,8 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException { BytesStreamOutput bytesStreamOutputMcorrOutput = new BytesStreamOutput(); output.writeTo(bytesStreamOutputMcorrOutput); StreamInput streamInputForOutput = bytesStreamOutputMcorrOutput.bytes().streamInput(); - MetricsCorrelationOutput mcorrOutput = MLCommonsClassLoader.initExecuteOutputInstance(FunctionName.METRICS_CORRELATION, streamInputForOutput, StreamInput.class); + MetricsCorrelationOutput mcorrOutput = MLCommonsClassLoader + .initExecuteOutputInstance(FunctionName.METRICS_CORRELATION, streamInputForOutput, StreamInput.class); assertEquals(1, mcorrOutput.getModelOutput().size()); MCorrModelTensors testmodelTensors = mcorrOutput.getModelOutput().get(0); @@ -145,22 +150,29 @@ public void testClassLoader_ExecuteOutputMCorr() throws IOException { MCorrModelTensor testmodelTensor = testmodelTensors.getMCorrModelTensors().get(0); float[] events = testmodelTensor.getEvent_pattern(); long[] metrics = testmodelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); } private void testClassLoader_MLInput_DlModel(FunctionName functionName) throws IOException { assertTrue(MLCommonsClassLoader.canInitMLInput(functionName)); - String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = + "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); - TextDocsMLInput mlInput = MLCommonsClassLoader.initMLInput(functionName, new Object[]{parser, functionName}, XContentParser.class, FunctionName.class); + TextDocsMLInput mlInput = MLCommonsClassLoader + .initMLInput(functionName, new Object[] { parser, functionName }, XContentParser.class, FunctionName.class); assertNotNull(mlInput); assertEquals(functionName, mlInput.getFunctionName()); - assertEquals(2, ((TextDocsInputDataSet)mlInput.getInputDataset()).getDocs().size()); + assertEquals(2, ((TextDocsInputDataSet) mlInput.getInputDataset()).getDocs().size()); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java index 71f7f46cf2..c1abe07297 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelGroupTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -19,10 +23,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; - public class MLModelGroupTest { @Rule @@ -47,31 +47,41 @@ public void toXContent_Empty() throws IOException { @Test public void toXContent() throws IOException { - MLModelGroup modelGroup = MLModelGroup.builder() - .name("test") - .description("this is test group") - .latestVersion(1) - .backendRoles(Arrays.asList("role1", "role2")) - .owner(new User()) - .access(AccessMode.PUBLIC.name()) - .build(); + MLModelGroup modelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"access\":\"PUBLIC\"}", content); + Assert + .assertEquals( + "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}", + content + ); } @Test public void parse() throws IOException { - String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"access\":\"PUBLIC\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = "{\"name\":\"test\",\"latest_version\":1,\"description\":\"this is test group\"," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"owner\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"access\":\"PUBLIC\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLModelGroup modelGroup = MLModelGroup.parse(parser); Assert.assertEquals("test", modelGroup.getName()); @@ -85,8 +95,13 @@ public void parse() throws IOException { @Test public void parse_Empty() throws IOException { String jsonStr = "{\"name\":\"test\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLModelGroup modelGroup = MLModelGroup.parse(parser); Assert.assertEquals("test", modelGroup.getName()); @@ -97,14 +112,15 @@ public void parse_Empty() throws IOException { @Test public void writeTo() throws IOException { - MLModelGroup originalModelGroup = MLModelGroup.builder() - .name("test") - .description("this is test group") - .latestVersion(1) - .backendRoles(Arrays.asList("role1", "role2")) - .owner(new User()) - .access(AccessMode.PUBLIC.name()) - .build(); + MLModelGroup originalModelGroup = MLModelGroup + .builder() + .name("test") + .description("this is test group") + .latestVersion(1) + .backendRoles(Arrays.asList("role1", "role2")) + .owner(new User()) + .access(AccessMode.PUBLIC.name()) + .build(); BytesStreamOutput output = new BytesStreamOutput(); originalModelGroup.writeTo(output); diff --git a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java index b59fa2ac2a..45b6f9ce7e 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLModelTests.java @@ -5,61 +5,64 @@ package org.opensearch.ml.common; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.time.Instant; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import java.io.IOException; -import java.time.Instant; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLModelTests { MLModel mlModel; TextEmbeddingModelConfig config; Function function; + @Before public void setUp() { FunctionName algorithm = FunctionName.KMEANS; - User user = new User(); - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + User user = new User(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); Instant now = Instant.now(); - mlModel = MLModel.builder() - .name("some model") - .algorithm(algorithm) - .version("1.0.0") - .content("some content") - .user(user) - .description("test description") - .modelFormat(MLModelFormat.ONNX) - .modelState(MLModelState.DEPLOYED) - .modelContentSizeInBytes(10_000_000l) - .modelContentHash("test_hash") - .modelConfig(config) - .createdTime(now) - .lastRegisteredTime(now) - .lastDeployedTime(now) - .lastUndeployedTime(now) - .modelId("model_id") - .chunkNumber(1) - .totalChunks(10) - .build(); + mlModel = MLModel + .builder() + .name("some model") + .algorithm(algorithm) + .version("1.0.0") + .content("some content") + .user(user) + .description("test description") + .modelFormat(MLModelFormat.ONNX) + .modelState(MLModelState.DEPLOYED) + .modelContentSizeInBytes(10_000_000l) + .modelContentHash("test_hash") + .modelConfig(config) + .createdTime(now) + .lastRegisteredTime(now) + .lastDeployedTime(now) + .lastUndeployedTime(now) + .modelId("model_id") + .chunkNumber(1) + .totalChunks(10) + .build(); function = parser -> { try { return MLModel.parse(parser, algorithm.name()); @@ -71,11 +74,20 @@ public void setUp() { @Test public void toXContent() throws IOException { - MLModel mlModel = MLModel.builder().algorithm(FunctionName.KMEANS).name("model_name").version("1.0.0").content("test_content").build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.KMEANS) + .name("model_name") + .version("1.0.0") + .content("test_content") + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\"}", mlModelContent); + assertEquals( + "{\"name\":\"model_name\",\"algorithm\":\"KMEANS\",\"model_version\":\"1.0.0\",\"model_content\":\"test_content\"}", + mlModelContent + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java index 39d391bf33..2ffdc32679 100644 --- a/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/MLTaskTests.java @@ -19,10 +19,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; - public class MLTaskTests { private MLTask mlTask; @@ -59,13 +55,14 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlTask.toXContent(builder, ToXContent.EMPTY_PARAMS); String taskContent = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals( - "{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\"," - + "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0," - + "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000," - + "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}", - taskContent - ); + Assert + .assertEquals( + "{\"task_id\":\"dummy taskId\",\"model_id\":\"test_model_id\",\"task_type\":\"PREDICTION\"," + + "\"function_name\":\"KMEANS\",\"state\":\"RUNNING\",\"input_type\":\"DATA_FRAME\",\"progress\":0.0," + + "\"output_index\":\"test_index\",\"worker_node\":[\"node1\"],\"create_time\":1641599940000," + + "\"last_update_time\":1641600000000,\"error\":\"test_error\",\"is_async\":false}", + taskContent + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java index 263c1b0f31..e9c4328320 100644 --- a/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java +++ b/common/src/test/java/org/opensearch/ml/common/RemoteModelTests.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -17,72 +22,78 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnectorTest; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class RemoteModelTests { @Test public void toXContent_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" + - ",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + - "\"connector_id\":\"test_connector_id\"}", mlModelContent); + assertEquals( + "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"" + + ",\"model_version\":\"1.0.0\",\"description\":\"test model\"," + + "\"connector_id\":\"test_connector_id\"}", + mlModelContent + ); } @Test public void toXContent_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," + - "\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," + - "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}}", mlModelContent); + assertEquals( + "{\"name\":\"test_model_name\",\"model_group_id\":\"test_group_id\",\"algorithm\":\"REMOTE\"," + + "\"model_version\":\"1.0.0\",\"description\":\"test model\",\"connector\":{\"name\":\"test_connector_name\"," + + "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}}", + mlModelContent + ); } @Test public void parse_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); Assert.assertNull(parsedModel.getConnector()); @@ -92,49 +103,53 @@ public void parse_ConnectorId() throws IOException { @Test public void parse_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mlModel.toXContent(builder, EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); MLModel parsedModel = MLModel.parse(parser, FunctionName.REMOTE.name()); Assert.assertEquals(mlModel.getConnector(), parsedModel.getConnector()); } - @Test public void readInputStream_ConnectorId() throws IOException { - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connectorId("test_connector_id") - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connectorId("test_connector_id") + .build(); readInputStream(mlModel); } @Test public void readInputStream_InternalConnector() throws IOException { Connector connector = HttpConnectorTest.createHttpConnector(); - MLModel mlModel = MLModel.builder() - .algorithm(FunctionName.REMOTE) - .name("test_model_name") - .version("1.0.0") - .modelGroupId("test_group_id") - .description("test model") - .connector(connector) - .build(); + MLModel mlModel = MLModel + .builder() + .algorithm(FunctionName.REMOTE) + .name("test_model_name") + .version("1.0.0") + .modelGroupId("test_group_id") + .description("test model") + .connector(connector) + .build(); readInputStream(mlModel); } diff --git a/common/src/test/java/org/opensearch/ml/common/TestHelper.java b/common/src/test/java/org/opensearch/ml/common/TestHelper.java index baa40f9102..723ca70a75 100644 --- a/common/src/test/java/org/opensearch/ml/common/TestHelper.java +++ b/common/src/test/java/org/opensearch/ml/common/TestHelper.java @@ -5,11 +5,12 @@ package org.opensearch.ml.common; -import org.opensearch.core.common.Strings; -import org.opensearch.core.common.bytes.BytesReference; +import java.io.IOException; +import java.util.function.Function; + import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -17,9 +18,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.function.Function; - public class TestHelper { public static void testParse(ToXContentObject obj, Function function) throws IOException { @@ -27,7 +25,7 @@ public static void testParse(ToXContentObject obj, Function void testParse(ToXContentObject obj, Function function, boolean wrapWithObject) - throws IOException { + throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); if (wrapWithObject) { builder.startObject(); @@ -40,10 +38,11 @@ public static void testParse(ToXContentObject obj, Function void testParseFromString(ToXContentObject obj, String jsonStr, - Function function) throws IOException { - XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, jsonStr); + public static void testParseFromString(ToXContentObject obj, String jsonStr, Function function) + throws IOException { + XContentParser parser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, jsonStr); parser.nextToken(); T parsedObj = function.apply(parser); obj.equals(parsedObj); diff --git a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java index a242c213ea..518559e067 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/AwsConnectorTest.java @@ -5,6 +5,20 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.function.Function; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -20,20 +34,6 @@ import org.opensearch.ml.common.TestHelper; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.function.Function; - -import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SESSION_TOKEN_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; - public class AwsConnectorTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -43,8 +43,8 @@ public class AwsConnectorTest { @Before public void setUp() { - encryptFunction = s -> "encrypted: "+s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: "+s.toUpperCase(Locale.ROOT); + encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -106,7 +106,12 @@ public void constructor_NoPredictAction() { credential.put(REGION_FIELD, "test_region"); Map parameters = new HashMap<>(); parameters.put(SERVICE_NAME_FIELD, "test_service"); - AwsConnector connector = AwsConnector.awsConnectorBuilder().protocol(ConnectorProtocols.AWS_SIGV4).credential(credential).parameters(parameters).build(); + AwsConnector connector = AwsConnector + .awsConnectorBuilder() + .protocol(ConnectorProtocols.AWS_SIGV4) + .credential(credential) + .parameters(parameters) + .build(); Assert.assertNotNull(connector); connector.encrypt(encryptFunction); @@ -125,8 +130,13 @@ public void constructor_Parser() throws IOException { awsConnector.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = TestHelper.xContentBuilderToString(builder); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); AwsConnector connector = new AwsConnector(awsConnector.getProtocol(), parser); @@ -209,19 +219,28 @@ private AwsConnector createAwsConnector(Map parameters, Map "encrypted: "+s.toLowerCase(Locale.ROOT); - decryptFunction = s -> "decrypted: "+s.toUpperCase(Locale.ROOT); + encryptFunction = s -> "encrypted: " + s.toLowerCase(Locale.ROOT); + decryptFunction = s -> "decrypted: " + s.toUpperCase(Locale.ROOT); } @Test @@ -71,33 +71,41 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); connector.toXContent(builder, ToXContent.EMPTY_PARAMS); String content = TestHelper.xContentBuilderToString(builder); - Assert.assertEquals("{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + - "\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}", content); + Assert + .assertEquals( + "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + + "\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"}," + + "\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}", + content + ); } - @Test public void constructor_Parser() throws IOException { - String jsonStr = "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + - "\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"}," + - "\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + String jsonStr = "{\"name\":\"test_connector_name\",\"version\":\"1\",\"description\":\"this is a test connector\"," + + "\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"}," + + "\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"},\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); HttpConnector connector = new HttpConnector("http", parser); @@ -277,7 +285,15 @@ public static HttpConnector createHttpConnector() { String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, requestBody, preProcessFunction, postProcessFunction); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + requestBody, + preProcessFunction, + postProcessFunction + ); Map parameters = new HashMap<>(); parameters.put("input", "test input value"); @@ -285,17 +301,18 @@ public static HttpConnector createHttpConnector() { Map credential = new HashMap<>(); credential.put("key", "test_key_value"); - HttpConnector connector = HttpConnector.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(parameters) - .credential(credential) - .actions(Arrays.asList(action)) - .backendRoles(Arrays.asList("role1", "role2")) - .accessMode(AccessMode.PUBLIC) - .build(); + HttpConnector connector = HttpConnector + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(action)) + .backendRoles(Arrays.asList("role1", "role2")) + .accessMode(AccessMode.PUBLIC) + .build(); return connector; } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java index 5d4c0c88d7..292ddc05d8 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPostProcessFunctionTest.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.connector; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; import java.util.ArrayList; import java.util.Collections; import java.util.List; -import static org.opensearch.ml.common.connector.MLPostProcessFunction.OPENAI_EMBEDDING; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; public class MLPostProcessFunctionTest { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java index b3784c1c1c..dfcd2a41a8 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/MLPreProcessFunctionTest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.connector; +import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; + import org.junit.Assert; import org.junit.Test; -import static org.opensearch.ml.common.connector.MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; - public class MLPreProcessFunctionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java index d1597ada6a..4e0e83c3bf 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/BooleanValueTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - public class BooleanValueTest { @Test public void booleanValue() { diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java index 95ec45a3ae..c940969db8 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnMetaTest.java @@ -4,6 +4,13 @@ */ package org.opensearch.ml.common.dataframe; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -13,12 +20,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class ColumnMetaTest { ColumnMeta columnMeta; diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java index 905fe31159..3c0a5cfa38 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnTypeTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.dataframe; +import java.math.BigDecimal; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import java.math.BigDecimal; - public class ColumnTypeTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java index 4b783a3755..a9c137f459 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueBuilderTest.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import static org.junit.Assert.assertEquals; import java.math.BigDecimal; -import static org.junit.Assert.assertEquals; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; public class ColumnValueBuilderTest { @@ -45,12 +45,12 @@ public void build() { assertEquals(2.1f, value.floatValue(), 1e-5); assertEquals(2.1d, value.doubleValue(), 1e-5); - value = ColumnValueBuilder.build((short)2); + value = ColumnValueBuilder.build((short) 2); assertEquals(ColumnType.SHORT, value.columnType()); assertEquals(2, value.shortValue()); assertEquals(2.0d, value.doubleValue(), 1e-5); - value = ColumnValueBuilder.build((long)2); + value = ColumnValueBuilder.build((long) 2); assertEquals(ColumnType.LONG, value.columnType()); assertEquals(2, value.longValue()); assertEquals(2.0d, value.doubleValue(), 1e-5); @@ -63,4 +63,4 @@ public void build_IllegalType() { Object obj = new BigDecimal("0"); ColumnValueBuilder.build(obj); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java index 07287da537..baac846c6f 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueReaderTest.java @@ -5,13 +5,13 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import static org.junit.Assert.assertEquals; - public class ColumnValueReaderTest { ColumnValueReader reader = new ColumnValueReader(); @@ -86,7 +86,7 @@ public void read_FloatValue() throws IOException { @Test public void read_ShortValue() throws IOException { - ColumnValue value = new ShortValue((short)2); + ColumnValue value = new ShortValue((short) 2); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); value.writeTo(bytesStreamOutput); value = reader.read(bytesStreamOutput.bytes().streamInput()); @@ -96,7 +96,7 @@ public void read_ShortValue() throws IOException { @Test public void read_LongValue() throws IOException { - ColumnValue value = new LongValue((long)2); + ColumnValue value = new LongValue((long) 2); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); value.writeTo(bytesStreamOutput); value = reader.read(bytesStreamOutput.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java index 3e8661d275..cc60a85608 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ColumnValueTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertTrue; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import static org.junit.Assert.assertTrue; - public class ColumnValueTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java index 98f4282254..d9d20e7b4d 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DataFrameBuilderTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -16,8 +18,6 @@ import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import static org.junit.Assert.assertEquals; - public class DataFrameBuilderTest { @Rule @@ -25,10 +25,7 @@ public class DataFrameBuilderTest { @Test public void emptyDataFrame_Success() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ColumnMeta.builder() - .name("k1") - .columnType(ColumnType.DOUBLE) - .build()}; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrame dataFrame = DataFrameBuilder.emptyDataFrame(columnMetas); assertEquals(0, dataFrame.size()); } @@ -68,9 +65,7 @@ public void load_Exception_NullInputMapList() { public void load_Success_ColumnMetasAndInputMapList() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrame dataFrame = DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); assertEquals(1, dataFrame.size()); } @@ -91,17 +86,13 @@ public void load_Exception_NullColumnMetas() { @Test(expected = IllegalArgumentException.class) public void load_Exception_ColumnMetasAndEmptyInputMapList() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, Collections.emptyList()); } @Test(expected = IllegalArgumentException.class) public void load_Exception_ColumnMetasAndNullInputMapList() { - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, null); } @@ -112,10 +103,9 @@ public void load_Exception_DifferentColumnsInColumnMetasAndInputMapList() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build(), - ColumnMeta.builder().name("k2").columnType(ColumnType.DOUBLE).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { + ColumnMeta.builder().name("k1").columnType(ColumnType.DOUBLE).build(), + ColumnMeta.builder().name("k2").columnType(ColumnType.DOUBLE).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -126,9 +116,7 @@ public void load_Exception_DifferentTypesForSameField() { Map map = new HashMap<>(); map.put("k1", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -139,9 +127,7 @@ public void load_Exception_DifferentFields() { Map map = new HashMap<>(); map.put("k2", 2.3D); - ColumnMeta[] columnMetas = new ColumnMeta[]{ - ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() - }; + ColumnMeta[] columnMetas = new ColumnMeta[] { ColumnMeta.builder().name("k1").columnType(ColumnType.INTEGER).build() }; DataFrameBuilder.load(columnMetas, Collections.singletonList(map)); } @@ -158,4 +144,4 @@ public void load_Success_StreamInput() throws IOException { dataFrame = DataFrameBuilder.load(bytesStreamOutput.bytes().streamInput()); assertEquals(1, dataFrame.size()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java index 3721f9d7a8..8ae385fd2a 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DefaultDataFrameTest.java @@ -5,6 +5,10 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -15,17 +19,13 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - public class DefaultDataFrameTest { DefaultDataFrame defaultDataFrame; @@ -37,23 +37,11 @@ public class DefaultDataFrameTest { @Before public void setUp() { ColumnMeta[] columnMetas = new ColumnMeta[4]; - columnMetas[0] = ColumnMeta.builder() - .name("c1") - .columnType(ColumnType.STRING) - .build(); - columnMetas[1] = ColumnMeta.builder() - .name("c2") - .columnType(ColumnType.INTEGER) - .build(); - columnMetas[2] = ColumnMeta.builder() - .name("c3") - .columnType(ColumnType.DOUBLE) - .build(); - - columnMetas[3] = ColumnMeta.builder() - .name("c4") - .columnType(ColumnType.BOOLEAN) - .build(); + columnMetas[0] = ColumnMeta.builder().name("c1").columnType(ColumnType.STRING).build(); + columnMetas[1] = ColumnMeta.builder().name("c2").columnType(ColumnType.INTEGER).build(); + columnMetas[2] = ColumnMeta.builder().name("c3").columnType(ColumnType.DOUBLE).build(); + + columnMetas[3] = ColumnMeta.builder().name("c4").columnType(ColumnType.BOOLEAN).build(); Row row = new Row(4); row.setValue(0, new StringValue("string")); @@ -157,8 +145,7 @@ public void appendRow_Exception_DifferentColumns() { @Test public void appendRow_Exception_DifferentColumnTypes() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("the column type is different in column meta:BOOLEAN and input row:DOUBLE for " + - "index: 3"); + exceptionRule.expectMessage("the column type is different in column meta:BOOLEAN and input row:DOUBLE for " + "index: 3"); Row row = new Row(4); row.setValue(0, new StringValue("string2")); row.setValue(1, new IntValue(2)); @@ -174,21 +161,21 @@ public void columnMetas_Success() { } @Test - public void remove_Exception_InputColumnIndexBiggerThanColumensLength(){ + public void remove_Exception_InputColumnIndexBiggerThanColumensLength() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:4"); defaultDataFrame.remove(4); } @Test - public void remove_Exception_InputColumnIndexNegtiveColumensLength(){ + public void remove_Exception_InputColumnIndexNegtiveColumensLength() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:4"); defaultDataFrame.remove(-1); } @Test - public void remove_EmptyColumnMeta(){ + public void remove_EmptyColumnMeta() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length:0"); DefaultDataFrame dataFrame = new DefaultDataFrame(new ColumnMeta[0]); @@ -197,31 +184,31 @@ public void remove_EmptyColumnMeta(){ } @Test - public void remove_Success(){ + public void remove_Success() { DataFrame dataFrame = defaultDataFrame.remove(3); assertEquals(3, dataFrame.columnMetas().length); assertEquals(3, dataFrame.getRow(0).size()); } @Test - public void select_Success(){ - DataFrame dataFrame = defaultDataFrame.select(new int[]{1, 3}); + public void select_Success() { + DataFrame dataFrame = defaultDataFrame.select(new int[] { 1, 3 }); assertEquals(2, dataFrame.columnMetas().length); assertEquals(2, dataFrame.getRow(0).size()); } @Test - public void select_Exception_EmptyInputColumns(){ + public void select_Exception_EmptyInputColumns() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columns can't be null or empty"); defaultDataFrame.select(new int[0]); } @Test - public void select_Exception_InvalidColumn(){ + public void select_Exception_InvalidColumn() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("columnIndex can't be negative or bigger than columns length"); - defaultDataFrame.select(new int[]{5}); + defaultDataFrame.select(new int[] { 5 }); } @Test @@ -233,44 +220,47 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"column_metas\":[" + - "{\"name\":\"c1\",\"column_type\":\"STRING\"}," + - "{\"name\":\"c2\",\"column_type\":\"INTEGER\"}," + - "{\"name\":\"c3\",\"column_type\":\"DOUBLE\"}," + - "{\"name\":\"c4\",\"column_type\":\"BOOLEAN\"}]," + - "\"rows\":[" + - "{\"values\":[" + - "{\"column_type\":\"STRING\",\"value\":\"string\"}," + - "{\"column_type\":\"INTEGER\",\"value\":1}," + - "{\"column_type\":\"DOUBLE\",\"value\":2.0}," + - "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}", jsonStr); + assertEquals( + "{\"column_metas\":[" + + "{\"name\":\"c1\",\"column_type\":\"STRING\"}," + + "{\"name\":\"c2\",\"column_type\":\"INTEGER\"}," + + "{\"name\":\"c3\",\"column_type\":\"DOUBLE\"}," + + "{\"name\":\"c4\",\"column_type\":\"BOOLEAN\"}]," + + "\"rows\":[" + + "{\"values\":[" + + "{\"column_type\":\"STRING\",\"value\":\"string\"}," + + "{\"column_type\":\"INTEGER\",\"value\":1}," + + "{\"column_type\":\"DOUBLE\",\"value\":2.0}," + + "{\"column_type\":\"BOOLEAN\",\"value\":true}]}]}", + jsonStr + ); } @Test public void testParse_EmptyDataFrame() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); TestHelper.testParse(dataFrame, function, true); } @Test public void testParse_DataFrame() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); - dataFrame.appendRow(new Integer[]{1}); - dataFrame.appendRow(new Integer[]{2}); + dataFrame.appendRow(new Integer[] { 1 }); + dataFrame.appendRow(new Integer[] { 2 }); TestHelper.testParse(dataFrame, function, true); } @Test public void testParse_WrongExtraField() throws IOException { - ColumnMeta[] columnMetas = new ColumnMeta[] {new ColumnMeta("test_int", ColumnType.INTEGER)}; + ColumnMeta[] columnMetas = new ColumnMeta[] { new ColumnMeta("test_int", ColumnType.INTEGER) }; DefaultDataFrame dataFrame = new DefaultDataFrame(columnMetas); - dataFrame.appendRow(new Integer[]{1}); - dataFrame.appendRow(new Integer[]{2}); - String jsonStr = "{\"wrong_field\":{\"test\":\"abc\"},\"column_metas\":[{\"name\":\"test_int\",\"column_type\":" + - "\"INTEGER\"}],\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":" + - "[{\"column_type\":\"INTEGER\",\"value\":2}]}]}"; + dataFrame.appendRow(new Integer[] { 1 }); + dataFrame.appendRow(new Integer[] { 2 }); + String jsonStr = "{\"wrong_field\":{\"test\":\"abc\"},\"column_metas\":[{\"name\":\"test_int\",\"column_type\":" + + "\"INTEGER\"}],\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":" + + "[{\"column_type\":\"INTEGER\",\"value\":2}]}]}"; TestHelper.testParseFromString(dataFrame, jsonStr, function); } } diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java index 9f2e665d56..980dce1784 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/DoubleValueTest.java @@ -5,6 +5,9 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + import java.io.IOException; import org.junit.Test; @@ -13,9 +16,6 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class DoubleValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java index c9e8e552d7..a500be1e12 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/FloatValueTest.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class FloatValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java index 3048e88767..61c5decf4a 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/IntValueTest.java @@ -5,16 +5,16 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class IntValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java index 5cfbe8cf5e..0502d83d9e 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/LongValueTest.java @@ -5,21 +5,21 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class LongValueTest { @Test public void longValue() { - LongValue longValue = new LongValue((long)2); + LongValue longValue = new LongValue((long) 2); assertEquals(ColumnType.LONG, longValue.columnType()); assertEquals(2L, longValue.getValue()); assertEquals(2.0d, longValue.doubleValue(), 1e-5); @@ -27,7 +27,7 @@ public void longValue() { @Test public void testToXContent() throws IOException { - LongValue longValue = new LongValue((long)2); + LongValue longValue = new LongValue((long) 2); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); longValue.toXContent(builder); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java index 90021d29ad..3db4c050b8 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/NullValueTest.java @@ -5,18 +5,18 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - public class NullValueTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java index c2947ca944..09df496bf8 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/RowTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.ml.common.TestHelper.testParse; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + import java.io.IOException; import java.util.Iterator; import java.util.function.Function; @@ -19,13 +26,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.common.TestHelper.testParse; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class RowTest { Row row; @@ -39,7 +39,7 @@ public void setup() { row = new Row(1); row.setValue(0, ColumnValueBuilder.build(0)); - function = parser -> { + function = parser -> { try { return Row.parse(parser); } catch (IOException e) { @@ -125,7 +125,7 @@ public void select() { row = new Row(2); row.setValue(0, ColumnValueBuilder.build(0)); row.setValue(1, ColumnValueBuilder.build(false)); - row = row.select(new int[]{1}); + row = row.select(new int[] { 1 }); assertEquals(1, row.size()); assertFalse(row.getValue(0).booleanValue()); } @@ -143,45 +143,70 @@ public void testToXContent() throws IOException { @Test public void testParse_NullValue() throws IOException { - ColumnValue[] values = new ColumnValue[] {new NullValue()}; + ColumnValue[] values = new ColumnValue[] { new NullValue() }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtLast() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new StringValue("test"), new BooleanValue(true), new NullValue()}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new StringValue("test"), + new BooleanValue(true), + new NullValue() }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtFirst() throws IOException { - ColumnValue[] values = new ColumnValue[] {new NullValue(), new IntValue(1), new DoubleValue(2.0), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new NullValue(), + new IntValue(1), + new DoubleValue(2.0), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_NullValue_AtMiddle() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); testParse(row, function); } @Test public void testParse_ExtraWrongValueField() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); - String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1},{\"column_type\":\"DOUBLE\",\"value\":2}," + - "{\"column_type\":\"NULL\"},{\"column_type\":\"STRING\",\"value\":\"test\"},{\"column_type\":\"BOOLEAN\"," + - "\"value\":true},{\"column_type\":\"WRONG\",\"value\":true}],\"wrong_filed\":{\"test\":\"abc\"}}"; + String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value\":1},{\"column_type\":\"DOUBLE\",\"value\":2}," + + "{\"column_type\":\"NULL\"},{\"column_type\":\"STRING\",\"value\":\"test\"},{\"column_type\":\"BOOLEAN\"," + + "\"value\":true},{\"column_type\":\"WRONG\",\"value\":true}],\"wrong_filed\":{\"test\":\"abc\"}}"; testParseFromString(row, jsonStr, function); } @Test public void testParse_EmptyValueField() throws IOException { - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{}]}"; testParseFromString(row, jsonStr, function); @@ -191,7 +216,12 @@ public void testParse_EmptyValueField() throws IOException { public void testParse_WrongColumnTypeField() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("wrong column type, expect column_type field but got column_type_wrong"); - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{\"column_type_wrong\":\"INTEGER\",\"value\":1}]}"; testParseFromString(row, jsonStr, function); @@ -201,7 +231,12 @@ public void testParse_WrongColumnTypeField() throws IOException { public void testParse_WrongValueField() throws IOException { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("wrong column value, expect value field but got value_wrong"); - ColumnValue[] values = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row = new Row(values); String jsonStr = "{\"values\":[{\"column_type\":\"INTEGER\",\"value_wrong\":1}]}"; testParseFromString(row, jsonStr, function); @@ -215,34 +250,69 @@ public void testEquals_EmptyValues() { Row row2 = new Row(values2); assertTrue(row1.equals(row2)); - ColumnValue[] values3 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values3 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row3 = new Row(values3); assertFalse(row1.equals(row3)); } @Test public void testEquals_AllValuesMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertTrue(row1.equals(row2)); } @Test public void testEquals_SomeValueNotMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new IntValue(2), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new IntValue(2), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertFalse(row1.equals(row2)); } @Test public void testEquals_SomeTypeNotMatch() { - ColumnValue[] values1 = new ColumnValue[] {new IntValue(1), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values1 = new ColumnValue[] { + new IntValue(1), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row1 = new Row(values1); - ColumnValue[] values2 = new ColumnValue[] {new DoubleValue(1.0), new DoubleValue(2.0), new NullValue(), new StringValue("test"), new BooleanValue(true)}; + ColumnValue[] values2 = new ColumnValue[] { + new DoubleValue(1.0), + new DoubleValue(2.0), + new NullValue(), + new StringValue("test"), + new BooleanValue(true) }; Row row2 = new Row(values2); assertFalse(row1.equals(row2)); } diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java index c6e5b5366d..b1a6c08052 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/ShortValueTest.java @@ -5,29 +5,29 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class ShortValueTest { @Test public void shortValue() { - ShortValue shortValue = new ShortValue((short)2); + ShortValue shortValue = new ShortValue((short) 2); assertEquals(ColumnType.SHORT, shortValue.columnType()); - assertEquals((short)2, shortValue.getValue()); + assertEquals((short) 2, shortValue.getValue()); assertEquals(2.0d, shortValue.doubleValue(), 1e-5); } @Test public void testToXContent() throws IOException { - ShortValue shortValue = new ShortValue((short)2); + ShortValue shortValue = new ShortValue((short) 2); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); shortValue.toXContent(builder); diff --git a/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java b/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java index 944d1efd59..e48bbfc2e1 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataframe/StringValueTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.dataframe; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class StringValueTest { @Test public void stringValue() { diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java index d7c6294d20..eb1d4c33fe 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/DataFrameInputDatasetTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -13,19 +15,20 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import static org.junit.Assert.assertEquals; - public class DataFrameInputDatasetTest { @Test public void writeTo_Success() throws IOException { - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() - .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }}))) + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset + .builder() + .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + }))) .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); dataFrameInputDataset.writeTo(bytesStreamOutput); assertEquals(21, bytesStreamOutput.size()); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java index d1a3af0ce2..3502ea8723 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/SearchQueryInputDatasetTest.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -17,8 +19,6 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; -import static org.junit.Assert.assertEquals; - public class SearchQueryInputDatasetTest { @Rule @@ -26,7 +26,8 @@ public class SearchQueryInputDatasetTest { @Test public void writeTo_Success() throws IOException { - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset + .builder() .indices(Arrays.asList("index1")) .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) .build(); @@ -45,9 +46,6 @@ public void writeTo_Success() throws IOException { public void init_EmptyIndices() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("indices can't be empty"); - SearchQueryInputDataset.builder() - .indices(new ArrayList<>()) - .searchSourceBuilder(new SearchSourceBuilder().size(1)) - .build(); + SearchQueryInputDataset.builder().indices(new ArrayList<>()).searchSourceBuilder(new SearchSourceBuilder().size(1)).build(); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java index 89f629e7c2..811a1243e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/TextDocsInputDataSetTest.java @@ -5,17 +5,17 @@ package org.opensearch.ml.common.dataset; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import java.io.IOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class TextDocsInputDataSetTest { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java index c08cb1a468..e8b805e17f 100644 --- a/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java +++ b/common/src/test/java/org/opensearch/ml/common/dataset/remote/RemoteInferenceInputDataSetTest.java @@ -1,16 +1,16 @@ package org.opensearch.ml.common.dataset.remote; -import org.junit.Assert; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.ml.common.dataset.MLInputDataset; +import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import static org.opensearch.ml.common.dataset.MLInputDataType.REMOTE; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.common.dataset.MLInputDataset; public class RemoteInferenceInputDataSetTest { diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java index 5b872c0f36..21ce33f4ed 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLExceptionTest.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.junit.Test; + public class MLExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java index a85e07e02d..ed3b6bb180 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLLimitExceededExceptionTest.java @@ -5,11 +5,10 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; + +import org.junit.Test; public class MLLimitExceededExceptionTest { diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java index 9409859e71..7e1c41d297 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLResourceNotFoundExceptionTest.java @@ -5,11 +5,11 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import org.junit.Test; + public class MLResourceNotFoundExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java b/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java index f9e3a04fdc..eee1038a52 100644 --- a/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java +++ b/common/src/test/java/org/opensearch/ml/common/exception/MLValidationExceptionTest.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.exception; -import org.junit.Test; - import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.junit.Test; + public class MLValidationExceptionTest { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java index cee7956ce8..7607461580 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/MLInputTest.java @@ -5,15 +5,24 @@ package org.opensearch.ml.common.input; -import lombok.NonNull; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -31,15 +40,7 @@ import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.function.Consumer; -import java.util.function.Function; - -import static org.junit.Assert.*; +import lombok.NonNull; public class MLInputTest { @@ -66,11 +67,12 @@ public void setUp() throws Exception { rows.add(new Row(new ColumnValue[] { new DoubleValue(2.0) })); rows.add(new Row(new ColumnValue[] { new DoubleValue(3.0) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); - input = MLInput.builder() - .algorithm(algorithm) - .parameters(LinearRegressionParams.builder().learningRate(0.1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + input = MLInput + .builder() + .algorithm(algorithm) + .parameters(LinearRegressionParams.builder().learningRate(0.1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test @@ -83,11 +85,13 @@ public void constructor_NullAlgorithm() { @Test public void parse_LinearRegression() throws IOException { String indexName = "index1"; - SearchQueryInputDataset inputDataset = SearchQueryInputDataset.builder() - .indices(Arrays.asList(indexName)) - .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) - .build(); - String expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_index\":[\"index1\"],\"input_query\":{\"size\":1,\"query\":{\"match_all\":{\"boost\":1.0}}}}"; + SearchQueryInputDataset inputDataset = SearchQueryInputDataset + .builder() + .indices(Arrays.asList(indexName)) + .searchSourceBuilder(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1)) + .build(); + String expectedInputStr = + "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_index\":[\"index1\"],\"input_query\":{\"size\":1,\"query\":{\"match_all\":{\"boost\":1.0}}}}"; testParse(FunctionName.LINEAR_REGRESSION, inputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); assertEquals(1, ((SearchQueryInputDataset) parsedInput.getInputDataset()).getIndices().size()); @@ -96,15 +100,20 @@ public void parse_LinearRegression() throws IOException { @NonNull DataFrame dataFrame = new DefaultDataFrame( - new ColumnMeta[] { ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build() }); + new ColumnMeta[] { ColumnMeta.builder().name("value").columnType(ColumnType.FLOAT).build() } + ); dataFrame.appendRow(new Float[] { 1.0f }); DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(dataFrame).build(); - expectedInputStr = "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_data\":{\"column_metas\":[{\"name\":\"value\",\"column_type\":\"FLOAT\"}],\"rows\":[{\"values\":[{\"column_type\":\"FLOAT\",\"value\":1.0}]}]}}"; + expectedInputStr = + "{\"algorithm\":\"LINEAR_REGRESSION\",\"input_data\":{\"column_metas\":[{\"name\":\"value\",\"column_type\":\"FLOAT\"}],\"rows\":[{\"values\":[{\"column_type\":\"FLOAT\",\"value\":1.0}]}]}}"; testParse(FunctionName.LINEAR_REGRESSION, dataFrameInputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); assertEquals(1, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().size()); - assertEquals(1.0f, ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().getRow(0) - .getValue(0).floatValue(), 1e-5); + assertEquals( + 1.0f, + ((DataFrameInputDataset) parsedInput.getInputDataset()).getDataFrame().getRow(0).getValue(0).floatValue(), + 1e-5 + ); }); } @@ -112,13 +121,15 @@ private void parse_NLPModel(FunctionName functionName) throws IOException { String sentence = "test sentence"; String column = "column1"; Integer position = 1; - ModelResultFilter resultFilter = ModelResultFilter.builder() - .targetResponse(Arrays.asList(column)) - .targetResponsePositions(Arrays.asList(position)) - .build(); + ModelResultFilter resultFilter = ModelResultFilter + .builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); TextDocsInputDataSet inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList(sentence)).resultFilter(resultFilter).build(); - String expectedInputStr = "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}"; + String expectedInputStr = + "{\"algorithm\":\"functionName\",\"text_docs\":[\"test sentence\"],\"return_bytes\":false,\"return_number\":false,\"target_response\":[\"column1\"],\"target_response_positions\":[1]}"; expectedInputStr = expectedInputStr.replace("functionName", functionName.toString()); testParse(functionName, inputDataset, expectedInputStr, parsedInput -> { assertNotNull(parsedInput.getInputDataset()); @@ -150,7 +161,6 @@ private void parse_NLPModel_NullResultFilter(FunctionName functionName) throws I }); } - @Test public void parse_NLPRelated_NullResultFilter() throws IOException { parse_NLPModel_NullResultFilter(FunctionName.TEXT_EMBEDDING); @@ -158,7 +168,8 @@ public void parse_NLPRelated_NullResultFilter() throws IOException { parse_NLPModel_NullResultFilter(FunctionName.SPARSE_ENCODING); } - private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer verify) throws IOException { + private void testParse(FunctionName algorithm, MLInputDataset inputDataset, String expectedInputStr, Consumer verify) + throws IOException { MLInput input = MLInput.builder().inputDataset(inputDataset).algorithm(algorithm).build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); input.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -166,9 +177,13 @@ private void testParse(FunctionName algorithm, MLInputDataset inputDataset, Stri String jsonStr = builder.toString(); assertEquals(expectedInputStr, jsonStr); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, algorithm.name()); assertEquals(input.getFunctionName(), parsedInput.getFunctionName()); diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java b/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java index 34bff7c246..34682155da 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/anomalylocalization/AnomalyLocalizationInputTests.java @@ -5,15 +5,20 @@ package org.opensearch.ml.common.input.execute.anomalylocalization; +import static org.junit.Assert.assertEquals; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; + import org.junit.Test; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -21,67 +26,93 @@ import org.opensearch.search.SearchModule; import org.opensearch.search.aggregations.AggregationBuilders; -import java.util.Arrays; -import java.util.Collections; -import java.util.Optional; - -import static org.junit.Assert.assertEquals; - public class AnomalyLocalizationInputTests { - @Test - public void testXContentFullObject() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), - Arrays.asList(AggregationBuilders.max("max").field("field"), - AggregationBuilders.min("min").field("field")), - "@timestamp", 0L, 10L, 1L, 2, Optional.of(3L), - Optional.of(QueryBuilders.matchAllQuery())); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder = input.toXContent(builder, null); - String json = builder.toString(); + @Test + public void testXContentFullObject() throws Exception { + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field"), AggregationBuilders.min("min").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.of(3L), + Optional.of(QueryBuilders.matchAllQuery()) + ); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = input.toXContent(builder, null); + String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); - parser.nextToken(); - AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); + parser.nextToken(); + AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); - assertEquals(input, newInput); - } + assertEquals(input, newInput); + } - @Test - public void testXContentMissingAnomalyStartFilter() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), - Arrays.asList(AggregationBuilders.max("max").field("field")), - "@timestamp", 0L, 10L, 1L, 2, Optional.empty(), Optional.empty()); - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder = input.toXContent(builder, null); - String json = builder.toString(); + @Test + public void testXContentMissingAnomalyStartFilter() throws Exception { + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.empty(), + Optional.empty() + ); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = input.toXContent(builder, null); + String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, json); - parser.nextToken(); - AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); + parser.nextToken(); + AnomalyLocalizationInput newInput = AnomalyLocalizationInput.parse(parser); - assertEquals(input, newInput); - } + assertEquals(input, newInput); + } - @Test - public void testWriteable() throws Exception { - AnomalyLocalizationInput input = new AnomalyLocalizationInput("indexName", Arrays.asList("attribute"), - Arrays.asList(AggregationBuilders.max("max").field("field"), - AggregationBuilders.min("min").field("field")), - "@timestamp", 0L, 10L, 1L, 2, Optional.of(3L), - Optional.of(QueryBuilders.matchAllQuery())); + @Test + public void testWriteable() throws Exception { + AnomalyLocalizationInput input = new AnomalyLocalizationInput( + "indexName", + Arrays.asList("attribute"), + Arrays.asList(AggregationBuilders.max("max").field("field"), AggregationBuilders.min("min").field("field")), + "@timestamp", + 0L, + 10L, + 1L, + 2, + Optional.of(3L), + Optional.of(QueryBuilders.matchAllQuery()) + ); - BytesStreamOutput out = new BytesStreamOutput(); - input.writeTo(out); - StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), - new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()) - .getNamedWriteables())); - AnomalyLocalizationInput newInput = new AnomalyLocalizationInput(in); + BytesStreamOutput out = new BytesStreamOutput(); + input.writeTo(out); + StreamInput in = new NamedWriteableAwareStreamInput( + out.bytes().streamInput(), + new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()) + ); + AnomalyLocalizationInput newInput = new AnomalyLocalizationInput(in); - assertEquals(input, newInput); - } + assertEquals(input, newInput); + } } diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java index 832e07b7e7..c9ad55fc22 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/metricscorrelation/MetricsCorrelationInputTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.execute.metricscorrelation; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class MetricsCorrelationInputTest { MetricsCorrelationInput input; @@ -39,9 +39,9 @@ public class MetricsCorrelationInputTest { @Before public void setUp() { List inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); input = MetricsCorrelationInput.builder().inputData(inputData).build(); } @@ -57,9 +57,9 @@ public void constructor_variableLengthInput() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("All the input metrics sizes should be same"); List inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); MetricsCorrelationInput.builder().inputData(inputData).build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java index fec5c99ced..19535c1d7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/execute/samplecalculator/LocalSampleCalculatorInputTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.execute.samplecalculator; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class LocalSampleCalculatorInputTest { LocalSampleCalculatorInput input; @@ -42,10 +42,7 @@ public void setUp() { inputData.add(1.0); inputData.add(2.0); inputData.add(3.0); - input = LocalSampleCalculatorInput.builder() - .operation("sum") - .inputData(inputData) - .build(); + input = LocalSampleCalculatorInput.builder().operation("sum").inputData(inputData).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java index 10468d44a3..1c652b4162 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/nlp/TextDocsMLInputTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.input.nlp; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -18,15 +27,6 @@ import org.opensearch.ml.common.output.model.ModelResultFilter; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - public class TextDocsMLInputTest { MLInput input; @@ -38,10 +38,14 @@ public class TextDocsMLInputTest { @Before public void setUp() throws Exception { - ModelResultFilter resultFilter = ModelResultFilter.builder().returnBytes(true).returnNumber(true) - .targetResponse(Arrays.asList("field1")).targetResponsePositions(Arrays.asList(2)).build(); - MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")) - .resultFilter(resultFilter).build(); + ModelResultFilter resultFilter = ModelResultFilter + .builder() + .returnBytes(true) + .returnNumber(true) + .targetResponse(Arrays.asList("field1")) + .targetResponsePositions(Arrays.asList(2)) + .build(); + MLInputDataset inputDataset = TextDocsInputDataSet.builder().docs(Arrays.asList("doc1", "doc2")).resultFilter(resultFilter).build(); input = new TextDocsMLInput(algorithm, inputDataset); } @@ -56,20 +60,26 @@ public void parseTextDocsMLInput() throws IOException { @Test public void parseTextDocsMLInput_OldWay() throws IOException { - String jsonStr = "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; + String jsonStr = + "{\"text_docs\": [ \"doc1\", \"doc2\", null ],\"return_number\": true, \"return_bytes\": true,\"target_response\": [ \"field1\" ], \"target_response_positions\": [2]}"; parseMLInput(jsonStr, 3); } @Test public void parseTextDocsMLInput_NewWay() throws IOException { - String jsonStr = "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; + String jsonStr = + "{\"text_docs\":[\"doc1\",\"doc2\"],\"result_filter\":{\"return_bytes\":true,\"return_number\":true,\"target_response\":[\"field1\"], \"target_response_positions\": [2]}}"; parseMLInput(jsonStr, 2); } private void parseMLInput(String jsonStr, int docSize) throws IOException { - XContentParser parser = XContentType.JSON.xContent() - .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLInput parsedInput = MLInput.parse(parser, input.getFunctionName().name()); diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java index 544f43ec1e..0e87cf62d4 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/ad/AnomalyDetectionLibSVMParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.ad; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class AnomalyDetectionLibSVMParamsTest { AnomalyDetectionLibSVMParams params; @@ -30,15 +30,16 @@ public class AnomalyDetectionLibSVMParamsTest { @Before public void setUp() { - params = AnomalyDetectionLibSVMParams.builder() - .kernelType(AnomalyDetectionLibSVMParams.ADKernelType.POLY) - .gamma(1.0) - .nu(0.5) - .cost(1.0) - .coeff(0.1) - .epsilon(0.2) - .degree(2) - .build(); + params = AnomalyDetectionLibSVMParams + .builder() + .kernelType(AnomalyDetectionLibSVMParams.ADKernelType.POLY) + .gamma(1.0) + .nu(0.5) + .cost(1.0) + .coeff(0.1) + .epsilon(0.2) + .degree(2) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java index b4cd2f0c81..a9886a0d3f 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/KMeansParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class KMeansParamsTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -28,7 +28,7 @@ public class KMeansParamsTest { KMeansParams params; private Function function = parser -> { try { - return (KMeansParams)KMeansParams.parse(parser); + return (KMeansParams) KMeansParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse KMeansParams", e); } @@ -36,11 +36,7 @@ public class KMeansParamsTest { @Before public void setUp() { - params = KMeansParams.builder() - .centroids(2) - .iterations(10) - .distanceType(KMeansParams.DistanceType.COSINE) - .build(); + params = KMeansParams.builder().centroids(2).iterations(10).distanceType(KMeansParams.DistanceType.COSINE).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java index 0a29ebfffa..e6cbba3722 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/clustering/RCFSummarizeParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.clustering; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class RCFSummarizeParamsTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -28,7 +28,7 @@ public class RCFSummarizeParamsTest { RCFSummarizeParams params; private Function function = parser -> { try { - return (RCFSummarizeParams)RCFSummarizeParams.parse(parser); + return (RCFSummarizeParams) RCFSummarizeParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse RCFSummarizeParams", e); } @@ -36,11 +36,7 @@ public class RCFSummarizeParamsTest { @Before public void setUp() { - params = RCFSummarizeParams.builder() - .maxK(2) - .initialK(10) - .distanceType(RCFSummarizeParams.DistanceType.L1) - .build(); + params = RCFSummarizeParams.builder().maxK(2).initialK(10).distanceType(RCFSummarizeParams.DistanceType.L1).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java index 7137763146..e9e2737174 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/BatchRCFParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.rcf; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class BatchRCFParamsTest { BatchRCFParams params; @@ -30,13 +30,7 @@ public class BatchRCFParamsTest { @Before public void setUp() { - params = BatchRCFParams.builder() - .numberOfTrees(10) - .shingleSize(8) - .sampleSize(256) - .outputAfter(32) - .trainingDataSize(200) - .build(); + params = BatchRCFParams.builder().numberOfTrees(10).shingleSize(8).sampleSize(256).outputAfter(32).trainingDataSize(200).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java index e87973a98a..e2e7050a72 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/rcf/FitRCFParamsTest.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.input.parameter.rcf; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,11 +17,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; - public class FitRCFParamsTest { FitRCFParams params; @@ -30,17 +30,18 @@ public class FitRCFParamsTest { @Before public void setUp() { - params = FitRCFParams.builder() - .numberOfTrees(10) - .shingleSize(8) - .sampleSize(256) - .outputAfter(32) - .timeDecay(0.001) - .anomalyRate(0.005) - .timeField("timestamp") - .dateFormat("yyyy-mm-dd") - .timeZone("UTC") - .build(); + params = FitRCFParams + .builder() + .numberOfTrees(10) + .shingleSize(8) + .sampleSize(256) + .outputAfter(32) + .timeDecay(0.001) + .anomalyRate(0.005) + .timeField("timestamp") + .dateFormat("yyyy-mm-dd") + .timeZone("UTC") + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java index be71883e92..dc9d92d0cb 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LinearRegressionParamsTest.java @@ -5,6 +5,13 @@ package org.opensearch.ml.common.input.parameter.regression; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,13 +21,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; - public class LinearRegressionParamsTest { @Rule @@ -39,21 +39,21 @@ public class LinearRegressionParamsTest { @Before public void setUp() { params = LinearRegressionParams - .builder() - .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) - .optimizerType(LinearRegressionParams.OptimizerType.ADAM) - .learningRate(0.1) - .momentumType(LinearRegressionParams.MomentumType.NESTEROV) - .momentumFactor(0.2) - .epsilon(0.3) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.ADAM) + .learningRate(0.1) + .momentumType(LinearRegressionParams.MomentumType.NESTEROV) + .momentumFactor(0.2) + .epsilon(0.3) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); } @Test @@ -69,21 +69,21 @@ public void readInputStream_Success() throws IOException { @Test public void parse_PassIntValueToDoubleField() throws IOException { LinearRegressionParams params = LinearRegressionParams - .builder() - .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) - .optimizerType(LinearRegressionParams.OptimizerType.ADAM) - .learningRate(0.1) - .momentumType(LinearRegressionParams.MomentumType.NESTEROV) - .momentumFactor(0.2) - .epsilon(3.0) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LinearRegressionParams.ObjectiveType.ABSOLUTE_LOSS) + .optimizerType(LinearRegressionParams.OptimizerType.ADAM) + .learningRate(0.1) + .momentumType(LinearRegressionParams.MomentumType.NESTEROV) + .momentumFactor(0.2) + .epsilon(3.0) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); String paramsStr = contentObjectToString(params); testParseFromString(params, paramsStr.replace("\"epsilon\":3.0,", "\"epsilon\":3,"), function); } diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java index ab45c9e41e..12b5a7f5cb 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.input.parameter.regression; +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; +import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -14,14 +22,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.ml.common.TestHelper.contentObjectToString; -import static org.opensearch.ml.common.TestHelper.testParseFromString; -import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME; - public class LogisticRegressionParamsTest { @Rule @@ -40,21 +40,21 @@ public class LogisticRegressionParamsTest { @Before public void setUp() { logisticRegressionParams = LogisticRegressionParams - .builder() - .objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS) - .optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD) - .learningRate(0.1) - .momentumType(LogisticRegressionParams.MomentumType.STANDARD) - .momentumFactor(0.2) - .epsilon(0.3) - .beta1(0.4) - .beta2(0.5) - .decayRate(0.6) - .epochs(1) - .batchSize(2) - .seed(3L) - .target("test_target") - .build(); + .builder() + .objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS) + .optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD) + .learningRate(0.1) + .momentumType(LogisticRegressionParams.MomentumType.STANDARD) + .momentumFactor(0.2) + .epsilon(0.3) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); } @Test @@ -114,8 +114,12 @@ public void parse_EmptyLogisticRegressionParams() throws IOException { @Test public void parse_LogisticRegressionParams_WrongExtraField() throws IOException { - TestHelper.testParseFromString(logisticRegressionParams, "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", function); + TestHelper + .testParseFromString( + logisticRegressionParams, + "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", + function + ); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java index 2ad3fcca39..9dbc858e36 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/sample/SampleAlgoParamsTest.java @@ -5,21 +5,21 @@ package org.opensearch.ml.common.input.parameter.sample; +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Test; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - public class SampleAlgoParamsTest { SampleAlgoParams params; private Function function = parser -> { try { - return (SampleAlgoParams)SampleAlgoParams.parse(parser); + return (SampleAlgoParams) SampleAlgoParams.parse(parser); } catch (IOException e) { throw new RuntimeException("failed to parse SampleAlgoParams", e); } @@ -27,9 +27,7 @@ public class SampleAlgoParamsTest { @Before public void setUp() { - params = SampleAlgoParams.builder() - .sampleParam(2) - .build(); + params = SampleAlgoParams.builder().sampleParam(2).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java index a01a955e7a..5ad201d9e1 100644 --- a/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/input/remote/RemoteInferenceMLInputTest.java @@ -1,5 +1,8 @@ package org.opensearch.ml.common.input.remote; +import java.io.IOException; +import java.util.Collections; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,9 +15,6 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; - public class RemoteInferenceMLInputTest { @Test @@ -22,7 +22,7 @@ public void constructor_parser() throws IOException { RemoteInferenceMLInput input = createRemoteInferenceMLInput(); Assert.assertNotNull(input.getInputDataset()); Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); } @@ -36,15 +36,20 @@ public void constructor_stream() throws IOException { RemoteInferenceMLInput input = new RemoteInferenceMLInput(output.bytes().streamInput()); Assert.assertNotNull(input.getInputDataset()); Assert.assertEquals(MLInputDataType.REMOTE, input.getInputDataset().getInputDataType()); - RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)input.getInputDataset(); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) input.getInputDataset(); Assert.assertEquals(1, inputDataSet.getParameters().size()); Assert.assertEquals("hello world", inputDataSet.getParameters().get("prompt")); } private static RemoteInferenceMLInput createRemoteInferenceMLInput() throws IOException { String jsonStr = "{ \"parameters\": { \"prompt\": \"hello world\" } }"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); RemoteInferenceMLInput input = new RemoteInferenceMLInput(parser, FunctionName.REMOTE); return input; diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java index 8bdf0564e2..ee14189592 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLModelFormatTests.java @@ -5,12 +5,12 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import static org.junit.Assert.assertEquals; - public class MLModelFormatTests { @Rule diff --git a/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java b/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java index c4f8e7e51f..713f793ddf 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MLModelStateTests.java @@ -5,12 +5,11 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.opensearch.ml.common.CommonValue; - -import static org.junit.Assert.assertEquals; public class MLModelStateTests { diff --git a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java index 4700039939..c115c9d1d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/MetricsCorrelationModelConfigTests.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -16,12 +22,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MetricsCorrelationModelConfigTests { MetricsCorrelationModelConfig config; @@ -31,10 +31,11 @@ public class MetricsCorrelationModelConfigTests { @Before public void setUp() { - config = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + config = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); function = parser -> { try { return MetricsCorrelationModelConfig.parse(parser); @@ -49,20 +50,23 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + assertEquals( + "{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", + configContent + ); } @Test public void nullFields_ModelType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model type is null"); - config = MetricsCorrelationModelConfig.builder() - .build(); + config = MetricsCorrelationModelConfig.builder().build(); } @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = + "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index acba744ced..876aafb58b 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -5,23 +5,23 @@ package org.opensearch.ml.common.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class TextEmbeddingModelConfigTests { TextEmbeddingModelConfig config; @@ -31,12 +31,13 @@ public class TextEmbeddingModelConfigTests { @Before public void setUp() { - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); function = parser -> { try { return TextEmbeddingModelConfig.parse(parser); @@ -51,39 +52,37 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + assertEquals( + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", + configContent + ); } @Test public void nullFields_ModelType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model type is null"); - config = TextEmbeddingModelConfig.builder() - .build(); + config = TextEmbeddingModelConfig.builder().build(); } - @Test public void nullFields_EmbeddingDimension() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("embedding dimension is null"); - config = TextEmbeddingModelConfig.builder().modelType("testModelType") - .build(); + config = TextEmbeddingModelConfig.builder().modelType("testModelType").build(); } @Test public void nullFields_FrameworkType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("framework type is null"); - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .embeddingDimension(100) - .build(); + config = TextEmbeddingModelConfig.builder().modelType("testModelType").embeddingDimension(100).build(); } @Test public void parse() throws IOException { - String content = "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; + String content = + "{\"wrong_field\":\"test_value\", \"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}"; TestHelper.testParseFromString(config, content, function); } diff --git a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java index 612ea3c104..857e92f5a3 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/MLPredictionOutputTest.java @@ -5,13 +5,17 @@ package org.opensearch.ml.common.output; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -23,12 +27,6 @@ import org.opensearch.ml.common.dataframe.IntValue; import org.opensearch.ml.common.dataframe.Row; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; - public class MLPredictionOutputTest { MLPredictionOutput output; @@ -40,11 +38,7 @@ public void setUp() { rows.add(new Row(new ColumnValue[] { new IntValue(1) })); rows.add(new Row(new ColumnValue[] { new IntValue(2) })); DataFrame dataFrame = new DefaultDataFrame(columnMetas, rows); - output = MLPredictionOutput.builder() - .taskId("test_task_id") - .status("test_status") - .predictionResult(dataFrame) - .build(); + output = MLPredictionOutput.builder().taskId("test_task_id").status("test_status").predictionResult(dataFrame).build(); } @Test @@ -52,10 +46,13 @@ public void toXContent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); output.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"prediction_result\":" + - "{\"column_metas\":[{\"name\":\"test\",\"column_type\":\"INTEGER\"}],\"rows\":[{\"values\":" + - "[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":[{\"column_type\":\"INTEGER\"," + - "\"value\":2}]}]}}", jsonStr); + assertEquals( + "{\"task_id\":\"test_task_id\",\"status\":\"test_status\",\"prediction_result\":" + + "{\"column_metas\":[{\"name\":\"test\",\"column_type\":\"INTEGER\"}],\"rows\":[{\"values\":" + + "[{\"column_type\":\"INTEGER\",\"value\":1}]},{\"values\":[{\"column_type\":\"INTEGER\"," + + "\"value\":2}]}]}}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/output/MLTrainingOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/MLTrainingOutputTest.java index 01b997adfa..1c3706adf6 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/MLTrainingOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/MLTrainingOutputTest.java @@ -5,21 +5,21 @@ package org.opensearch.ml.common.output; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Test; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import static org.junit.Assert.assertEquals; - public class MLTrainingOutputTest { @Test public void parse_MLTrain() throws IOException { - MLTrainingOutput output = MLTrainingOutput.builder() - .modelId("test_modelId").status("test_status").build(); + MLTrainingOutput output = MLTrainingOutput.builder().modelId("test_modelId").status("test_status").build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); output.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); diff --git a/common/src/test/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutputTests.java b/common/src/test/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutputTests.java index d14508ea43..5be23ca9dc 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/output/execute/anomalylocalization/AnomalyLocalizationOutputTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.output.execute.anomalylocalization; +import static org.junit.Assert.assertEquals; + import java.util.Arrays; import org.junit.Before; @@ -16,8 +18,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import static org.junit.Assert.assertEquals; - public class AnomalyLocalizationOutputTests { private AnomalyLocalizationOutput output; diff --git a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorTest.java index e600fa3325..6f41fc963a 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorTest.java @@ -5,24 +5,21 @@ package org.opensearch.ml.common.output.execute.metricscorrelation; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; -import org.opensearch.ml.common.output.model.MLResultDataType; -import org.opensearch.ml.common.output.model.ModelTensor; - -import java.io.IOException; -import java.nio.ByteBuffer; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class MCorrModelTensorTest { @@ -33,11 +30,12 @@ public class MCorrModelTensorTest { @Before public void setUp() { - mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); } @Test @@ -78,4 +76,3 @@ public void test_StreamInAndOut_NullValue() throws IOException { assertEquals(parsedTensor, tensor); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorsTest.java index 516a50cb5b..8ed35a486b 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MCorrModelTensorsTest.java @@ -5,28 +5,24 @@ package org.opensearch.ml.common.output.execute.metricscorrelation; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; -import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelResultFilter; -import org.opensearch.ml.common.output.model.ModelTensor; -import org.opensearch.ml.common.output.model.ModelTensors; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; public class MCorrModelTensorsTest { @@ -39,16 +35,18 @@ public class MCorrModelTensorsTest { public void setUp() { String column = "model_tensor"; Integer position = 1; - modelResultFilter = ModelResultFilter.builder() - .targetResponse(Arrays.asList(column)) - .targetResponsePositions(Arrays.asList(position)) - .build(); + modelResultFilter = ModelResultFilter + .builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); mcorrModelTensors = MCorrModelTensors.builder().mCorrModelTensors(Arrays.asList(mCorrModelTensor)).build(); } @@ -81,4 +79,3 @@ public void test_StreamInAndOut_NullValue() throws IOException { assertEquals(parsedTensors.getMCorrModelTensors(), tensors.getMCorrModelTensors()); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MetricsCorrelationOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MetricsCorrelationOutputTest.java index 6cd7c3ec3f..2f6e3b568b 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MetricsCorrelationOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/execute/metricscorrelation/MetricsCorrelationOutputTest.java @@ -5,13 +5,7 @@ package org.opensearch.ml.common.output.execute.metricscorrelation; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; -import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; -import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; +import static org.junit.Assert.*; import java.io.IOException; import java.util.ArrayList; @@ -19,7 +13,13 @@ import java.util.List; import java.util.function.Consumer; -import static org.junit.Assert.*; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensor; +import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; +import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; public class MetricsCorrelationOutputTest { @@ -28,11 +28,12 @@ public class MetricsCorrelationOutputTest { @Before public void setUp() { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); @@ -48,8 +49,8 @@ public void readInputStream_Success() throws IOException { MCorrModelTensor modelTensor = modelTensors.getMCorrModelTensors().get(0); float[] events = modelTensor.getEvent_pattern(); long[] metrics = modelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); }); } @@ -57,9 +58,7 @@ public void readInputStream_Success() throws IOException { @Test public void readInputStream_NullField() throws IOException { MetricsCorrelationOutput modelTensorOutput = MetricsCorrelationOutput.builder().build(); - readInputStream(modelTensorOutput, parsedTensorOutput -> { - assertNull(parsedTensorOutput.getModelOutput()); - }); + readInputStream(modelTensorOutput, parsedTensorOutput -> { assertNull(parsedTensorOutput.getModelOutput()); }); } private void readInputStream(MetricsCorrelationOutput input, Consumer verify) throws IOException { @@ -67,8 +66,8 @@ private void readInputStream(MetricsCorrelationOutput input, Consumer { - assertArrayEquals(resultFilter.getTargetResponse().toArray(new String[0]), parsedFilter.getTargetResponse().toArray(new String[0])); + assertArrayEquals( + resultFilter.getTargetResponse().toArray(new String[0]), + parsedFilter.getTargetResponse().toArray(new String[0]) + ); assertFalse(parsedFilter.returnBytes); assertFalse(parsedFilter.returnNumber); }); @@ -48,7 +51,6 @@ public void readInputStream_NullFields() throws IOException { }); } - private void readInputStream(ModelResultFilter input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java index d9f4c2c968..67690ed2bf 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorOutputTest.java @@ -1,10 +1,8 @@ package org.opensearch.ml.common.output.model; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.ml.common.output.MLOutputType; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import java.io.IOException; import java.nio.ByteBuffer; @@ -13,9 +11,11 @@ import java.util.List; import java.util.function.Consumer; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ml.common.output.MLOutputType; public class ModelTensorOutputTest { @@ -24,11 +24,16 @@ public class ModelTensorOutputTest { @Before public void setUp() throws Exception { - value = new Float[]{1.0f, 2.0f, 3.0f}; + value = new Float[] { 1.0f, 2.0f, 3.0f }; List outputs = new ArrayList<>(); - ModelTensor tensor = ModelTensor.builder().data(value) - .name("test").shape(new long[]{1, 3}).dataType(MLResultDataType.FLOAT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})).build(); + ModelTensor tensor = ModelTensor + .builder() + .data(value) + .name("test") + .shape(new long[] { 1, 3 }) + .dataType(MLResultDataType.FLOAT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); List mlModelTensors = Arrays.asList(tensor); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(mlModelTensors).build(); outputs.add(modelTensors); diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java index 68904cb390..d41ba82fe4 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java @@ -5,24 +5,24 @@ package org.opensearch.ml.common.output.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ModelTensorTest { @Rule @@ -35,15 +35,16 @@ public void setUp() { Map dataMap = new HashMap<>(); dataMap.put("key1", "test value1"); dataMap.put("key2", "test value2"); - modelTensor = ModelTensor.builder() - .name("model_tensor") - .data(new Number[]{1, 2, 3}) - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .result("test result") - .dataAsMap(dataMap) - .build(); + modelTensor = ModelTensor + .builder() + .name("model_tensor") + .data(new Number[] { 1, 2, 3 }) + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .result("test result") + .dataAsMap(dataMap) + .build(); } @Test @@ -61,13 +62,16 @@ public void test_ModelTensorSuccess() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelTensor.toXContent(builder, EMPTY_PARAMS); String modelTensorContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"name\":\"model_tensor\"," + - "\"data_type\":\"INT32\"," + - "\"shape\":[1,2,3]," + - "\"data\":[1,2,3]," + - "\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," + - "\"result\":\"test result\"," + - "\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", modelTensorContent); + assertEquals( + "{\"name\":\"model_tensor\"," + + "\"data_type\":\"INT32\"," + + "\"shape\":[1,2,3]," + + "\"data\":[1,2,3]," + + "\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}," + + "\"result\":\"test result\"," + + "\"dataAsMap\":{\"key1\":\"test value1\",\"key2\":\"test value2\"}}", + modelTensorContent + ); } @Test @@ -95,26 +99,27 @@ public void test_UnknownDataType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("data type is null"); - ModelTensor.builder() - .name("null_data") - .data(new Number[]{1, 2, 3}) - .shape(null) - .dataType(MLResultDataType.UNKNOWN) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + ModelTensor + .builder() + .name("null_data") + .data(new Number[] { 1, 2, 3 }) + .shape(null) + .dataType(MLResultDataType.UNKNOWN) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); } @Test public void test_NullDataType() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("data type is null"); - ModelTensor.builder() - .name("null_data") - .data(new Number[]{1, 2, 3}) - .shape(null) - .dataType(null) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + ModelTensor + .builder() + .name("null_data") + .data(new Number[] { 1, 2, 3 }) + .shape(null) + .dataType(null) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java index f8e3fee984..a4f7dc51b1 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java @@ -5,23 +5,23 @@ package org.opensearch.ml.common.output.model; +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class ModelTensorsTest { @Rule @@ -34,18 +34,20 @@ public void setUp() { String sentence = "test sentence"; String column = "model_tensor"; Integer position = 1; - modelResultFilter = ModelResultFilter.builder() - .targetResponse(Arrays.asList(column)) - .targetResponsePositions(Arrays.asList(position)) - .build(); - - ModelTensor modelTensor = ModelTensor.builder() - .name("model_tensor") - .data(new Number[]{1, 2, 3}) - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) - .build(); + modelResultFilter = ModelResultFilter + .builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); + + ModelTensor modelTensor = ModelTensor + .builder() + .name("model_tensor") + .data(new Number[] { 1, 2, 3 }) + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[] { 0, 1, 0, 1 })) + .build(); modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); } @@ -55,7 +57,10 @@ public void test_ModelTensortoXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelTensors.toXContent(builder, EMPTY_PARAMS); String modelTensorContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", modelTensorContent); + assertEquals( + "{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", + modelTensorContent + ); } @Test @@ -80,14 +85,15 @@ public void test_StreamInAndOut_NullValue() throws IOException { @Test public void test_Filter() { - ModelTensor modelTensorFiltered = ModelTensor.builder() - .name("model_tensor") - .shape(new long[]{1, 2, 3,}) - .dataType(MLResultDataType.INT32) - .build(); + ModelTensor modelTensorFiltered = ModelTensor + .builder() + .name("model_tensor") + .shape(new long[] { 1, 2, 3, }) + .dataType(MLResultDataType.INT32) + .build(); modelTensors.filter(modelResultFilter); assertEquals(modelTensors.getMlModelTensors().size(), 1); - //assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered); + // assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered); } @Test @@ -112,7 +118,6 @@ public void test_ToAndFromBytes() throws IOException { assertEquals(bytes.length, bytesStreamOutput.bytes().toBytesRef().bytes.length); ModelTensors tensors = ModelTensors.fromBytes(bytes); - //assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); + // assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java b/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java index c417758058..73e6c74043 100644 --- a/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/output/sample/SampleAlgoOutputTest.java @@ -5,30 +5,26 @@ package org.opensearch.ml.common.output.sample; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.output.MLOutputType; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; - public class SampleAlgoOutputTest { SampleAlgoOutput output; @Before public void setUp() { - output = SampleAlgoOutput.builder() - .sampleResult(1.0) - .build(); + output = SampleAlgoOutput.builder().sampleResult(1.0).build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java index d404f49ab4..84d97df88a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorDeleteRequestTests.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLConnectorDeleteRequestTests { private String connectorId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlConnectorDeleteRequest.writeTo(bytesStreamOutput); MLConnectorDeleteRequest parsedConnector = new MLConnectorDeleteRequest(bytesStreamOutput.bytes().streamInput()); @@ -47,16 +46,14 @@ public void valid_Exception_NullConnectorId() { @Test public void validate_Success() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); ActionRequestValidationException actionRequestValidationException = mlConnectorDeleteRequest.validate(); assertNull(actionRequestValidationException); } @Test public void fromActionRequest_Success() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -91,9 +88,9 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void fromActionRequestWithConnectorDeleteRequest_Success() { - MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder() - .connectorId(connectorId).build(); - MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest.fromActionRequest(mlConnectorDeleteRequest); + MLConnectorDeleteRequest mlConnectorDeleteRequest = MLConnectorDeleteRequest.builder().connectorId(connectorId).build(); + MLConnectorDeleteRequest mlConnectorDeleteRequestFromActionRequest = MLConnectorDeleteRequest + .fromActionRequest(mlConnectorDeleteRequest); assertSame(mlConnectorDeleteRequest, mlConnectorDeleteRequestFromActionRequest); assertEquals(mlConnectorDeleteRequest.getConnectorId(), mlConnectorDeleteRequestFromActionRequest.getConnectorId()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java index 53fcce560b..fce8492957 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetRequestTests.java @@ -3,9 +3,13 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; @@ -16,11 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLConnectorGetRequestTests { private String connectorId; @@ -95,4 +94,3 @@ public void validate_Success() { assertNull(actionRequestValidationException); } } - diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java index 417f77506c..f1582ec2e8 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLConnectorGetResponseTests.java @@ -5,27 +5,27 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnectorTest; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLConnectorGetResponseTests { Connector connector; @@ -59,16 +59,19 @@ public void toXContentTest() throws IOException { mlConnectorGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"test_connector_name\"," + - "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"]," + - "\"access\":\"public\"}", jsonStr); + assertEquals( + "{\"name\":\"test_connector_name\"," + + "\"version\":\"1\",\"description\":\"this is a test connector\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"]," + + "\"access\":\"public\"}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index b938d81941..51cd560dee 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -5,8 +5,12 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -19,12 +23,10 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -36,32 +38,25 @@ import org.opensearch.ml.common.connector.MLPreProcessFunction; import org.opensearch.search.SearchModule; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - public class MLCreateConnectorInputTests { private MLCreateConnectorInput mlCreateConnectorInput; private MLCreateConnectorInput mlCreateDryRunConnectorInput; @Rule public final ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"name\":\"test_connector_name\"," + - "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + - "\"access_mode\":\"PUBLIC\"}"; + private final String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":\"test input value\"},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\"}"; @Before - public void setUp(){ + public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; String url = "https://test.com"; @@ -70,78 +65,88 @@ public void setUp(){ String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); - - mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of(action)) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); - - mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder() - .dryRun(true) - .build(); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + mlCreateConnectorRequestBody, + preProcessFunction, + postProcessFunction + ); + + mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); + + mlCreateDryRunConnectorInput = MLCreateConnectorInput.builder().dryRun(true).build(); } @Test public void constructorMLCreateConnectorInput_NullName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector name is null"); - MLCreateConnectorInput.builder() - .name(null) - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name(null) + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test public void constructorMLCreateConnectorInput_NullVersion() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector version is null"); - MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version(null) - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version(null) + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test public void constructorMLCreateConnectorInput_NullProtocol() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Connector protocol is null"); - MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol(null) - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of()) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol(null) + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test @@ -164,23 +169,21 @@ public void testToXContent_NullFields() throws Exception { @Test public void testParse() throws Exception { - testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("test_connector_name", parsedInput.getName()); - }); + testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); }); } @Test public void testParse_ArrayParameter() throws Exception { - String expectedInputStr = "{\"name\":\"test_connector_name\"," + - "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + - "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," + - "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + - "\"headers\":{\"api_key\":\"${credential.key}\"}," + - "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + - "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + - "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + - "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + - "\"access_mode\":\"PUBLIC\"}"; + String expectedInputStr = "{\"name\":\"test_connector_name\"," + + "\"description\":\"this is a test connector\",\"version\":\"1\",\"protocol\":\"http\"," + + "\"parameters\":{\"input\":[\"test input value\"]},\"credential\":{\"key\":\"test_key_value\"}," + + "\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":\"https://test.com\"," + + "\"headers\":{\"api_key\":\"${credential.key}\"}," + + "\"request_body\":\"{\\\"input\\\": \\\"${parameters.input}\\\"}\"," + + "\"pre_process_function\":\"connector.pre_process.openai.embedding\"," + + "\"post_process_function\":\"connector.post_process.openai.embedding\"}]," + + "\"backend_roles\":[\"role1\",\"role2\"],\"add_all_backend_roles\":false," + + "\"access_mode\":\"PUBLIC\"}"; testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("test_connector_name", parsedInput.getName()); assertEquals(1, parsedInput.getParameters().size()); @@ -204,11 +207,12 @@ public void readInputStream_Success() throws IOException { @Test public void readInputStream_SuccessWithNullFields() throws IOException { - MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .version("1") - .protocol("http") - .build(); + MLCreateConnectorInput mlCreateMinimalConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .version("1") + .protocol("http") + .build(); readInputStream(mlCreateMinimalConnectorInput, parsedInput -> { assertEquals(mlCreateMinimalConnectorInput.getName(), parsedInput.getName()); assertNull(parsedInput.getActions()); @@ -216,8 +220,13 @@ public void readInputStream_SuccessWithNullFields() throws IOException { } private void testParseFromJsonString(String expectedInputString, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputString); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputString + ); parser.nextToken(); MLCreateConnectorInput parsedInput = MLCreateConnectorInput.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java index 5310be6582..b8cddf6af3 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorRequestTests.java @@ -5,6 +5,11 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; import java.util.Arrays; @@ -23,16 +28,11 @@ import org.opensearch.ml.common.connector.MLPostProcessFunction; import org.opensearch.ml.common.connector.MLPreProcessFunction; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLCreateConnectorRequestTests { private MLCreateConnectorInput mlCreateConnectorInput; @Before - public void setUp(){ + public void setUp() { ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT; String method = "POST"; String url = "https://test.com"; @@ -41,66 +41,95 @@ public void setUp(){ String mlCreateConnectorRequestBody = "{\"input\": \"${parameters.input}\"}"; String preProcessFunction = MLPreProcessFunction.TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT; String postProcessFunction = MLPostProcessFunction.OPENAI_EMBEDDING; - ConnectorAction action = new ConnectorAction(actionType, method, url, headers, mlCreateConnectorRequestBody, preProcessFunction, postProcessFunction); - - mlCreateConnectorInput = MLCreateConnectorInput.builder() - .name("test_connector_name") - .description("this is a test connector") - .version("1") - .protocol("http") - .parameters(Map.of("input", "test input value")) - .credential(Map.of("key", "test_key_value")) - .actions(List.of(action)) - .access(AccessMode.PUBLIC) - .backendRoles(Arrays.asList("role1", "role2")) - .addAllBackendRoles(false) - .build(); + ConnectorAction action = new ConnectorAction( + actionType, + method, + url, + headers, + mlCreateConnectorRequestBody, + preProcessFunction, + postProcessFunction + ); + + mlCreateConnectorInput = MLCreateConnectorInput + .builder() + .name("test_connector_name") + .description("this is a test connector") + .version("1") + .protocol("http") + .parameters(Map.of("input", "test input value")) + .credential(Map.of("key", "test_key_value")) + .actions(List.of(action)) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList("role1", "role2")) + .addAllBackendRoles(false) + .build(); } @Test - public void writeTo_Success() throws IOException { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder().mlCreateConnectorInput(mlCreateConnectorInput).build(); + public void writeTo_Success() throws IOException { + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest + .builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); BytesStreamOutput output = new BytesStreamOutput(); mlCreateConnectorRequest.writeTo(output); MLCreateConnectorRequest parsedRequest = new MLCreateConnectorRequest(output.bytes().streamInput()); assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getName(), parsedRequest.getMlCreateConnectorInput().getName()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getAccess(), parsedRequest.getMlCreateConnectorInput().getAccess()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getProtocol(), parsedRequest.getMlCreateConnectorInput().getProtocol()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getBackendRoles(), parsedRequest.getMlCreateConnectorInput().getBackendRoles()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getActions(), parsedRequest.getMlCreateConnectorInput().getActions()); - assertEquals(mlCreateConnectorRequest.getMlCreateConnectorInput().getParameters(), parsedRequest.getMlCreateConnectorInput().getParameters()); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getAccess(), + parsedRequest.getMlCreateConnectorInput().getAccess() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getProtocol(), + parsedRequest.getMlCreateConnectorInput().getProtocol() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getBackendRoles(), + parsedRequest.getMlCreateConnectorInput().getBackendRoles() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getActions(), + parsedRequest.getMlCreateConnectorInput().getActions() + ); + assertEquals( + mlCreateConnectorRequest.getMlCreateConnectorInput().getParameters(), + parsedRequest.getMlCreateConnectorInput().getParameters() + ); } @Test public void validate_Success() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest + .builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); assertNull(mlCreateConnectorRequest.validate()); } @Test public void validate_Exception_NullMLRegisterModelGroupInput() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .build(); + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder().build(); ActionRequestValidationException exception = mlCreateConnectorRequest.validate(); assertEquals("Validation Failed: 1: ML Connector input can't be null;", exception.getMessage()); } @Test public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest + .builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); assertSame(MLCreateConnectorRequest.fromActionRequest(mlCreateConnectorRequest), mlCreateConnectorRequest); } @Test public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest.builder() - .mlCreateConnectorInput(mlCreateConnectorInput) - .build(); + MLCreateConnectorRequest mlCreateConnectorRequest = MLCreateConnectorRequest + .builder() + .mlCreateConnectorInput(mlCreateConnectorInput) + .build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java index 8d58047980..c242146847 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorResponseTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.common.transport.connector; +import java.io.IOException; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -13,8 +15,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; - public class MLCreateConnectorResponseTests { @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java index 44e970f95c..49b013cdf2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLUpdateConnectorRequestTests.java @@ -5,6 +5,16 @@ package org.opensearch.ml.common.transport.connector; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.mockito.MockitoAnnotations; @@ -18,16 +28,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - public class MLUpdateConnectorRequestTests { private String connectorId; private MLCreateConnectorInput updateContent; @@ -38,10 +38,7 @@ public void setUp() { MockitoAnnotations.openMocks(this); this.connectorId = "test-connector_id"; this.updateContent = MLCreateConnectorInput.builder().description("new description").updateConnector(true).build(); - mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() - .connectorId(connectorId) - .updateContent(updateContent) - .build(); + mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder().connectorId(connectorId).updateContent(updateContent).build(); } @Test @@ -63,14 +60,22 @@ public void validate_Exception_NullConnectorId() { MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.builder().build(); Exception exception = updateConnectorRequest.validate(); - assertEquals("Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", exception.getMessage()); + assertEquals( + "Validation Failed: 1: ML connector id can't be null;2: Update connector content can't be null;", + exception.getMessage() + ); } @Test public void parse_success() throws IOException { String jsonStr = "{\"version\":\"new version\",\"description\":\"new description\"}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), null, jsonStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + jsonStr + ); parser.nextToken(); MLUpdateConnectorRequest updateConnectorRequest = MLUpdateConnectorRequest.parse(parser, connectorId); assertEquals(updateConnectorRequest.getConnectorId(), connectorId); @@ -81,7 +86,8 @@ public void parse_success() throws IOException { @Test public void fromActionRequest_Success() { - MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest + .builder() .connectorId(connectorId) .updateContent(updateContent) .build(); @@ -90,7 +96,8 @@ public void fromActionRequest_Success() { @Test public void fromActionRequest_Success_fromActionRequest() { - MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest.builder() + MLUpdateConnectorRequest mlUpdateConnectorRequest = MLUpdateConnectorRequest + .builder() .connectorId(connectorId) .updateContent(updateContent) .build(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java index e3f1583f13..8d93213a4a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelInputTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -12,16 +19,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; - -import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelInputTest { @@ -31,29 +28,31 @@ public class MLDeployModelInputTest { @Before public void setUp() throws Exception { Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("node1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("node1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); - mlDeployModelInput = mlDeployModelInput.builder() - .modelId("testModelId") - .taskId("testTaskId") - .modelContentHash("modelContentHash") - .nodeCount(3) - .coordinatingNodeId("coordinatingNodeId") - .mlTask(mlTask) - .build(); + mlDeployModelInput = mlDeployModelInput + .builder() + .modelId("testModelId") + .taskId("testTaskId") + .modelContentHash("modelContentHash") + .nodeCount(3) + .coordinatingNodeId("coordinatingNodeId") + .mlTask(mlTask) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java index cce0c463be..5d4136b06c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodeResponseTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,16 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodeResponseTest { @@ -29,12 +29,12 @@ public class MLDeployModelNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java index 938543f230..eb510d2a11 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesRequestTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -16,16 +26,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.dataset.MLInputDataType; -import java.io.IOException; -import java.net.InetAddress; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodesRequestTest { @@ -39,54 +39,63 @@ public class MLDeployModelNodesRequestTest { @Before public void setUp() throws Exception { localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode3 = new DiscoveryNode( - "foo3", - "foo3", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo3", + "foo3", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("node1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("node1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); } @Test public void testConstructorSerialization1() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + String[] nodeIds = { "id1", "id2", "id3" }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); BytesStreamOutput output = new BytesStreamOutput(); @@ -95,54 +104,103 @@ public void testConstructorSerialization1() throws IOException { assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHash", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHash", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(3, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeId", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorSerialization2() throws IOException { - DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + DiscoveryNode[] nodeIds = { localNode1, localNode2, localNode3 }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHash", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHash", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(3, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeId", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeId", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorSerialization3() throws IOException { MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(localNode1, localNode2, localNode3) + new MLDeployModelNodesRequest(localNode1, localNode2, localNode3) ); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setModelId("modelIdSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setTaskId("taskIdSetDuringTest"); - MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setModelContentHash("modelContentHashSetDuringTest"); + MLDeployModelNodeRequest + .getMLDeployModelNodesRequest() + .getMlDeployModelInput() + .setModelContentHash("modelContentHashSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setNodeCount(2); - MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setCoordinatingNodeId("coordinatingNodeIdSetDuringTest"); + MLDeployModelNodeRequest + .getMLDeployModelNodesRequest() + .getMlDeployModelInput() + .setCoordinatingNodeId("coordinatingNodeIdSetDuringTest"); MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().setMlTask(mlTask); assertNotNull(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); assertEquals("modelIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); assertEquals("taskIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getTaskId()); - assertEquals("modelContentHashSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash()); + assertEquals( + "modelContentHashSetDuringTest", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelContentHash() + ); assertEquals(2, MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getNodeCount().intValue()); - assertEquals("coordinatingNodeIdSetDuringTest", MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId()); - assertEquals(mlTask.getTaskId(), MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId()); + assertEquals( + "coordinatingNodeIdSetDuringTest", + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getCoordinatingNodeId() + ); + assertEquals( + mlTask.getTaskId(), + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getMlTask().getTaskId() + ); } @Test public void testConstructorFromInputStream() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLDeployModelInput deployModelInput = new MLDeployModelInput("modelId", "taskId", "modelContentHash", 3, "coordinatingNodeId", true, mlTask); + String[] nodeIds = { "id1", "id2", "id3" }; + MLDeployModelInput deployModelInput = new MLDeployModelInput( + "modelId", + "taskId", + "modelContentHash", + 3, + "coordinatingNodeId", + true, + mlTask + ); MLDeployModelNodeRequest MLDeployModelNodeRequest = new MLDeployModelNodeRequest( - new MLDeployModelNodesRequest(nodeIds, deployModelInput) + new MLDeployModelNodesRequest(nodeIds, deployModelInput) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); MLDeployModelNodeRequest.writeTo(bytesStreamOutput); @@ -150,8 +208,10 @@ public void testConstructorFromInputStream() throws IOException { MLDeployModelNodeRequest parsedNodeRequest = new MLDeployModelNodeRequest(streamInput); assertNotNull(parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput()); - assertEquals(MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId(), - parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId()); + assertEquals( + MLDeployModelNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId(), + parsedNodeRequest.getMLDeployModelNodesRequest().getMlDeployModelInput().getModelId() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java index c7bd1dcc3a..e7e8fb6009 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelNodesResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.*; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,7 +16,6 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.transport.TransportAddress; @@ -17,13 +23,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.InetAddress; -import java.util.*; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLDeployModelNodesResponseTest { @@ -50,23 +49,25 @@ public void testSerializationDeserialization() throws IOException { public void testToXContent() throws IOException { List nodes = new ArrayList<>(); DiscoveryNode node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); Map modelToDeployStatus1 = new HashMap<>(); modelToDeployStatus1.put("modelName:version1", "response"); nodes.add(new MLDeployModelNodeResponse(node1, modelToDeployStatus1)); DiscoveryNode node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); Map modelToDeployStatus2 = new HashMap<>(); modelToDeployStatus2.put("modelName:version2", "response"); nodes.add(new MLDeployModelNodeResponse(node2, modelToDeployStatus2)); @@ -77,7 +78,8 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( - "{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}", - jsonStr); + "{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java index e2945dc212..b75193d970 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelRequestTest.java @@ -1,44 +1,43 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.*; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.junit.Assert.*; - public class MLDeployModelRequestTest { private MLDeployModelRequest mlDeployModelRequest; @Before public void setUp() throws Exception { - mlDeployModelRequest = mlDeployModelRequest.builder(). - modelId("modelId"). - modelNodeIds(new String[]{"modelNodeIds"}). - async(true). - dispatchTask(true). - build(); + mlDeployModelRequest = mlDeployModelRequest + .builder() + .modelId("modelId") + .modelNodeIds(new String[] { "modelNodeIds" }) + .async(true) + .dispatchTask(true) + .build(); } @Test public void testValidateWithBuilder() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId("modelId"). - build(); + MLDeployModelRequest request = mlDeployModelRequest.builder().modelId("modelId").build(); assertNull(request.validate()); } @@ -50,12 +49,13 @@ public void testValidateWithoutBuilder() { @Test public void validate_Exception_WithNullModelId() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId(null). - modelNodeIds(new String[]{"modelNodeIds"}). - async(true). - dispatchTask(true). - build(); + MLDeployModelRequest request = mlDeployModelRequest + .builder() + .modelId(null) + .modelNodeIds(new String[] { "modelNodeIds" }) + .async(true) + .dispatchTask(true) + .build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML model id can't be null;", exception.getMessage()); } @@ -69,7 +69,7 @@ public void writeTo_Success() throws IOException { request = new MLDeployModelRequest(bytesStreamOutput.bytes().streamInput()); assertEquals("modelId", request.getModelId()); - assertArrayEquals(new String[]{"modelNodeIds"}, request.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, request.getModelNodeIds()); assertTrue(request.isAsync()); assertTrue(request.isDispatchTask()); } @@ -92,9 +92,7 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void fromActionRequest_Success_WithMLDeployModelRequest() { - MLDeployModelRequest request = mlDeployModelRequest.builder(). - modelId("modelId"). - build(); + MLDeployModelRequest request = mlDeployModelRequest.builder().modelId("modelId").build(); assertSame(mlDeployModelRequest.fromActionRequest(request), request); } @@ -124,27 +122,33 @@ public void testParse() throws Exception { String expectedInputStr = "{\"node_ids\":[\"modelNodeIds\"]}"; parseFromJsonString(modelId, expectedInputStr, parsedInput -> { assertEquals("modelId", parsedInput.getModelId()); - assertArrayEquals(new String [] {"modelNodeIds"}, parsedInput.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, parsedInput.getModelNodeIds()); assertFalse(parsedInput.isAsync()); - assertTrue(parsedInput.isDispatchTask());} - ); + assertTrue(parsedInput.isDispatchTask()); + }); } @Test public void testParseWithInvalidField() throws Exception { String modelId = "modelId"; - String withInvalidFieldInputStr = "{\"void\":\"void\", \"dispatchTask\":\"false\", \"async\":\"true\", \"node_ids\":[\"modelNodeIds\"]}"; + String withInvalidFieldInputStr = + "{\"void\":\"void\", \"dispatchTask\":\"false\", \"async\":\"true\", \"node_ids\":[\"modelNodeIds\"]}"; parseFromJsonString(modelId, withInvalidFieldInputStr, parsedInput -> { assertEquals("modelId", parsedInput.getModelId()); - assertArrayEquals(new String [] {"modelNodeIds"}, parsedInput.getModelNodeIds()); + assertArrayEquals(new String[] { "modelNodeIds" }, parsedInput.getModelNodeIds()); assertFalse(parsedInput.isAsync()); - assertTrue(parsedInput.isDispatchTask());} - ); + assertTrue(parsedInput.isDispatchTask()); + }); } private void parseFromJsonString(String modelId, String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLDeployModelRequest parsedInput = mlDeployModelRequest.parse(parser, modelId); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java index ff5792b0a8..12d3eaef72 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/deploy/MLDeployModelResponseTest.java @@ -1,21 +1,19 @@ package org.opensearch.ml.common.transport.deploy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.MLTaskType; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class MLDeployModelResponseTest { private String taskId; @@ -53,7 +51,6 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," + - "\"status\":\"test\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," + "\"status\":\"test\"}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java index 3c5faa1559..8458c090e4 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequestTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.execute; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -7,26 +13,8 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataframe.ColumnType; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.dataframe.DataFrameBuilder; -import org.opensearch.ml.common.dataset.DataFrameInputDataset; -import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.dataset.MLInputDataset; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; -import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.metricscorrelation.MetricsCorrelationInput; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; - -import static org.junit.Assert.*; public class MLExecuteTaskRequestTest { private Input exInput; @@ -37,9 +25,9 @@ public class MLExecuteTaskRequestTest { @Before public void setUp() { inputData = new ArrayList<>(); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); - inputData.add(new float[]{1.0f, 2.0f, 3.0f, 4.0f}); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); exInput = MetricsCorrelationInput.builder().inputData(inputData).build(); } @@ -47,10 +35,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .input(exInput) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).input(exInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLExecuteTaskRequest(bytesStreamOutput.bytes().streamInput()); @@ -62,10 +47,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .input(exInput) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).input(exInput).build(); assertNull(request.validate()); } @@ -74,17 +56,14 @@ public void validate_Success() { public void validate_Exception_NullFunctionNane() { exceptionRule.expect(NullPointerException.class); exceptionRule.expectMessage("functionName is marked non-null but is null"); - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().build(); request.validate(); } @Test public void validate_Exception_NullMLInput() { - MLExecuteTaskRequest request = MLExecuteTaskRequest.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .build(); + MLExecuteTaskRequest request = MLExecuteTaskRequest.builder().functionName(FunctionName.METRICS_CORRELATION).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java index 1dfc5f29cd..dc5013fe22 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponseTest.java @@ -1,8 +1,14 @@ package org.opensearch.ml.common.transport.execute; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -12,89 +18,91 @@ import org.opensearch.ml.common.output.execute.metrics_correlation.MCorrModelTensors; import org.opensearch.ml.common.output.execute.metrics_correlation.MetricsCorrelationOutput; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.junit.Assert.*; - public class MLExecuteTaskResponseTest { @Test public void writeTo_Success() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLExecuteTaskResponse(bytesStreamOutput.bytes().streamInput()); - MetricsCorrelationOutput mcorrOutputTest = (MetricsCorrelationOutput)response.getOutput(); + MetricsCorrelationOutput mcorrOutputTest = (MetricsCorrelationOutput) response.getOutput(); assertEquals(1, mcorrOutputTest.getModelOutput().size()); MCorrModelTensors testmodelTensors = mcorrOutputTest.getModelOutput().get(0); assertEquals(1, testmodelTensors.getMCorrModelTensors().size()); MCorrModelTensor testmodelTensor = testmodelTensors.getMCorrModelTensors().get(0); float[] events = testmodelTensor.getEvent_pattern(); long[] metrics = testmodelTensor.getSuspected_metrics(); - assertArrayEquals(new float[]{1.0f, 2.0f, 3.0f}, events, 0.001f); - assertArrayEquals(new long[]{1, 2}, metrics); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, events, 0.001f); + assertArrayEquals(new long[] { 1, 2 }, metrics); } @Test public void fromActionResponse_WithMLPredictionTaskResponse() { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); assertSame(response, MLExecuteTaskResponse.fromActionResponse(response)); } @Test public void toXContentTest() throws IOException { List outputs = new ArrayList<>(); - MCorrModelTensor mCorrModelTensor = MCorrModelTensor.builder() - .event_pattern(new float[]{1.0f, 2.0f, 3.0f}) - .event_window(new float[]{4.0f, 5.0f, 6.0f}) - .suspected_metrics(new long[]{1, 2}) - .build(); + MCorrModelTensor mCorrModelTensor = MCorrModelTensor + .builder() + .event_pattern(new float[] { 1.0f, 2.0f, 3.0f }) + .event_window(new float[] { 4.0f, 5.0f, 6.0f }) + .suspected_metrics(new long[] { 1, 2 }) + .build(); List mlModelTensors = Arrays.asList(mCorrModelTensor); MCorrModelTensors modelTensors = MCorrModelTensors.builder().mCorrModelTensors(mlModelTensors).build(); outputs.add(modelTensors); MetricsCorrelationOutput output = MetricsCorrelationOutput.builder().modelOutput(outputs).build(); - MLExecuteTaskResponse response = MLExecuteTaskResponse.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .output(output) - .build(); + MLExecuteTaskResponse response = MLExecuteTaskResponse + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .output(output) + .build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"function_name\":\"METRICS_CORRELATION\"," + - "\"output\":{\"inference_results\":[{" + - "\"event_window\":[4.0,5.0,6.0]," + - "\"event_pattern\":[1.0,2.0,3.0]," + - "\"suspected_metrics\":[1,2]}]}}", jsonStr); + assertEquals( + "{\"function_name\":\"METRICS_CORRELATION\"," + + "\"output\":{\"inference_results\":[{" + + "\"event_window\":[4.0,5.0,6.0]," + + "\"event_pattern\":[1.0,2.0,3.0]," + + "\"suspected_metrics\":[1,2]}]}}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java index d9b3fc77c4..7720f478e7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -21,79 +31,74 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.function.Consumer; - -import static org.junit.Assert.*; - - @RunWith(MockitoJUnitRunner.class) public class MLForwardInputTest { private MLForwardInput forwardInput; private final FunctionName functionName = FunctionName.KMEANS; - @Before public void setUp() throws Exception { Instant time = Instant.now(); - MLTask mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(functionName) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("mlTaskNode1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); - - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - MLInput modelInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("mockModelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds"}) - .build(); - - forwardInput = MLForwardInput.builder() - .taskId("forwardInputTaskId") - .modelId("forwardInputModelId") - .workerNodeId("forwardInputWorkerNodeId") - .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) - .mlTask(mlTask) - .modelInput(modelInput) - .error("forwardInputError") - .workerNodes(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}) - .registerModelInput(registerModelInput) - .build(); + MLTask mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(functionName) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("mlTaskNode1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); + + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + MLInput modelInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("mockModelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + forwardInput = MLForwardInput + .builder() + .taskId("forwardInputTaskId") + .modelId("forwardInputModelId") + .workerNodeId("forwardInputWorkerNodeId") + .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) + .mlTask(mlTask) + .modelInput(modelInput) + .error("forwardInputError") + .workerNodes(new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }) + .registerModelInput(registerModelInput) + .build(); } @Test @@ -104,7 +109,6 @@ public void readInputStream_Success() throws IOException { }); } - @Test public void readInputStream_SuccessWithNullFields() throws IOException { forwardInput.setMlTask(null); @@ -117,7 +121,6 @@ public void readInputStream_SuccessWithNullFields() throws IOException { }); } - private void readInputStream(MLForwardInput input, Consumer verify) throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); input.writeTo(bytesStreamOutput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java index b0eabfcb83..65815c921a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -23,16 +33,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.time.Instant; -import java.time.temporal.ChronoUnit; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.*; - @RunWith(MockitoJUnitRunner.class) public class MLForwardRequestTest { @@ -42,70 +42,74 @@ public class MLForwardRequestTest { private MLRegisterModelInput registerModelInput; private final FunctionName functionName = FunctionName.KMEANS; - @Before public void setUp() throws Exception { Instant time = Instant.now(); - mlTask = MLTask.builder() - .taskId("mlTaskTaskId") - .modelId("mlTaskModelId") - .taskType(MLTaskType.PREDICTION) - .functionName(functionName) - .state(MLTaskState.RUNNING) - .inputType(MLInputDataType.DATA_FRAME) - .workerNodes(Arrays.asList("mlTaskNode1")) - .progress(0.0f) - .outputIndex("test_index") - .error("test_error") - .createTime(time.minus(1, ChronoUnit.MINUTES)) - .lastUpdateTime(time) - .build(); - - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - modelInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - registerModelInput = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); - - forwardInput = MLForwardInput.builder() - .taskId("forwardInputTaskId") - .modelId("forwardInputModelId") - .workerNodeId("forwardInputWorkerNodeId") - .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) - .mlTask(mlTask) - .modelInput(modelInput) - .error("forwardInputError") - .workerNodes(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}) - .registerModelInput(registerModelInput) - .build(); + mlTask = MLTask + .builder() + .taskId("mlTaskTaskId") + .modelId("mlTaskModelId") + .taskType(MLTaskType.PREDICTION) + .functionName(functionName) + .state(MLTaskState.RUNNING) + .inputType(MLInputDataType.DATA_FRAME) + .workerNodes(Arrays.asList("mlTaskNode1")) + .progress(0.0f) + .outputIndex("test_index") + .error("test_error") + .createTime(time.minus(1, ChronoUnit.MINUTES)) + .lastUpdateTime(time) + .build(); + + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + modelInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + registerModelInput = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName("testModelName") + .version("testModelVersion") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); + + forwardInput = MLForwardInput + .builder() + .taskId("forwardInputTaskId") + .modelId("forwardInputModelId") + .workerNodeId("forwardInputWorkerNodeId") + .requestType(MLForwardRequestType.DEPLOY_MODEL_DONE) + .mlTask(mlTask) + .modelInput(modelInput) + .error("forwardInputError") + .workerNodes(new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }) + .registerModelInput(registerModelInput) + .build(); } @Test public void writeTo_Success() throws IOException { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLForwardRequest(bytesStreamOutput.bytes().streamInput()); @@ -114,7 +118,10 @@ public void writeTo_Success() throws IOException { assertEquals("forwardInputWorkerNodeId", request.getForwardInput().getWorkerNodeId()); assertEquals(MLForwardRequestType.DEPLOY_MODEL_DONE, request.getForwardInput().getRequestType()); assertEquals("forwardInputError", request.getForwardInput().getError()); - assertArrayEquals(new String [] {"forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3"}, request.getForwardInput().getWorkerNodes()); + assertArrayEquals( + new String[] { "forwardInputNodeId1", "forwardInputNodeId2", "forwardInputNodeId3" }, + request.getForwardInput().getWorkerNodes() + ); assertEquals(mlTask.getTaskId(), request.getForwardInput().getMlTask().getTaskId()); assertEquals(modelInput.getAlgorithm().toString(), request.getForwardInput().getModelInput().getAlgorithm().toString()); assertEquals(registerModelInput.getModelName(), request.getForwardInput().getRegisterModelInput().getModelName()); @@ -122,17 +129,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLInput() { - MLForwardRequest request = MLForwardRequest.builder() - .build(); + MLForwardRequest request = MLForwardRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -141,9 +145,7 @@ public void validate_Exception_NullMLInput() { // MLForwardInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { forwardInput.setTaskId(null); - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertNull(request.validate()); assertNull(request.getForwardInput().getTaskId()); @@ -151,19 +153,14 @@ public void validate_Exception_NullMLModelName() { @Test public void fromActionRequest_Success_WithMLForwardRequest() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); assertSame(MLForwardRequest.fromActionRequest(request), request); } - @Test public void fromActionRequest_Success_WithNonMLForwardRequest() { - MLForwardRequest request = MLForwardRequest.builder() - .forwardInput(forwardInput) - .build(); + MLForwardRequest request = MLForwardRequest.builder().forwardInput(forwardInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -179,8 +176,14 @@ public void writeTo(StreamOutput out) throws IOException { assertNotSame(result, request); assertEquals(request.getForwardInput().getTaskId(), result.getForwardInput().getTaskId()); assertEquals(request.getForwardInput().getMlTask().getTaskId(), result.getForwardInput().getMlTask().getTaskId()); - assertEquals(request.getForwardInput().getModelInput().getAlgorithm().toString(), result.getForwardInput().getModelInput().getAlgorithm().toString()); - assertEquals(request.getForwardInput().getRegisterModelInput().getModelName(), result.getForwardInput().getRegisterModelInput().getModelName()); + assertEquals( + request.getForwardInput().getModelInput().getAlgorithm().toString(), + result.getForwardInput().getModelInput().getAlgorithm().toString() + ); + assertEquals( + request.getForwardInput().getRegisterModelInput().getModelName(), + result.getForwardInput().getRegisterModelInput().getModelName() + ); } @Test(expected = UncheckedIOException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java index fa2d86a54d..a159c6e761 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardResponseTest.java @@ -1,15 +1,20 @@ package org.opensearch.ml.common.transport.forward; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -17,13 +22,6 @@ import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.output.MLPredictionOutput; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.*; - @RunWith(MockitoJUnitRunner.class) public class MLForwardResponseTest { @@ -38,11 +36,7 @@ public void setUp() throws Exception { put("key1", 2.0D); } })); - predictionOutput = MLPredictionOutput.builder() - .status("Success") - .predictionResult(dataFrame) - .taskId("taskId") - .build(); + predictionOutput = MLPredictionOutput.builder().status("Success").predictionResult(dataFrame).taskId("taskId").build(); } @Test @@ -69,8 +63,9 @@ public void testToXContent() throws IOException { String jsonStr = builder.toString(); // Verify the results assertEquals( - "{\"result\":{\"task_id\":\"taskId\",\"status\":\"Success\",\"prediction_result\":{\"column_metas\":[{\"name\":\"key1\",\"column_type\":\"DOUBLE\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":2.0}]}]}}}", - jsonStr); + "{\"result\":{\"task_id\":\"taskId\",\"status\":\"Success\",\"prediction_result\":{\"column_metas\":[{\"name\":\"key1\",\"column_type\":\"DOUBLE\"}],\"rows\":[{\"values\":[{\"column_type\":\"DOUBLE\",\"value\":2.0}]}]}}}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java index 533b96ecdf..109eafde95 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelDeleteRequestTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.model; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelDeleteRequestTest { private String modelId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelDeleteRequest.writeTo(bytesStreamOutput); MLModelDeleteRequest parsedModel = new MLModelDeleteRequest(bytesStreamOutput.bytes().streamInput()); @@ -40,8 +39,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); ActionRequestValidationException actionRequestValidationException = mlModelDeleteRequest.validate(); assertNull(actionRequestValidationException); } @@ -56,8 +54,7 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -90,11 +87,9 @@ public void writeTo(StreamOutput out) throws IOException { MLModelDeleteRequest.fromActionRequest(actionRequest); } - @Test public void fromActionRequestWithModelDeleteRequest_Success() { - MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder() - .modelId(modelId).build(); + MLModelDeleteRequest mlModelDeleteRequest = MLModelDeleteRequest.builder().modelId(modelId).build(); MLModelDeleteRequest mlModelDeleteRequestFromActionRequest = MLModelDeleteRequest.fromActionRequest(mlModelDeleteRequest); assertSame(mlModelDeleteRequest, mlModelDeleteRequestFromActionRequest); assertEquals(mlModelDeleteRequest.getModelId(), mlModelDeleteRequestFromActionRequest.getModelId()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java index 97f784d868..4a16bf9347 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetRequestTest.java @@ -5,6 +5,14 @@ package org.opensearch.ml.common.transport.model; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -12,14 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLModelGetRequestTest { private String modelId; @@ -30,8 +30,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() - .modelId(modelId).build(); + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelGetRequest.writeTo(bytesStreamOutput); MLModelGetRequest parsedModel = new MLModelGetRequest(bytesStreamOutput.bytes().streamInput()); @@ -48,17 +47,16 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder() - .modelId(modelId).build(); + MLModelGetRequest mlModelGetRequest = MLModelGetRequest.builder().modelId(modelId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - mlModelGetRequest.writeTo(out); + mlModelGetRequest.writeTo(out); } }; MLModelGetRequest result = MLModelGetRequest.fromActionRequest(actionRequest); @@ -71,12 +69,12 @@ public void fromActionRequest_IOException() { ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { - return null; + return null; } @Override public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); + throw new IOException("test"); } }; MLModelGetRequest.fromActionRequest(actionRequest); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java index 81a25ce8d8..2b191c14a2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLModelGetResponseTest.java @@ -5,12 +5,17 @@ package org.opensearch.ml.common.transport.model; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -19,25 +24,21 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.*; - public class MLModelGetResponseTest { MLModel mlModel; @Before public void setUp() { - mlModel = MLModel.builder() - .name("model") - .algorithm(FunctionName.KMEANS) - .version("1.0.0") - .content("content") - .user(new User()) - .modelState(MLModelState.TRAINED) - .build(); + mlModel = MLModel + .builder() + .name("model") + .algorithm(FunctionName.KMEANS) + .version("1.0.0") + .content("content") + .user(new User()) + .modelState(MLModelState.TRAINED) + .build(); } @Test @@ -61,12 +62,14 @@ public void toXContentTest() throws IOException { mlModelGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"name\":\"model\"," + - "\"algorithm\":\"KMEANS\"," + - "\"model_version\":\"1.0.0\"," + - "\"model_content\":\"content\"," + - "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", - jsonStr); + assertEquals( + "{\"name\":\"model\"," + + "\"algorithm\":\"KMEANS\"," + + "\"model_version\":\"1.0.0\"," + + "\"model_content\":\"content\"," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null},\"model_state\":\"TRAINED\"}", + jsonStr + ); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index 6bafe81692..ffcadd9c48 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -6,12 +6,10 @@ package org.opensearch.ml.common.transport.model; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import java.io.IOException; -import java.util.Arrays; import java.util.Collections; import java.util.function.Consumer; @@ -20,57 +18,63 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.SearchModule; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.search.SearchModule; public class MLUpdateModelInputTest { private MLUpdateModelInput updateModelInput; - private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; - private final String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; - private final String expectedOutputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; - private final String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + - "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; + private final String expectedInputStr = + "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithNullField = + "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedOutputStr = + "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":\"2\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; + private final String expectedInputStrWithIllegalField = + "{\"model_id\":\"test-model_id\",\"description\":\"description\",\"model_version\":\"2\",\"name\":\"name\",\"model_group_id\":\"modelGroupId\",\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\",\"illegal_field\":\"This field need to be skipped.\"}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @Before public void setUp() throws Exception { - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - - updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id") - .modelGroupId("modelGroupId") - .version("2") - .name("name") - .description("description") - .modelConfig(config) - .connectorId("test-connector_id") - .build(); - } + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + updateModelInput = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .version("2") + .name("name") + .description("description") + .modelConfig(config) + .connectorId("test-connector_id") + .build(); + } @Test public void readInputStream_Success() throws IOException { @@ -83,9 +87,7 @@ public void readInputStream_Success() throws IOException { @Test public void readInputStream_SuccessWithNullFields() throws IOException { updateModelInput.setModelConfig(null); - readInputStream(updateModelInput, parsedInput -> { - assertNull(parsedInput.getModelConfig()); - }); + readInputStream(updateModelInput, parsedInput -> { assertNull(parsedInput.getModelConfig()); }); } @Test @@ -96,8 +98,7 @@ public void testToXContent() throws Exception { @Test public void testToXContent_Incomplete() throws Exception { - String expectedIncompleteInputStr = - "{\"model_id\":\"test-model_id\"}"; + String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; updateModelInput.setDescription(null); updateModelInput.setVersion(null); updateModelInput.setName(null); @@ -110,9 +111,7 @@ public void testToXContent_Incomplete() throws Exception { @Test public void parse_Success() throws Exception { - testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("name", parsedInput.getName()); - }); + testParseFromJsonString(expectedInputStr, parsedInput -> { assertEquals("name", parsedInput.getName()); }); } @Test @@ -139,8 +138,13 @@ public void parse_WithIllegalFieldWithoutModel() throws Exception { } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLUpdateModelInput parsedInput = MLUpdateModelInput.parse(parser); verify.accept(parsedInput); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java index cadf865b1c..184ab097d2 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelRequestTest.java @@ -5,55 +5,50 @@ package org.opensearch.ml.common.transport.model; -import org.junit.Before; -import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; import org.junit.Test; -import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.rest.RestRequest; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; public class MLUpdateModelRequestTest { private MLUpdateModelRequest updateModelRequest; @Before - public void setUp(){ + public void setUp() { MockitoAnnotations.openMocks(this); - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - - MLUpdateModelInput updateModelInput = MLUpdateModelInput.builder() - .modelId("test-model_id") - .modelGroupId("modelGroupId") - .name("name") - .description("description") - .modelConfig(config) - .build(); - - updateModelRequest = MLUpdateModelRequest.builder() - .updateModelInput(updateModelInput) - .build(); + MLModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + MLUpdateModelInput updateModelInput = MLUpdateModelInput + .builder() + .modelId("test-model_id") + .modelGroupId("modelGroupId") + .name("name") + .description("description") + .modelConfig(config) + .build(); + + updateModelRequest = MLUpdateModelRequest.builder().updateModelInput(updateModelInput).build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java index 68e4491674..f9bcf6fc1d 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -7,12 +13,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - public class MLModelGroupDeleteRequestTest { private String modelGroupId; @@ -24,8 +24,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.builder() - .modelGroupId(modelGroupId).build(); + MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlModelGroupDeleteRequest.writeTo(bytesStreamOutput); MLModelGroupDeleteRequest parsedModel = new MLModelGroupDeleteRequest(bytesStreamOutput.bytes().streamInput()); @@ -42,8 +41,7 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLModelGroupDeleteRequest mlModelDeleteRequest = MLModelGroupDeleteRequest.builder() - .modelGroupId(modelGroupId).build(); + MLModelGroupDeleteRequest mlModelDeleteRequest = MLModelGroupDeleteRequest.builder().modelGroupId(modelGroupId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java index 628336d1ef..68cd72836e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInputTest.java @@ -1,16 +1,16 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class MLRegisterModelGroupInputTest { private MLRegisterModelGroupInput mlRegisterModelGroupInput; @@ -18,13 +18,14 @@ public class MLRegisterModelGroupInputTest { @Before public void setUp() throws Exception { - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + mlRegisterModelGroupInput = mlRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java index 8e27325e47..188180a98c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupRequestTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -8,59 +17,61 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLRegisterModelGroupRequestTest { private MLRegisterModelGroupInput mlRegisterModelGroupInput; @Before - public void setUp(){ - - mlRegisterModelGroupInput = mlRegisterModelGroupInput.builder() - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + public void setUp() { + + mlRegisterModelGroupInput = mlRegisterModelGroupInput + .builder() + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test public void writeTo_Success() throws IOException { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); MLRegisterModelGroupRequest parsedRequest = new MLRegisterModelGroupRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(request.getRegisterModelGroupInput().getName(), parsedRequest.getRegisterModelGroupInput().getName()); assertEquals(request.getRegisterModelGroupInput().getDescription(), parsedRequest.getRegisterModelGroupInput().getDescription()); - assertEquals(request.getRegisterModelGroupInput().getBackendRoles().get(0), parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0)); - assertEquals(request.getRegisterModelGroupInput().getModelAccessMode(), parsedRequest.getRegisterModelGroupInput().getModelAccessMode()); - assertEquals(request.getRegisterModelGroupInput().getIsAddAllBackendRoles() ,parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles()); + assertEquals( + request.getRegisterModelGroupInput().getBackendRoles().get(0), + parsedRequest.getRegisterModelGroupInput().getBackendRoles().get(0) + ); + assertEquals( + request.getRegisterModelGroupInput().getModelAccessMode(), + parsedRequest.getRegisterModelGroupInput().getModelAccessMode() + ); + assertEquals( + request.getRegisterModelGroupInput().getIsAddAllBackendRoles(), + parsedRequest.getRegisterModelGroupInput().getIsAddAllBackendRoles() + ); } @Test public void validate_Success() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLRegisterModelGroupInput() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); } @@ -69,9 +80,10 @@ public void validate_Exception_NullMLRegisterModelGroupInput() { // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { mlRegisterModelGroupInput.setName(null); - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); assertNull(request.validate()); assertNull(request.getRegisterModelGroupInput().getName()); @@ -79,17 +91,19 @@ public void validate_Exception_NullMLModelName() { @Test public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); assertSame(MLRegisterModelGroupRequest.fromActionRequest(request), request); } @Test public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest.builder() - .registerModelGroupInput(mlRegisterModelGroupInput) - .build(); + MLRegisterModelGroupRequest request = MLRegisterModelGroupRequest + .builder() + .registerModelGroupInput(mlRegisterModelGroupInput) + .build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java index 9299307539..483ccedcc1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupResponseTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,12 +18,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLRegisterModelGroupResponseTest { MLRegisterModelGroupResponse mlRegisterModelGroupResponse; @@ -27,7 +27,6 @@ public void setup() { mlRegisterModelGroupResponse = new MLRegisterModelGroupResponse("ModelGroupId", "Status"); } - @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java index be5a9c0862..96d6b36a45 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInputTest.java @@ -1,16 +1,16 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; - public class MLUpdateModelGroupInputTest { private MLUpdateModelGroupInput mlUpdateModelGroupInput; @@ -18,14 +18,15 @@ public class MLUpdateModelGroupInputTest { @Before public void setUp() throws Exception { - mlUpdateModelGroupInput = mlUpdateModelGroupInput.builder() - .modelGroupID("modelGroupId") - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + mlUpdateModelGroupInput = mlUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java index 483d7c6c85..0462e47eeb 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupRequestTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -8,38 +17,28 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.AccessMode; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; - public class MLUpdateModelGroupRequestTest { private MLUpdateModelGroupInput mlUpdateModelGroupInput; @Before - public void setUp(){ - - mlUpdateModelGroupInput = mlUpdateModelGroupInput.builder() - .modelGroupID("modelGroupId") - .name("name") - .description("description") - .backendRoles(Arrays.asList("IT")) - .modelAccessMode(AccessMode.RESTRICTED) - .isAddAllBackendRoles(true) - .build(); + public void setUp() { + + mlUpdateModelGroupInput = mlUpdateModelGroupInput + .builder() + .modelGroupID("modelGroupId") + .name("name") + .description("description") + .backendRoles(Arrays.asList("IT")) + .modelAccessMode(AccessMode.RESTRICTED) + .isAddAllBackendRoles(true) + .build(); } @Test public void writeTo_Success() throws IOException { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLUpdateModelGroupRequest(bytesStreamOutput.bytes().streamInput()); @@ -53,17 +52,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLRegisterModelGroupInput() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: Update Model group input can't be null;", exception.getMessage()); } @@ -72,28 +68,21 @@ public void validate_Exception_NullMLRegisterModelGroupInput() { // MLRegisterModelGroupInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { mlUpdateModelGroupInput.setName(null); - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertNull(request.validate()); assertNull(request.getUpdateModelGroupInput().getName()); } - @Test public void fromActionRequest_Success_WithMLUpdateModelRequest() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); assertSame(MLUpdateModelGroupRequest.fromActionRequest(request), request); } @Test public void fromActionRequest_Success_WithNonMLUpdateModelRequest() { - MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder() - .updateModelGroupInput(mlUpdateModelGroupInput) - .build(); + MLUpdateModelGroupRequest request = MLUpdateModelGroupRequest.builder().updateModelGroupInput(mlUpdateModelGroupInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java index 2c1305a73e..f42a1bf4d1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupResponseTest.java @@ -5,6 +5,12 @@ package org.opensearch.ml.common.transport.model_group; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -12,12 +18,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - public class MLUpdateModelGroupResponseTest { MLUpdateModelGroupResponse mlUpdateModelGroupResponse; @@ -27,7 +27,6 @@ public void setup() { mlUpdateModelGroupResponse = new MLUpdateModelGroupResponse("Status"); } - @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index ce96aa56c1..bc7fd94f40 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -5,12 +5,16 @@ package org.opensearch.ml.common.transport.prediction; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; import java.util.Collections; import java.util.HashMap; -import lombok.NonNull; import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -18,6 +22,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; @@ -25,15 +30,11 @@ import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.search.builder.SearchSourceBuilder; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import lombok.NonNull; public class MLPredictionTaskRequestTest { @@ -41,27 +42,28 @@ public class MLPredictionTaskRequestTest { @Before public void setUp() { - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test public void writeTo_Success() throws IOException { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLPredictionTaskRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(FunctionName.KMEANS, request.getMlInput().getAlgorithm()); - KMeansParams params = (KMeansParams)request.getMlInput().getParameters(); + KMeansParams params = (KMeansParams) request.getMlInput().getParameters(); assertEquals(1, params.getCentroids().intValue()); MLInputDataset inputDataset = request.getMlInput().getInputDataset(); assertEquals(MLInputDataType.DATA_FRAME, inputDataset.getInputDataType()); @@ -78,9 +80,7 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); assertNull(request.validate()); } @@ -88,8 +88,7 @@ public void validate_Success() { @Test public void validate_Exception_NullMLInput() { mlInput.setAlgorithm(null); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); @@ -98,21 +97,16 @@ public void validate_Exception_NullMLInput() { @Test public void validate_Exception_NullInputDataset() { mlInput.setInputDataset(null); - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } - @Test public void fromActionRequest_Success_WithMLPredictionTaskRequest() { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); assertSame(MLPredictionTaskRequest.fromActionRequest(request), request); } @@ -123,19 +117,22 @@ public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_DataFrameIn @Test public void fromActionRequest_Success_WithNonMLPredictionTaskRequest_SearchQueryInput() { - @NonNull SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + @NonNull + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(new MatchAllQueryBuilder()); - mlInput.setInputDataset(SearchQueryInputDataset.builder() - .indices(Collections.singletonList("test_index")) - .searchSourceBuilder(searchSourceBuilder) - .build()); + mlInput + .setInputDataset( + SearchQueryInputDataset + .builder() + .indices(Collections.singletonList("test_index")) + .searchSourceBuilder(searchSourceBuilder) + .build() + ); fromActionRequest_Success_WithNonMLPredictionTaskRequest(mlInput); } private void fromActionRequest_Success_WithNonMLPredictionTaskRequest(MLInput mlInput) { - MLPredictionTaskRequest request = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().mlInput(mlInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -168,4 +165,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLPredictionTaskRequest.fromActionRequest(actionRequest); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java index fdd3883441..3e99a5a972 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskResponseTest.java @@ -5,12 +5,18 @@ package org.opensearch.ml.common.transport.prediction; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Test; -import org.opensearch.core.action.ActionResponse; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -18,29 +24,21 @@ import org.opensearch.ml.common.output.MLPredictionOutput; import org.opensearch.ml.common.transport.MLTaskResponse; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.*; - public class MLPredictionTaskResponseTest { @Test public void writeTo_Success() throws IOException { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("key1", 2.0D); - } - }))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput()); @@ -52,35 +50,33 @@ public void writeTo_Success() throws IOException { @Test public void fromActionResponse_WithMLPredictionTaskResponse() { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("key1", 2.0D); - } - }))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); assertSame(response, MLTaskResponse.fromActionResponse(response)); } @Test public void fromActionResponse_WithNonMLPredictionTaskResponse() { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("taskId") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("key1", 2.0D); - } - }))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("taskId") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -94,8 +90,7 @@ public void writeTo(StreamOutput out) throws IOException { MLPredictionOutput resultMlPredictionOutput = (MLPredictionOutput) result.getOutput(); assertEquals(mlPredictionOutput.getTaskId(), resultMlPredictionOutput.getTaskId()); assertEquals(mlPredictionOutput.getStatus(), resultMlPredictionOutput.getStatus()); - assertEquals(mlPredictionOutput.getPredictionResult().size(), - resultMlPredictionOutput.getPredictionResult().size()); + assertEquals(mlPredictionOutput.getPredictionResult().size(), resultMlPredictionOutput.getPredictionResult().size()); } @Test(expected = UncheckedIOException.class) @@ -112,27 +107,29 @@ public void writeTo(StreamOutput out) throws IOException { @Test public void toXContentTest() throws IOException { - MLPredictionOutput output = MLPredictionOutput.builder() - .taskId("b5009b99-268f-476d-a676-379a30f82457") - .status("Success") - .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { - { - put("ClusterID", 0); - } - }))) - .build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLPredictionOutput output = MLPredictionOutput + .builder() + .taskId("b5009b99-268f-476d-a676-379a30f82457") + .status("Success") + .predictionResult(DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("ClusterID", 0); + } + }))) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"b5009b99-268f-476d-a676-379a30f82457\"," + - "\"status\":\"Success\"," + - "\"prediction_result\":{" + - "\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + - "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}}", jsonStr); + assertEquals( + "{\"task_id\":\"b5009b99-268f-476d-a676-379a30f82457\"," + + "\"status\":\"Success\"," + + "\"prediction_result\":{" + + "\"column_metas\":[{\"name\":\"ClusterID\",\"column_type\":\"INTEGER\"}]," + + "\"rows\":[{\"values\":[{\"column_type\":\"INTEGER\",\"value\":0}]}]}}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 6de3788243..a0351e9100 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -1,5 +1,14 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Consumer; + import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -8,11 +17,11 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; -import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; @@ -26,15 +35,6 @@ import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; -import java.util.function.Consumer; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - @RunWith(MockitoJUnitRunner.class) public class MLRegisterModelInputTest { @@ -44,10 +44,11 @@ public class MLRegisterModelInputTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," + - "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + - "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + - "},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + private final String expectedInputStr = + "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," + + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + + "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + + "},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; @@ -57,24 +58,26 @@ public class MLRegisterModelInputTest { @Before public void setUp() throws Exception { - config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); + config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); - input = MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); } @Test @@ -88,52 +91,51 @@ public void constructor_NullModel() { public void constructor_NullModelName() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model name is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelGroupId(modelGroupId) - .modelName(null) - .build(); + MLRegisterModelInput.builder().functionName(functionName).modelGroupId(modelGroupId).modelName(null).build(); } @Test public void constructor_NullModelFormat() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model format is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(null) - .url(url) - .build(); + MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(null) + .url(url) + .build(); } @Test public void constructor_NullModelConfig() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("model config is null"); - MLRegisterModelInput.builder() - .functionName(functionName) - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(null) - .url(url) - .build(); + MLRegisterModelInput + .builder() + .functionName(functionName) + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(null) + .url(url) + .build(); } @Test public void constructor_SuccessWithMinimalSetup() { - MLRegisterModelInput input = MLRegisterModelInput.builder() - .modelName(modelName) - .version(version) - .modelGroupId(modelGroupId) - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .url(url) - .build(); + MLRegisterModelInput input = MLRegisterModelInput + .builder() + .modelName(modelName) + .version(version) + .modelGroupId(modelGroupId) + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .url(url) + .build(); // MLRegisterModelInput.functionName is set to FunctionName.TEXT_EMBEDDING if not explicitly passed, with no exception thrown assertEquals(FunctionName.TEXT_EMBEDDING, input.getFunctionName()); // MLRegisterModelInput.deployModel is set to false if not explicitly passed, with no exception thrown @@ -153,9 +155,8 @@ public void testToXContent() throws Exception { @Test public void testToXContent_Incomplete() throws Exception { - String expectedIncompleteInputStr = - "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + - "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}"; + String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}"; input.setUrl(null); input.setModelConfig(null); input.setModelFormat(null); @@ -178,24 +179,41 @@ public void parse_WithModel() throws Exception { @Test public void parse_WithoutModel() throws Exception { - testParseFromJsonString( false, expectedInputStr, parsedInput -> { + testParseFromJsonString(false, expectedInputStr, parsedInput -> { assertFalse(parsedInput.isDeployModel()); assertEquals("modelName", parsedInput.getModelName()); assertEquals("version", parsedInput.getVersion()); }); } - private void testParseFromJsonString(String modelName, String version, Boolean deployModel, String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + private void testParseFromJsonString( + String modelName, + String version, + Boolean deployModel, + String expectedInputStr, + Consumer verify + ) throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLRegisterModelInput parsedInput = MLRegisterModelInput.parse(parser, modelName, version, deployModel); verify.accept(parsedInput); } - private void testParseFromJsonString(Boolean deployModel,String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + private void testParseFromJsonString(Boolean deployModel, String expectedInputStr, Consumer verify) + throws Exception { + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLRegisterModelInput parsedInput = MLRegisterModelInput.parse(parser, deployModel); verify.accept(parsedInput); @@ -209,7 +227,6 @@ public void readInputStream_Success() throws IOException { }); } - @Test public void readInputStream_SuccessWithNullFields() throws IOException { input.setModelFormat(null); @@ -223,14 +240,15 @@ public void readInputStream_SuccessWithNullFields() throws IOException { @Test public void readInputStream_WithConnectorId() throws IOException { String connectorId = "test_connector_id"; - input = MLRegisterModelInput.builder() - .functionName(FunctionName.REMOTE) - .modelName(modelName) - .description("test model input") - .version(version) - .modelGroupId(modelGroupId) - .connectorId(connectorId) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connectorId(connectorId) + .build(); readInputStream(input, parsedInput -> { assertNull(parsedInput.getModelConfig()); assertNull(parsedInput.getModelFormat()); @@ -242,14 +260,15 @@ public void readInputStream_WithConnectorId() throws IOException { @Test public void readInputStream_WithInternalConnector() throws IOException { HttpConnector connector = HttpConnectorTest.createHttpConnector(); - input = MLRegisterModelInput.builder() - .functionName(FunctionName.REMOTE) - .modelName(modelName) - .description("test model input") - .version(version) - .modelGroupId(modelGroupId) - .connector(connector) - .build(); + input = MLRegisterModelInput + .builder() + .functionName(FunctionName.REMOTE) + .modelName(modelName) + .description("test model input") + .version(version) + .modelGroupId(modelGroupId) + .connector(connector) + .build(); readInputStream(input, parsedInput -> { assertNull(parsedInput.getModelConfig()); assertNull(parsedInput.getModelFormat()); @@ -259,24 +278,27 @@ public void readInputStream_WithInternalConnector() throws IOException { @Test public void testMCorrInput() throws IOException { - String testString = "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + String testString = + "{\"function_name\":\"METRICS_CORRELATION\",\"name\":\"METRICS_CORRELATION\",\"version\":\"1.0.0b1\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"TORCH_SCRIPT\",\"model_config\":{\"model_type\":\"testModelType\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; - MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); - MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .modelName(FunctionName.METRICS_CORRELATION.name()) - .version("1.0.0b1") - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(mcorrConfig) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLRegisterModelInput mcorrInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); mcorrInput.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); @@ -285,22 +307,24 @@ public void testMCorrInput() throws IOException { @Test public void readInputStream_MCorr() throws IOException { - MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .build(); + MetricsCorrelationModelConfig mcorrConfig = MetricsCorrelationModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .build(); - MLRegisterModelInput mcorrInput = MLRegisterModelInput.builder() - .functionName(FunctionName.METRICS_CORRELATION) - .modelName(FunctionName.METRICS_CORRELATION.name()) - .version("1.0.0b1") - .modelGroupId(modelGroupId) - .url(url) - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .modelConfig(mcorrConfig) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + MLRegisterModelInput mcorrInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.METRICS_CORRELATION) + .modelName(FunctionName.METRICS_CORRELATION.name()) + .version("1.0.0b1") + .modelGroupId(modelGroupId) + .url(url) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(mcorrConfig) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); readInputStream(mcorrInput, parsedInput -> { assertEquals(parsedInput.getModelConfig().getModelType(), mcorrConfig.getModelType()); assertEquals(parsedInput.getModelConfig().getAllConfig(), mcorrConfig.getAllConfig()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java index b983fb1827..bcbee60593 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java @@ -1,5 +1,7 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.*; + import java.io.IOException; import java.io.UncheckedIOException; @@ -9,47 +11,44 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelConfig; - -import static org.junit.Assert.*; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; public class MLRegisterModelRequestTest { private MLRegisterModelInput mlRegisterModelInput; @Before - public void setUp(){ - - TextEmbeddingModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - - - mlRegisterModelInput = mlRegisterModelInput.builder() - .functionName(FunctionName.KMEANS) - .modelName("modelName") - .version("version") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) - .deployModel(true) - .modelNodeIds(new String[]{"modelNodeIds" }) - .build(); + public void setUp() { + + TextEmbeddingModelConfig config = TextEmbeddingModelConfig + .builder() + .modelType("testModelType") + .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(100) + .build(); + + mlRegisterModelInput = mlRegisterModelInput + .builder() + .functionName(FunctionName.KMEANS) + .modelName("modelName") + .version("version") + .modelGroupId("modelGroupId") + .url("url") + .modelFormat(MLModelFormat.ONNX) + .modelConfig(config) + .deployModel(true) + .modelNodeIds(new String[] { "modelNodeIds" }) + .build(); } @Test public void writeTo_Success() throws IOException { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLRegisterModelRequest(bytesStreamOutput.bytes().streamInput()); @@ -69,17 +68,14 @@ public void writeTo_Success() throws IOException { @Test public void validate_Success() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLRegisterModelInput() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -88,9 +84,7 @@ public void validate_Exception_NullMLRegisterModelInput() { // MLRegisterModelInput check its parameters when created, so exception is not thrown here public void validate_Exception_NullMLModelName() { mlRegisterModelInput.setModelName(null); - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertNull(request.validate()); assertNull(request.getRegisterModelInput().getModelName()); @@ -98,17 +92,13 @@ public void validate_Exception_NullMLModelName() { @Test public void fromActionRequest_Success_WithMLRegisterModelRequest() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); assertSame(MLRegisterModelRequest.fromActionRequest(request), request); } @Test public void fromActionRequest_Success_WithNonMLRegisterModelRequest() { - MLRegisterModelRequest request = MLRegisterModelRequest.builder() - .registerModelInput(mlRegisterModelInput) - .build(); + MLRegisterModelRequest request = MLRegisterModelRequest.builder().registerModelInput(mlRegisterModelInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -123,7 +113,10 @@ public void writeTo(StreamOutput out) throws IOException { MLRegisterModelRequest result = MLRegisterModelRequest.fromActionRequest(actionRequest); assertNotSame(result, request); assertEquals(request.getRegisterModelInput().getModelName(), result.getRegisterModelInput().getModelName()); - assertEquals(request.getRegisterModelInput().getModelConfig().getModelType(), result.getRegisterModelInput().getModelConfig().getModelType()); + assertEquals( + request.getRegisterModelInput().getModelConfig().getModelType(), + result.getRegisterModelInput().getModelConfig().getModelType() + ); } @Test(expected = UncheckedIOException.class) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java index a91e85fe93..d01e593e61 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponseTest.java @@ -1,8 +1,12 @@ package org.opensearch.ml.common.transport.register; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -10,11 +14,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - public class MLRegisterModelResponseTest { private String taskId; @@ -52,8 +51,7 @@ public void testToXContent() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + - "\"status\":\"test\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"status\":\"test\"}", jsonStr); } @Test @@ -66,7 +64,6 @@ public void testToXContent_withModelId() throws IOException { assertNotNull(builder); String jsonStr = builder.toString(); // Verify the results - assertEquals("{\"task_id\":\"test_id\"," + - "\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr); + assertEquals("{\"task_id\":\"test_id\"," + "\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java index 7a728a1e41..354017e30e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpInputTest.java @@ -1,8 +1,6 @@ package org.opensearch.ml.common.transport.sync; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; +import static org.junit.Assert.*; import java.io.IOException; import java.util.HashMap; @@ -10,18 +8,20 @@ import java.util.Map; import java.util.Set; -import static org.junit.Assert.*; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; public class MLSyncUpInputTest { - @Test public void testConstructorSerialization_SuccessWithNullFields() throws IOException { - MLSyncUpInput syncUpInputWithNullFields = MLSyncUpInput.builder() - .getDeployedModels(true) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + MLSyncUpInput syncUpInputWithNullFields = MLSyncUpInput + .builder() + .getDeployedModels(true) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpInputWithNullFields.writeTo(bytesStreamOutput); @@ -41,34 +41,47 @@ public void testConstructorSerialization_SuccessWithFullFields() throws IOExcept Map> modelRoutingTable = new HashMap<>(); Map> runningDeployModelTasks = new HashMap<>(); - MLSyncUpInput syncUpInput = MLSyncUpInput.builder() - .getDeployedModels(true) - .addedWorkerNodes(addedWorkerNodes) - .removedWorkerNodes(removedWorkerNodes) - .modelRoutingTable(modelRoutingTable) - .runningDeployModelTasks(runningDeployModelTasks) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + MLSyncUpInput syncUpInput = MLSyncUpInput + .builder() + .getDeployedModels(true) + .addedWorkerNodes(addedWorkerNodes) + .removedWorkerNodes(removedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningDeployModelTasks(runningDeployModelTasks) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); Set modelRoutingTableSet = new HashSet<>(); Set runningDeployModelTaskSet = new HashSet<>(); modelRoutingTableSet.add("modelRoutingTable1"); runningDeployModelTaskSet.add("runningDeployModelTask1"); - addedWorkerNodes.put("addedWorkerNodesKey1", new String [] {"addedWorkerNode1"}); - removedWorkerNodes.put("removedWorkerNodesKey1", new String [] {"removedWorkerNode1"}); - modelRoutingTable.put("modelRoutingTableKey1",modelRoutingTableSet); - runningDeployModelTasks.put("runningDeployModelTaskKey1",runningDeployModelTaskSet); + addedWorkerNodes.put("addedWorkerNodesKey1", new String[] { "addedWorkerNode1" }); + removedWorkerNodes.put("removedWorkerNodesKey1", new String[] { "removedWorkerNode1" }); + modelRoutingTable.put("modelRoutingTableKey1", modelRoutingTableSet); + runningDeployModelTasks.put("runningDeployModelTaskKey1", runningDeployModelTaskSet); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpInput.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLSyncUpInput parsedInput = new MLSyncUpInput(streamInput); - assertArrayEquals(syncUpInput.getAddedWorkerNodes().get("addedWorkerNodesKey1"), parsedInput.getAddedWorkerNodes().get("addedWorkerNodesKey1")); - assertArrayEquals(syncUpInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1"), parsedInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1")); - assertEquals(syncUpInput.getModelRoutingTable().get("modelRoutingTableKey1"), parsedInput.getModelRoutingTable().get("modelRoutingTableKey1")); - assertEquals(syncUpInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1"), parsedInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1")); + assertArrayEquals( + syncUpInput.getAddedWorkerNodes().get("addedWorkerNodesKey1"), + parsedInput.getAddedWorkerNodes().get("addedWorkerNodesKey1") + ); + assertArrayEquals( + syncUpInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1"), + parsedInput.getRemovedWorkerNodes().get("removedWorkerNodesKey1") + ); + assertEquals( + syncUpInput.getModelRoutingTable().get("modelRoutingTableKey1"), + parsedInput.getModelRoutingTable().get("modelRoutingTableKey1") + ); + assertEquals( + syncUpInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1"), + parsedInput.getRunningDeployModelTasks().get("runningDeployModelTaskKey1") + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java index dc16986f64..5ed0ea07ac 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeRequestTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -11,17 +21,6 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodeRequestTest { @@ -35,28 +34,28 @@ public class MLSyncUpNodeRequestTest { @Before public void setUp() throws Exception { localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode3 = new DiscoveryNode( - "foo3", - "foo3", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo3", + "foo3", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); Map addedWorkerNodes = new HashMap<>(); @@ -64,23 +63,22 @@ public void setUp() throws Exception { Map> modelRoutingTable = new HashMap<>(); Map> runningDeployModelTasks = new HashMap<>(); - syncUpInput = MLSyncUpInput.builder() - .getDeployedModels(true) - .addedWorkerNodes(addedWorkerNodes) - .removedWorkerNodes(removedWorkerNodes) - .modelRoutingTable(modelRoutingTable) - .runningDeployModelTasks(runningDeployModelTasks) - .clearRoutingTable(true) - .syncRunningDeployModelTasks(true) - .build(); + syncUpInput = MLSyncUpInput + .builder() + .getDeployedModels(true) + .addedWorkerNodes(addedWorkerNodes) + .removedWorkerNodes(removedWorkerNodes) + .modelRoutingTable(modelRoutingTable) + .runningDeployModelTasks(runningDeployModelTasks) + .clearRoutingTable(true) + .syncRunningDeployModelTasks(true) + .build(); } @Test public void testConstructorSerialization1() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + String[] nodeIds = { "id1", "id2", "id3" }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); BytesStreamOutput output = new BytesStreamOutput(); syncUpNodeRequest.writeTo(output); @@ -94,10 +92,8 @@ public void testConstructorSerialization1() throws IOException { @Test public void testConstructorSerialization2() { - DiscoveryNode [] nodeIds = {localNode1, localNode2, localNode3}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + DiscoveryNode[] nodeIds = { localNode1, localNode2, localNode3 }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); assertEquals(3, syncUpNodeRequest.getSyncUpNodesRequest().concreteNodes().length); assertTrue(syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); @@ -107,9 +103,7 @@ public void testConstructorSerialization2() { @Test public void testConstructorSerialization3() { - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(localNode1, localNode2, localNode3) - ); + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(localNode1, localNode2, localNode3)); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setClearRoutingTable(true); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setSyncRunningDeployModelTasks(true); syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().setClearRoutingTable(true); @@ -122,19 +116,26 @@ public void testConstructorSerialization3() { @Test public void testConstructorFromInputStream() throws IOException { - String [] nodeIds = {"id1", "id2", "id3"}; - MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest( - new MLSyncUpNodesRequest(nodeIds, syncUpInput) - ); + String[] nodeIds = { "id1", "id2", "id3" }; + MLSyncUpNodeRequest syncUpNodeRequest = new MLSyncUpNodeRequest(new MLSyncUpNodesRequest(nodeIds, syncUpInput)); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); syncUpNodeRequest.writeTo(bytesStreamOutput); StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLSyncUpNodeRequest parsedNodeRequest = new MLSyncUpNodeRequest(streamInput); assertEquals(3, parsedNodeRequest.getSyncUpNodesRequest().nodesIds().length); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isSyncRunningDeployModelTasks()); - assertEquals(parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable()); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable() + ); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isSyncRunningDeployModelTasks() + ); + assertEquals( + parsedNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable(), + syncUpNodeRequest.getSyncUpNodesRequest().getSyncUpInput().isClearRoutingTable() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java index 56e1672852..c5b53199a6 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodeResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.*; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,36 +16,36 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; - -import static org.junit.Assert.*; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodeResponseTest { private DiscoveryNode localNode; private final String modelStatus = "modelStatus"; - private final String[] loadedModelIds = {"loadedModelIds"}; - private final String[] runningLoadModelTaskIds = {"runningLoadModelTaskIds"}; - private final String[] runningLoadModelIds = {"modelid1"}; + private final String[] loadedModelIds = { "loadedModelIds" }; + private final String[] runningLoadModelTaskIds = { "runningLoadModelTaskIds" }; + private final String[] runningLoadModelIds = { "modelid1" }; + @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); } @Test public void testSerializationDeserialization() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( + localNode, + modelStatus, + loadedModelIds, + runningLoadModelIds, + runningLoadModelTaskIds + ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = new MLSyncUpNodeResponse(output.bytes().streamInput()); @@ -51,7 +58,13 @@ public void testSerializationDeserialization() throws IOException { @Test public void testReadProfile() throws IOException { - MLSyncUpNodeResponse response = new MLSyncUpNodeResponse(localNode, modelStatus, loadedModelIds, runningLoadModelIds, runningLoadModelTaskIds); + MLSyncUpNodeResponse response = new MLSyncUpNodeResponse( + localNode, + modelStatus, + loadedModelIds, + runningLoadModelIds, + runningLoadModelTaskIds + ); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLSyncUpNodeResponse newResponse = MLSyncUpNodeResponse.readStats(output.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java index 6603f17cbb..80d2608669 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpNodesResponseTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.*; +import java.util.List; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -8,13 +14,6 @@ import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import java.io.IOException; -import java.util.*; - -import static org.junit.Assert.assertEquals; - -import java.util.List; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpNodesResponseTest { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java index 19f6f6a7d6..3b0ef66b6c 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/sync/MLSyncUpResponseTest.java @@ -1,22 +1,20 @@ package org.opensearch.ml.common.transport.sync; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.IOException; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; - @RunWith(MockitoJUnitRunner.class) public class MLSyncUpResponseTest { private String status; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java index 5d4e300904..0219c68e0e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetRequestTest.java @@ -1,5 +1,11 @@ package org.opensearch.ml.common.transport.task; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; @@ -7,12 +13,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - public class MLTaskGetRequestTest { private String taskId; @@ -23,8 +23,7 @@ public void setUp() { @Test public void writeTo_Success() throws IOException { - MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder() - .taskId(taskId).build(); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); mlTaskGetRequest.writeTo(bytesStreamOutput); MLTaskGetRequest parsedModel = new MLTaskGetRequest(bytesStreamOutput.bytes().streamInput()); @@ -41,8 +40,7 @@ public void validate_Exception_NullModelId() { @Test public void fromActionRequest_Success() { - MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder() - .taskId(taskId).build(); + MLTaskGetRequest mlTaskGetRequest = MLTaskGetRequest.builder().taskId(taskId).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java index ba4d8d7d95..2d4db967c1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/task/MLTaskGetResponseTest.java @@ -1,5 +1,12 @@ package org.opensearch.ml.common.transport.task; +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.time.Instant; +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -8,40 +15,34 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; - -import java.io.IOException; -import java.time.Instant; -import java.util.Arrays; - -import static org.junit.Assert.*; -import static org.junit.Assert.assertEquals; +import org.opensearch.ml.common.dataset.MLInputDataType; public class MLTaskGetResponseTest { MLTask mlTask; @Before public void setUp() { - mlTask = MLTask.builder() - .taskId("id") - .modelId("model id") - .taskType(MLTaskType.EXECUTION) - .functionName(FunctionName.LINEAR_REGRESSION) - .state(MLTaskState.CREATED) - .inputType(MLInputDataType.DATA_FRAME) - .progress(1.3f) - .outputIndex("some index") - .workerNodes(Arrays.asList("some node")) - .createTime(Instant.ofEpochMilli(123)) - .lastUpdateTime(Instant.ofEpochMilli(123)) - .error("error") - .user(new User()) - .async(true) - .build(); + mlTask = MLTask + .builder() + .taskId("id") + .modelId("model id") + .taskType(MLTaskType.EXECUTION) + .functionName(FunctionName.LINEAR_REGRESSION) + .state(MLTaskState.CREATED) + .inputType(MLInputDataType.DATA_FRAME) + .progress(1.3f) + .outputIndex("some index") + .workerNodes(Arrays.asList("some node")) + .createTime(Instant.ofEpochMilli(123)) + .lastUpdateTime(Instant.ofEpochMilli(123)) + .error("error") + .user(new User()) + .async(true) + .build(); } @Test @@ -72,19 +73,22 @@ public void toXContentTest() throws IOException { mlTaskGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"task_id\":\"id\"," + - "\"model_id\":\"model id\"," + - "\"task_type\":\"EXECUTION\"," + - "\"function_name\":\"LINEAR_REGRESSION\"," + - "\"state\":\"CREATED\"," + - "\"input_type\":\"DATA_FRAME\"," + - "\"progress\":1.3," + - "\"output_index\":\"some index\"," + - "\"worker_node\":[\"some node\"]," + - "\"create_time\":123," + - "\"last_update_time\":123," + - "\"error\":\"error\"," + - "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + - "\"is_async\":true}", jsonStr); + assertEquals( + "{\"task_id\":\"id\"," + + "\"model_id\":\"model id\"," + + "\"task_type\":\"EXECUTION\"," + + "\"function_name\":\"LINEAR_REGRESSION\"," + + "\"state\":\"CREATED\"," + + "\"input_type\":\"DATA_FRAME\"," + + "\"progress\":1.3," + + "\"output_index\":\"some index\"," + + "\"worker_node\":[\"some node\"]," + + "\"create_time\":123," + + "\"last_update_time\":123," + + "\"error\":\"error\"," + + "\"user\":{\"name\":\"\",\"backend_roles\":[],\"roles\":[],\"custom_attribute_names\":[],\"user_requested_tenant\":null}," + + "\"is_async\":true}", + jsonStr + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java index 7c3e9eaa06..3de755892e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java @@ -5,29 +5,29 @@ package org.opensearch.ml.common.transport.training; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.MLInput; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.Collections; -import java.util.HashMap; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; +import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; public class MLTrainingTaskRequestTest { @@ -35,14 +35,17 @@ public class MLTrainingTaskRequestTest { @Before public void setUp() { - DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }})); - mlInput = MLInput.builder() - .algorithm(FunctionName.KMEANS) - .parameters(KMeansParams.builder().centroids(1).build()) - .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) - .build(); + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() { + { + put("key1", 2.0D); + } + })); + mlInput = MLInput + .builder() + .algorithm(FunctionName.KMEANS) + .parameters(KMeansParams.builder().centroids(1).build()) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) + .build(); } @Test @@ -53,16 +56,13 @@ public void validate_Success() { @Test public void validate_SuccessWithBuilder() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); assertNull(request.validate()); } @Test public void validate_Exception_NullMLInput() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); } @@ -70,18 +70,14 @@ public void validate_Exception_NullMLInput() { @Test public void validate_Exception_NullInputDataInMLInput() { mlInput.setInputDataset(null); - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: input data can't be null;", exception.getMessage()); } @Test public void writeTo() throws IOException { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); request = new MLTrainingTaskRequest(bytesStreamOutput.bytes().streamInput()); @@ -92,17 +88,13 @@ public void writeTo() throws IOException { @Test public void fromActionRequest_WithMLTrainingTaskRequest() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); assertSame(request, MLTrainingTaskRequest.fromActionRequest(request)); } @Test public void fromActionRequest_WithNonMLTrainingTaskRequest() { - MLTrainingTaskRequest request = MLTrainingTaskRequest.builder() - .mlInput(mlInput) - .build(); + MLTrainingTaskRequest request = MLTrainingTaskRequest.builder().mlInput(mlInput).build(); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -136,4 +128,4 @@ public void writeTo(StreamOutput out) throws IOException { }; MLTrainingTaskRequest.fromActionRequest(actionRequest); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java index cca7a158cf..9b5a033d7a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskResponseTest.java @@ -5,54 +5,45 @@ package org.opensearch.ml.common.transport.training; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; + import java.io.IOException; import java.io.UncheckedIOException; import org.junit.Test; -import org.opensearch.core.action.ActionResponse; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.ml.common.output.MLTrainingOutput; import org.opensearch.ml.common.transport.MLTaskResponse; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertSame; - public class MLTrainingTaskResponseTest { @Test public void writeTo() throws IOException { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); response = new MLTaskResponse(bytesStreamOutput.bytes().streamInput()); - MLTrainingOutput modelTrainingOutput = (MLTrainingOutput)response.getOutput(); + MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput(); assertEquals("success", modelTrainingOutput.getStatus()); assertEquals("taskId", modelTrainingOutput.getModelId()); } @Test public void fromActionResponse_Success_WithMLTrainingTaskResponse() { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); assertSame(response, MLTaskResponse.fromActionResponse(response)); } @Test public void fromActionResponse_Success_WithNonMLTrainingTaskResponse() { - MLTrainingOutput output = MLTrainingOutput.builder().status("success") - .modelId("taskId").build(); - MLTaskResponse response = MLTaskResponse.builder() - .output(output) - .build(); + MLTrainingOutput output = MLTrainingOutput.builder().status("success").modelId("taskId").build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput out) throws IOException { @@ -62,8 +53,8 @@ public void writeTo(StreamOutput out) throws IOException { MLTaskResponse result = MLTaskResponse.fromActionResponse(actionResponse); assertNotSame(response, result); - MLTrainingOutput modelTrainingOutput = (MLTrainingOutput)response.getOutput(); - MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput)result.getOutput(); + MLTrainingOutput modelTrainingOutput = (MLTrainingOutput) response.getOutput(); + MLTrainingOutput resultModelTrainingOutput = (MLTrainingOutput) result.getOutput(); assertEquals(modelTrainingOutput.getStatus(), resultModelTrainingOutput.getStatus()); assertEquals(modelTrainingOutput.getModelId(), resultModelTrainingOutput.getModelId()); } @@ -79,4 +70,4 @@ public void writeTo(StreamOutput out) throws IOException { MLTaskResponse.fromActionResponse(actionResponse); } -} \ No newline at end of file +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java index c76e73afe9..283a19836e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelInputTest.java @@ -1,11 +1,16 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.*; + +import java.io.IOException; +import java.util.Collections; + import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.*; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; @@ -13,24 +18,16 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchModule; -import java.io.IOException; -import java.util.Collections; - -import static org.junit.Assert.*; - public class MLUndeployModelInputTest { private MLUndeployModelInput input; - private final String [] modelIds = new String [] {"modelId1","modelId2","modelId3"}; - private final String [] nodeIds = new String [] {"nodeId1","nodeId2","nodeId3"}; - private final String expectedInputStr = "{\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + - "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; + private final String[] modelIds = new String[] { "modelId1", "modelId2", "modelId3" }; + private final String[] nodeIds = new String[] { "nodeId1", "nodeId2", "nodeId3" }; + private final String expectedInputStr = "{\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + + "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; @Before public void setUp() throws Exception { - input = MLUndeployModelInput.builder() - .modelIds(modelIds) - .nodeIds(nodeIds) - .build(); + input = MLUndeployModelInput.builder().modelIds(modelIds).nodeIds(nodeIds).build(); } @Test @@ -44,25 +41,35 @@ public void testToXContent() throws Exception { @Test public void testParse() throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + expectedInputStr + ); parser.nextToken(); MLUndeployModelInput parsedInput = MLUndeployModelInput.parse(parser); - assertArrayEquals(new String [] {"modelId1","modelId2","modelId3"}, parsedInput.getModelIds()); - assertArrayEquals(new String [] {"nodeId1","nodeId2","nodeId3"}, parsedInput.getNodeIds()); + assertArrayEquals(new String[] { "modelId1", "modelId2", "modelId3" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "nodeId1", "nodeId2", "nodeId3" }, parsedInput.getNodeIds()); } @Test public void testParseWithInvalidField() throws Exception { - String withInvalidFieldInputStr = "{\"void\":\"void\"," + - "\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + - "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, withInvalidFieldInputStr); + String withInvalidFieldInputStr = "{\"void\":\"void\"," + + "\"model_ids\":[\"modelId1\",\"modelId2\",\"modelId3\"]," + + "\"node_ids\":[\"nodeId1\",\"nodeId2\",\"nodeId3\"]}"; + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + LoggingDeprecationHandler.INSTANCE, + withInvalidFieldInputStr + ); parser.nextToken(); MLUndeployModelInput parsedInput = MLUndeployModelInput.parse(parser); - assertArrayEquals(new String [] {"modelId1","modelId2","modelId3"}, parsedInput.getModelIds()); - assertArrayEquals(new String [] {"nodeId1","nodeId2","nodeId3"}, parsedInput.getNodeIds()); + assertArrayEquals(new String[] { "modelId1", "modelId2", "modelId3" }, parsedInput.getModelIds()); + assertArrayEquals(new String[] { "nodeId1", "nodeId2", "nodeId3" }, parsedInput.getNodeIds()); } @Test @@ -74,5 +81,4 @@ public void readInputStream() throws IOException { assertArrayEquals(input.getModelIds(), parsedInput.getModelIds()); assertArrayEquals(input.getNodeIds(), parsedInput.getNodeIds()); } - } - +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java index baa17e2d11..84afd8d166 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodeResponseTest.java @@ -1,5 +1,15 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -10,16 +20,6 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodeResponseTest { @@ -31,15 +31,15 @@ public class MLUndeployModelNodeResponseTest { @Before public void setUp() throws Exception { localNode = new DiscoveryNode( - "foo0", - "foo0", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo0", + "foo0", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); modelWorkerNodeCounts = new HashMap<>(); - modelWorkerNodeCounts.put("modelId1", new String[]{"node"}); + modelWorkerNodeCounts.put("modelId1", new String[] { "node" }); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java index 434ba2dbef..1cc69bec70 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesRequestTest.java @@ -1,5 +1,13 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.Collections; + import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; @@ -9,14 +17,6 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.transport.TransportAddress; -import java.io.IOException; -import java.net.InetAddress; -import java.util.Collections; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodesRequestTest { @@ -26,16 +26,19 @@ public class MLUndeployModelNodesRequestTest { @Test public void testConstructorSerialization1() throws IOException { - String[] modelIds = {"modelId1", "modelId2", "modelId3"}; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] modelIds = { "modelId1", "modelId2", "modelId3" }; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(nodeIds, modelIds) + new MLUndeployModelNodesRequest(nodeIds, modelIds) ); BytesStreamOutput output = new BytesStreamOutput(); undeployModelNodeRequest.writeTo(output); - assertArrayEquals(new String[] {"modelId1", "modelId2", "modelId3"}, undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds()); + assertArrayEquals( + new String[] { "modelId1", "modelId2", "modelId3" }, + undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds() + ); } @@ -43,24 +46,24 @@ public void testConstructorSerialization1() throws IOException { public void testConstructorSerialization2() throws IOException { localNode1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); localNode2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT ); MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(localNode1,localNode2) + new MLUndeployModelNodesRequest(localNode1, localNode2) ); assertEquals(2, undeployModelNodeRequest.getMlUndeployModelNodesRequest().concreteNodes().length); @@ -69,11 +72,11 @@ public void testConstructorSerialization2() throws IOException { @Test public void testConstructorFromInputStream() throws IOException { - String[] modelIds = {"modelId1", "modelId2", "modelId3"}; - String[] nodeIds = {"nodeId1", "nodeId2", "nodeId3"}; + String[] modelIds = { "modelId1", "modelId2", "modelId3" }; + String[] nodeIds = { "nodeId1", "nodeId2", "nodeId3" }; MLUndeployModelNodeRequest undeployModelNodeRequest = new MLUndeployModelNodeRequest( - new MLUndeployModelNodesRequest(nodeIds, modelIds) + new MLUndeployModelNodesRequest(nodeIds, modelIds) ); BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); undeployModelNodeRequest.writeTo(bytesStreamOutput); @@ -81,7 +84,10 @@ public void testConstructorFromInputStream() throws IOException { StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); MLUndeployModelNodeRequest parsedNodeRequest = new MLUndeployModelNodeRequest(streamInput); - assertArrayEquals(undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds(), parsedNodeRequest.getMlUndeployModelNodesRequest().getModelIds()); + assertArrayEquals( + undeployModelNodeRequest.getMlUndeployModelNodesRequest().getModelIds(), + parsedNodeRequest.getMlUndeployModelNodesRequest().getModelIds() + ); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java index d80c269679..629eb8cec0 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/undeploy/MLUndeployModelNodesResponseTest.java @@ -1,5 +1,16 @@ package org.opensearch.ml.common.transport.undeploy; +import static org.junit.Assert.assertEquals; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -9,26 +20,13 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.core.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.InetAddress; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; - @RunWith(MockitoJUnitRunner.class) public class MLUndeployModelNodesResponseTest { @@ -42,19 +40,21 @@ public class MLUndeployModelNodesResponseTest { public void setUp() throws Exception { clusterName = new ClusterName("clusterName"); node1 = new DiscoveryNode( - "foo1", - "foo1", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo1", + "foo1", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); node2 = new DiscoveryNode( - "foo2", - "foo2", - new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), - Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT); + "foo2", + "foo2", + new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); modelWorkerNodeCounts = new HashMap<>(); modelWorkerNodeCounts.put("modelId1", 1); } @@ -63,8 +63,7 @@ public void setUp() throws Exception { public void testSerializationDeserialization1() throws IOException { List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); - MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, - failuresList); + MLUndeployModelNodesResponse response = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLUndeployModelNodesResponse newResponse = new MLUndeployModelNodesResponse(output.bytes().streamInput()); @@ -78,13 +77,13 @@ public void testToXContent() throws IOException { Map modelToUndeployStatus1 = new HashMap<>(); modelToUndeployStatus1.put("modelId1", "response"); Map modelWorkerNodeCounts1 = new HashMap<>(); - modelWorkerNodeCounts1.put("modelId1", new String[]{"mockNode1"}); + modelWorkerNodeCounts1.put("modelId1", new String[] { "mockNode1" }); nodes.add(new MLUndeployModelNodeResponse(node1, modelToUndeployStatus1, modelWorkerNodeCounts1)); Map modelToUndeployStatus2 = new HashMap<>(); modelToUndeployStatus2.put("modelId2", "response"); Map modelWorkerNodeCounts2 = new HashMap<>(); - modelWorkerNodeCounts2.put("modelId2", new String[]{"mockNode2"}); + modelWorkerNodeCounts2.put("modelId2", new String[] { "mockNode2" }); nodes.add(new MLUndeployModelNodeResponse(node2, modelToUndeployStatus2, modelWorkerNodeCounts2)); List failures = new ArrayList<>(); @@ -92,8 +91,6 @@ public void testToXContent() throws IOException { XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals( - "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr); + assertEquals("{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 61e57d4ac6..ff4d6f7ed9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -1,89 +1,111 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; - -import java.io.IOException; -import java.util.function.Function; - -import static org.junit.Assert.assertEquals; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -public class MLRegisterModelMetaInputTest { - - - Function function = parser -> { - try { - return MLRegisterModelMetaInput.parse(parser); - } catch (Exception e) { - throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e); - } - }; - TextEmbeddingModelConfig config; - MLRegisterModelMetaInput mLRegisterModelMetaInput; - - @Before - public void setup() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", - TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mLRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "model_group_id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); - } - - @Test - public void parse_MLRegisterModelMetaInput() throws IOException { - TestHelper.testParse(mLRegisterModelMetaInput, function); - } - - @Test - public void readInputStream_Success() throws IOException { - readInputStream(mLRegisterModelMetaInput); - } - - - private void readInputStream(MLRegisterModelMetaInput input) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - input.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); - assertEquals(input.getName(), newInput.getName()); - assertEquals(input.getDescription(), newInput.getDescription()); - assertEquals(input.getModelFormat(), newInput.getModelFormat()); - assertEquals(input.getModelConfig().getAllConfig(), newInput.getModelConfig().getAllConfig()); - assertEquals(input.getModelConfig().getModelType(), newInput.getModelConfig().getModelType()); - } - - - @Test - public void testToXContent() throws IOException {{ - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + - "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; - assertEquals(expected, mlModelContent); - } - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + - "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; - assertEquals(expected, mlModelContent); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; + +public class MLRegisterModelMetaInputTest { + + Function function = parser -> { + try { + return MLRegisterModelMetaInput.parse(parser); + } catch (Exception e) { + throw new RuntimeException("Failed to parse MLRegisterModelMetaInput", e); + } + }; + TextEmbeddingModelConfig config; + MLRegisterModelMetaInput mLRegisterModelMetaInput; + + @Before + public void setup() { + config = new TextEmbeddingModelConfig( + "Model Type", + 123, + FrameworkType.SENTENCE_TRANSFORMERS, + "All Config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ); + mLRegisterModelMetaInput = new MLRegisterModelMetaInput( + "Model Name", + FunctionName.BATCH_RCF, + "model_group_id", + "1.0", + "Model Description", + MLModelFormat.TORCH_SCRIPT, + MLModelState.DEPLOYING, + 200L, + "123", + config, + 2, + null, + null, + null, + null + ); + } + + @Test + public void parse_MLRegisterModelMetaInput() throws IOException { + TestHelper.testParse(mLRegisterModelMetaInput, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mLRegisterModelMetaInput); + } + + private void readInputStream(MLRegisterModelMetaInput input) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLRegisterModelMetaInput newInput = new MLRegisterModelMetaInput(streamInput); + assertEquals(input.getName(), newInput.getName()); + assertEquals(input.getDescription(), newInput.getDescription()); + assertEquals(input.getModelFormat(), newInput.getModelFormat()); + assertEquals(input.getModelConfig().getAllConfig(), newInput.getModelConfig().getAllConfig()); + assertEquals(input.getModelConfig().getModelType(), newInput.getModelConfig().getModelType()); + } + + @Test + public void testToXContent() throws IOException { + { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + final String expected = + "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; + assertEquals(expected, mlModelContent); + } + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + final String expected = + "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"model_group_id\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; + assertEquals(expected, mlModelContent); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java index d7039780f0..fbd318d609 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaRequestTest.java @@ -1,105 +1,125 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.model.MLModelFormat; -import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - -public class MLRegisterModelMetaRequestTest { - - TextEmbeddingModelConfig config; - MLRegisterModelMetaInput mlRegisterModelMetaInput; - - @Before - public void setUp() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", - TextEmbeddingModelConfig.PoolingMode.MEAN, true, 512); - mlRegisterModelMetaInput = new MLRegisterModelMetaInput("Model Name", FunctionName.BATCH_RCF, "Model Group Id", "1.0", - "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.DEPLOYING, 200L, "123", config, 2, null, null, null, null); - } - - @Test - public void writeTo_Succeess() throws IOException { - MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLRegisterModelMetaRequest newRequest = new MLRegisterModelMetaRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); - assertEquals(request.getMlRegisterModelMetaInput().getDescription(), - newRequest.getMlRegisterModelMetaInput().getDescription()); - assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), - newRequest.getMlRegisterModelMetaInput().getFunctionName()); - assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), - newRequest.getMlRegisterModelMetaInput().getModelGroupId()); - } - - @Test - public void validate_Exception_NullModelId() { - MLRegisterModelMetaRequest mlRegisterModelMetaRequest = MLRegisterModelMetaRequest.builder().build(); - ActionRequestValidationException exception = mlRegisterModelMetaRequest.validate(); - assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); - } - - @Test - public void fromActionRequest_Success() { - MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLRegisterModelMetaRequest newRequest = MLRegisterModelMetaRequest.fromActionRequest(actionRequest); - assertNotSame(request, newRequest); - assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); - assertEquals(request.getMlRegisterModelMetaInput().getDescription(), - newRequest.getMlRegisterModelMetaInput().getDescription()); - assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), - newRequest.getMlRegisterModelMetaInput().getFunctionName()); - assertEquals(request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), - newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig()); - assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), - newRequest.getMlRegisterModelMetaInput().getModelGroupId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLRegisterModelMetaRequest.fromActionRequest(actionRequest); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig.FrameworkType; + +public class MLRegisterModelMetaRequestTest { + + TextEmbeddingModelConfig config; + MLRegisterModelMetaInput mlRegisterModelMetaInput; + + @Before + public void setUp() { + config = new TextEmbeddingModelConfig( + "Model Type", + 123, + FrameworkType.SENTENCE_TRANSFORMERS, + "All Config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ); + mlRegisterModelMetaInput = new MLRegisterModelMetaInput( + "Model Name", + FunctionName.BATCH_RCF, + "Model Group Id", + "1.0", + "Model Description", + MLModelFormat.TORCH_SCRIPT, + MLModelState.DEPLOYING, + 200L, + "123", + config, + 2, + null, + null, + null, + null + ); + } + + @Test + public void writeTo_Succeess() throws IOException { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLRegisterModelMetaRequest newRequest = new MLRegisterModelMetaRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals( + request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig() + ); + assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), newRequest.getMlRegisterModelMetaInput().getModelGroupId()); + } + + @Test + public void validate_Exception_NullModelId() { + MLRegisterModelMetaRequest mlRegisterModelMetaRequest = MLRegisterModelMetaRequest.builder().build(); + ActionRequestValidationException exception = mlRegisterModelMetaRequest.validate(); + assertEquals("Validation Failed: 1: Model meta input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLRegisterModelMetaRequest request = new MLRegisterModelMetaRequest(mlRegisterModelMetaInput); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLRegisterModelMetaRequest newRequest = MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + assertNotSame(request, newRequest); + assertEquals(request.getMlRegisterModelMetaInput().getName(), newRequest.getMlRegisterModelMetaInput().getName()); + assertEquals(request.getMlRegisterModelMetaInput().getDescription(), newRequest.getMlRegisterModelMetaInput().getDescription()); + assertEquals(request.getMlRegisterModelMetaInput().getFunctionName(), newRequest.getMlRegisterModelMetaInput().getFunctionName()); + assertEquals( + request.getMlRegisterModelMetaInput().getModelConfig().getAllConfig(), + newRequest.getMlRegisterModelMetaInput().getModelConfig().getAllConfig() + ); + assertEquals(request.getMlRegisterModelMetaInput().getModelGroupId(), newRequest.getMlRegisterModelMetaInput().getModelGroupId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLRegisterModelMetaRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java index 92f66530e9..fb8fcd0e81 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaResponseTest.java @@ -1,50 +1,49 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.TestHelper; - -public class MLRegisterModelMetaResponseTest { - - MLRegisterModelMetaResponse mlRegisterModelMetaResponse; - - @Before - public void setup() { - mlRegisterModelMetaResponse = new MLRegisterModelMetaResponse("Model Id", "Status"); - } - - - @Test - public void writeTo_Success() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlRegisterModelMetaResponse.writeTo(bytesStreamOutput); - MLRegisterModelMetaResponse newResponse = new MLRegisterModelMetaResponse(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlRegisterModelMetaResponse.getModelId(), newResponse.getModelId()); - assertEquals(mlRegisterModelMetaResponse.getStatus(), newResponse.getStatus()); - } - - @Test - public void testToXContent() throws IOException { - MLRegisterModelMetaResponse response = new MLRegisterModelMetaResponse("Model Id", "Status"); - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - response.toXContent(builder, EMPTY_PARAMS); - assertNotNull(builder); - String jsonStr = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"model_id\":\"Model Id\",\"status\":\"Status\"}"; - assertEquals(expected, jsonStr); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLRegisterModelMetaResponseTest { + + MLRegisterModelMetaResponse mlRegisterModelMetaResponse; + + @Before + public void setup() { + mlRegisterModelMetaResponse = new MLRegisterModelMetaResponse("Model Id", "Status"); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlRegisterModelMetaResponse.writeTo(bytesStreamOutput); + MLRegisterModelMetaResponse newResponse = new MLRegisterModelMetaResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlRegisterModelMetaResponse.getModelId(), newResponse.getModelId()); + assertEquals(mlRegisterModelMetaResponse.getStatus(), newResponse.getStatus()); + } + + @Test + public void testToXContent() throws IOException { + MLRegisterModelMetaResponse response = new MLRegisterModelMetaResponse("Model Id", "Status"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"model_id\":\"Model Id\",\"status\":\"Status\"}"; + assertEquals(expected, jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java index cafec77356..a04e5fcb51 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkInputTest.java @@ -1,107 +1,116 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; -import java.util.Collections; -import java.util.function.Function; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.TestHelper; -import org.opensearch.search.SearchModule; - -public class MLUploadModelChunkInputTest { - - MLUploadModelChunkInput mlUploadModelChunkInput; - private Function function = parser -> { - try { - return MLUploadModelChunkInput.parse(parser, new byte[] { 12, 4, 5, 3 }); - } catch (Exception e) { - throw new RuntimeException("Failed to parse MLUploadModelChunkInput", e); - } - }; - - @Before - public void setup() { - mlUploadModelChunkInput = MLUploadModelChunkInput.builder().modelId("modelId").chunkNumber(1) - .content(new byte[] { 1, 3, 4 }).build(); - } - - @Test - public void parse_MLUploadModelChunkInput() throws IOException { - TestHelper.testParse(mlUploadModelChunkInput, function); - } - - @Test - public void readInputStream_Success() throws IOException { - readInputStream(mlUploadModelChunkInput); - } - - private void readInputStream(MLUploadModelChunkInput input) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - input.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLUploadModelChunkInput newInput = new MLUploadModelChunkInput(streamInput); - assertEquals(input.getChunkNumber(), newInput.getChunkNumber()); - assertEquals(input.getModelId(), newInput.getModelId()); - } - - @Test - public void testMLUploadModelChunkInputConstructor() { - MLUploadModelChunkInput input = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); - assertNotNull(input); - } - - @Test - public void testMLUploadModelChunkInputWriteToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlUploadModelChunkInput.writeTo(bytesStreamOutput); - final var newLlUploadModelChunkInput = new MLUploadModelChunkInput(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlUploadModelChunkInput.getModelId(), newLlUploadModelChunkInput.getModelId()); - assertEquals(mlUploadModelChunkInput.getChunkNumber(), newLlUploadModelChunkInput.getChunkNumber()); - } - - @Test - public void testToXContent() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_id\":\"modelId\",\"chunk_number\":1,\"model_content\":\"AQME\"}", mlModelContent); - } - - @Test - public void testMLUploadModelChunkInputParser() throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder(); - builder = mlUploadModelChunkInput.toXContent(builder, null); - String json = builder.toString(); - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry( - new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), null, json); - parser.nextToken(); - MLUploadModelChunkInput newMlUploadModelChunkInput = MLUploadModelChunkInput.parse(parser, new byte[] { 1, 3, 4 }); - assertEquals(mlUploadModelChunkInput, newMlUploadModelChunkInput); - } - - @Test - public void testMLUploadModelChunkInputParser_XContentParser() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); - String mlModelContent = TestHelper.xContentBuilderToString(builder); - TestHelper.testParseFromString(mlUploadModelChunkInput, mlModelContent, function); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Function; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; +import org.opensearch.search.SearchModule; + +public class MLUploadModelChunkInputTest { + + MLUploadModelChunkInput mlUploadModelChunkInput; + private Function function = parser -> { + try { + return MLUploadModelChunkInput.parse(parser, new byte[] { 12, 4, 5, 3 }); + } catch (Exception e) { + throw new RuntimeException("Failed to parse MLUploadModelChunkInput", e); + } + }; + + @Before + public void setup() { + mlUploadModelChunkInput = MLUploadModelChunkInput + .builder() + .modelId("modelId") + .chunkNumber(1) + .content(new byte[] { 1, 3, 4 }) + .build(); + } + + @Test + public void parse_MLUploadModelChunkInput() throws IOException { + TestHelper.testParse(mlUploadModelChunkInput, function); + } + + @Test + public void readInputStream_Success() throws IOException { + readInputStream(mlUploadModelChunkInput); + } + + private void readInputStream(MLUploadModelChunkInput input) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUploadModelChunkInput newInput = new MLUploadModelChunkInput(streamInput); + assertEquals(input.getChunkNumber(), newInput.getChunkNumber()); + assertEquals(input.getModelId(), newInput.getModelId()); + } + + @Test + public void testMLUploadModelChunkInputConstructor() { + MLUploadModelChunkInput input = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); + assertNotNull(input); + } + + @Test + public void testMLUploadModelChunkInputWriteToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUploadModelChunkInput.writeTo(bytesStreamOutput); + final var newLlUploadModelChunkInput = new MLUploadModelChunkInput(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlUploadModelChunkInput.getModelId(), newLlUploadModelChunkInput.getModelId()); + assertEquals(mlUploadModelChunkInput.getChunkNumber(), newLlUploadModelChunkInput.getChunkNumber()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"model_id\":\"modelId\",\"chunk_number\":1,\"model_content\":\"AQME\"}", mlModelContent); + } + + @Test + public void testMLUploadModelChunkInputParser() throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = mlUploadModelChunkInput.toXContent(builder, null); + String json = builder.toString(); + XContentParser parser = XContentType.JSON + .xContent() + .createParser( + new NamedXContentRegistry(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()), + null, + json + ); + parser.nextToken(); + MLUploadModelChunkInput newMlUploadModelChunkInput = MLUploadModelChunkInput.parse(parser, new byte[] { 1, 3, 4 }); + assertEquals(mlUploadModelChunkInput, newMlUploadModelChunkInput); + } + + @Test + public void testMLUploadModelChunkInputParser_XContentParser() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + TestHelper.testParseFromString(mlUploadModelChunkInput, mlModelContent, function); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java index 9571c5db53..4f20046077 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkRequestTest.java @@ -1,84 +1,81 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; -import java.io.UncheckedIOException; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; - -public class MLUploadModelChunkRequestTest { - - MLUploadModelChunkInput mlUploadModelChunkInput; - - @Before - public void setUp() { - mlUploadModelChunkInput = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); - } - - - @Test - public void writeTo_Succeess() throws IOException { - MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLUploadModelChunkRequest newRequest = new MLUploadModelChunkRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals(request.getUploadModelChunkInput(), newRequest.getUploadModelChunkInput()); - } - - @Test - public void validate_Exception_NullModelId() { - MLUploadModelChunkRequest mlUploadModelChunkRequest = MLUploadModelChunkRequest.builder().build(); - ActionRequestValidationException exception = mlUploadModelChunkRequest.validate(); - assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); - } - - - @Test - public void fromActionRequest_Success() { - MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLUploadModelChunkRequest result = MLUploadModelChunkRequest.fromActionRequest(actionRequest); - assertNotSame(request, result); - assertEquals(request.getUploadModelChunkInput(), result.getUploadModelChunkInput()); - } - - - @Test(expected = UncheckedIOException.class) - public void fromActionRequest_IOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLUploadModelChunkRequest.fromActionRequest(actionRequest); - } - -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.io.UncheckedIOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; + +public class MLUploadModelChunkRequestTest { + + MLUploadModelChunkInput mlUploadModelChunkInput; + + @Before + public void setUp() { + mlUploadModelChunkInput = new MLUploadModelChunkInput("modelId", 1, new byte[] { 12, 3 }); + } + + @Test + public void writeTo_Succeess() throws IOException { + MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUploadModelChunkRequest newRequest = new MLUploadModelChunkRequest(bytesStreamOutput.bytes().streamInput()); + assertEquals(request.getUploadModelChunkInput(), newRequest.getUploadModelChunkInput()); + } + + @Test + public void validate_Exception_NullModelId() { + MLUploadModelChunkRequest mlUploadModelChunkRequest = MLUploadModelChunkRequest.builder().build(); + ActionRequestValidationException exception = mlUploadModelChunkRequest.validate(); + assertEquals("Validation Failed: 1: ML input can't be null;", exception.getMessage()); + } + + @Test + public void fromActionRequest_Success() { + MLUploadModelChunkRequest request = new MLUploadModelChunkRequest(mlUploadModelChunkInput); + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLUploadModelChunkRequest result = MLUploadModelChunkRequest.fromActionRequest(actionRequest); + assertNotSame(request, result); + assertEquals(request.getUploadModelChunkInput(), result.getUploadModelChunkInput()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequest_IOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUploadModelChunkRequest.fromActionRequest(actionRequest); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java index 9bff6e68de..14aa51f7ae 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLUploadModelChunkResponseTest.java @@ -1,47 +1,47 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.upload_chunk; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; - -import java.io.IOException; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.TestHelper; - -public class MLUploadModelChunkResponseTest { - - MLUploadModelChunkResponse mlUploadModelChunkResponse; - - @Before - public void setup() { - mlUploadModelChunkResponse = new MLUploadModelChunkResponse("Status"); - } - - @Test - public void writeTo_Success() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - mlUploadModelChunkResponse.writeTo(bytesStreamOutput); - MLUploadModelChunkResponse newResponse = new MLUploadModelChunkResponse(bytesStreamOutput.bytes().streamInput()); - assertEquals(mlUploadModelChunkResponse.getStatus(), newResponse.getStatus()); - } - - @Test - public void testToXContent() throws IOException { - XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); - mlUploadModelChunkResponse.toXContent(builder, EMPTY_PARAMS); - assertNotNull(builder); - String jsonStr = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"status\":\"Status\"}"; - assertEquals(expected, jsonStr); - } -} +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.upload_chunk; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; + +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.TestHelper; + +public class MLUploadModelChunkResponseTest { + + MLUploadModelChunkResponse mlUploadModelChunkResponse; + + @Before + public void setup() { + mlUploadModelChunkResponse = new MLUploadModelChunkResponse("Status"); + } + + @Test + public void writeTo_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + mlUploadModelChunkResponse.writeTo(bytesStreamOutput); + MLUploadModelChunkResponse newResponse = new MLUploadModelChunkResponse(bytesStreamOutput.bytes().streamInput()); + assertEquals(mlUploadModelChunkResponse.getStatus(), newResponse.getStatus()); + } + + @Test + public void testToXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mlUploadModelChunkResponse.toXContent(builder, EMPTY_PARAMS); + assertNotNull(builder); + String jsonStr = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"status\":\"Status\"}"; + assertEquals(expected, jsonStr); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index a4b34d75b5..44a72044a8 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -1,12 +1,12 @@ package org.opensearch.ml.common.utils; -import org.junit.Assert; -import org.junit.Test; - import java.util.HashMap; import java.util.List; import java.util.Map; +import org.junit.Assert; +import org.junit.Test; + public class StringUtilsTest { @Test @@ -45,12 +45,13 @@ public void fromJson_SimpleMap() { @Test public void fromJson_NestedMap() { - Map response = StringUtils.fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); + Map response = StringUtils + .fromJson("{\"key\": {\"nested_key\": \"nested_value\", \"nested_array\": [1, \"a\"]}}", "response"); Assert.assertEquals(1, response.size()); Assert.assertTrue(response.get("key") instanceof Map); - Map nestedMap = (Map)response.get("key"); + Map nestedMap = (Map) response.get("key"); Assert.assertEquals("nested_value", nestedMap.get("nested_key")); - List list = (List)nestedMap.get("nested_array"); + List list = (List) nestedMap.get("nested_array"); Assert.assertEquals(2, list.size()); Assert.assertEquals(1.0, list.get(0)); Assert.assertEquals("a", list.get(1)); @@ -61,7 +62,7 @@ public void fromJson_SimpleList() { Map response = StringUtils.fromJson("[1, \"a\"]", "response"); Assert.assertEquals(1, response.size()); Assert.assertTrue(response.get("response") instanceof List); - List list = (List)response.get("response"); + List list = (List) response.get("response"); Assert.assertEquals(1.0, list.get(0)); Assert.assertEquals("a", list.get(1)); } @@ -71,7 +72,7 @@ public void fromJson_NestedList() { Map response = StringUtils.fromJson("[1, \"a\", [2, 3], {\"key\": \"value\"}]", "response"); Assert.assertEquals(1, response.size()); Assert.assertTrue(response.get("response") instanceof List); - List list = (List)response.get("response"); + List list = (List) response.get("response"); Assert.assertEquals(1.0, list.get(0)); Assert.assertEquals("a", list.get(1)); Assert.assertTrue(list.get(2) instanceof List); @@ -84,8 +85,8 @@ public void getParameterMap() { parameters.put("key1", "value1"); parameters.put("key2", 2); parameters.put("key3", 2.1); - parameters.put("key4", new int[]{10, 20}); - parameters.put("key5", new Object[]{1.01, "abc"}); + parameters.put("key4", new int[] { 10, 20 }); + parameters.put("key5", new Object[] { 1.01, "abc" }); Map parameterMap = StringUtils.getParameterMap(parameters); System.out.println(parameterMap); Assert.assertEquals(5, parameterMap.size());