From d0895bb1f361d1a8f0f0b897abea89ad9dcdb0f0 Mon Sep 17 00:00:00 2001 From: Sicheng Song Date: Tue, 23 Jan 2024 12:06:20 -0800 Subject: [PATCH] Adding UT coverage for in-cache update and fine-tuning throttling feature (#1837) --- .../org/opensearch/ml/common/CommonValue.java | 842 +++++++++--------- .../org/opensearch/ml/common/MLModel.java | 115 +-- ...ModelController.java => MLController.java} | 79 +- .../ml/common/controller/MLRateLimiter.java | 88 +- ...ion.java => MLControllerDeleteAction.java} | 8 +- ...st.java => MLControllerDeleteRequest.java} | 18 +- .../controller/MLControllerGetAction.java | 17 + ...quest.java => MLControllerGetRequest.java} | 20 +- ...onse.java => MLControllerGetResponse.java} | 32 +- .../controller/MLCreateControllerAction.java | 17 + ...st.java => MLCreateControllerRequest.java} | 37 +- ...e.java => MLCreateControllerResponse.java} | 18 +- .../MLCreateModelControllerAction.java | 17 - ...ion.java => MLDeployControllerAction.java} | 8 +- .../MLDeployControllerNodeRequest.java | 32 + ...va => MLDeployControllerNodeResponse.java} | 28 +- ...va => MLDeployControllerNodesRequest.java} | 10 +- ...a => MLDeployControllerNodesResponse.java} | 21 +- .../MLDeployModelControllerNodeRequest.java | 32 - .../MLModelControllerGetAction.java | 15 - ...n.java => MLUndeployControllerAction.java} | 8 +- .../MLUndeployControllerNodeRequest.java | 33 + ... => MLUndeployControllerNodeResponse.java} | 28 +- ... => MLUndeployControllerNodesRequest.java} | 10 +- ...=> MLUndeployControllerNodesResponse.java} | 23 +- .../MLUndeployModelControllerNodeRequest.java | 34 - ...ion.java => MLUpdateControllerAction.java} | 6 +- ...st.java => MLUpdateControllerRequest.java} | 34 +- .../transport/model/MLUpdateModelInput.java | 43 +- .../register/MLRegisterModelInput.java | 86 +- .../MLRegisterModelMetaInput.java | 70 +- .../common/controller/MLControllerTest.java | 365 ++++++++ .../controller/MLModelControllerTest.java | 329 ------- .../common/controller/MLRateLimiterTest.java | 55 +- ...ava => MLControllerDeleteRequestTest.java} | 20 +- ...t.java => MLControllerGetRequestTest.java} | 18 +- ....java => MLControllerGetResponseTest.java} | 55 +- .../MLCreateControllerRequestTest.java | 128 +++ ...va => MLCreateControllerResponseTest.java} | 22 +- .../MLCreateModelControllerRequestTest.java | 124 --- ...> MLDeployControllerNodeResponseTest.java} | 20 +- .../MLDeployControllerNodesRequestTest.java | 77 ++ ... MLDeployControllerNodesResponseTest.java} | 38 +- ...DeployModelControllerNodesRequestTest.java | 77 -- ...MLUndeployControllerNodeResponseTest.java} | 25 +- .../MLUndeployControllerNodesRequestTest.java | 76 ++ ...LUndeployControllerNodesResponseTest.java} | 41 +- ...deployModelControllerNodesRequestTest.java | 76 -- .../MLUpdateControllerRequestTest.java | 129 +++ .../MLUpdateModelControllerRequestTest.java | 124 --- .../model/MLUpdateModelInputTest.java | 82 +- .../MLModelGroupDeleteRequestTest.java | 5 +- .../remote/AwsConnectorExecutor.java | 2 +- .../remote/HttpJsonConnectorExecutor.java | 2 +- .../remote/RemoteConnectorExecutor.java | 6 +- .../engine/algorithms/remote/RemoteModel.java | 4 +- .../opensearch/ml/engine/indices/MLIndex.java | 8 +- .../ml/engine/indices/MLIndicesHandler.java | 13 +- ...a => CreateControllerTransportAction.java} | 130 ++- ...a => DeleteControllerTransportAction.java} | 79 +- ...a => DeployControllerTransportAction.java} | 58 +- ...java => GetControllerTransportAction.java} | 33 +- ...=> UndeployControllerTransportAction.java} | 58 +- ...a => UpdateControllerTransportAction.java} | 121 ++- .../models/DeleteModelTransportAction.java | 30 +- .../models/UpdateModelTransportAction.java | 7 +- .../TransportPredictionTaskAction.java | 7 +- .../org/opensearch/ml/model/MLModelCache.java | 5 +- .../ml/model/MLModelCacheHelper.java | 140 ++- .../opensearch/ml/model/MLModelManager.java | 176 ++-- .../ml/plugin/MachineLearningPlugin.java | 69 +- ...java => RestMLCreateControllerAction.java} | 34 +- ...java => RestMLDeleteControllerAction.java} | 23 +- ...on.java => RestMLGetControllerAction.java} | 29 +- ...java => RestMLUpdateControllerAction.java} | 34 +- .../ml/stats/MLClusterLevelStat.java | 2 +- ...CreateControllerTransportActionTests.java} | 162 ++-- ...DeleteControllerTransportActionTests.java} | 133 ++- ...DeployControllerTransportActionTests.java} | 57 +- ...=> GetControllerTransportActionTests.java} | 68 +- ...deployControllerTransportActionTests.java} | 61 +- ...UpdateControllerTransportActionTests.java} | 187 ++-- .../UpdateModelTransportActionTests.java | 372 ++++++-- .../ml/model/MLModelCacheHelperTests.java | 61 +- ...=> RestMLCreateControllerActionTests.java} | 62 +- ...=> RestMLDeleteControllerActionTests.java} | 30 +- ...va => RestMLGetControllerActionTests.java} | 34 +- ...=> RestMLUpdateControllerActionTests.java} | 62 +- 88 files changed, 3232 insertions(+), 2842 deletions(-) rename common/src/main/java/org/opensearch/ml/common/controller/{MLModelController.java => MLController.java} (58%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLModelControllerDeleteAction.java => MLControllerDeleteAction.java} (55%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLModelControllerDeleteRequest.java => MLControllerDeleteRequest.java} (72%) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetAction.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLModelControllerGetRequest.java => MLControllerGetRequest.java} (75%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLModelControllerGetResponse.java => MLControllerGetResponse.java} (56%) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerAction.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLCreateModelControllerRequest.java => MLCreateControllerRequest.java} (60%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLCreateModelControllerResponse.java => MLCreateControllerResponse.java} (73%) delete mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerAction.java => MLDeployControllerAction.java} (56%) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerNodeResponse.java => MLDeployControllerNodeResponse.java} (50%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerNodesRequest.java => MLDeployControllerNodesRequest.java} (64%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerNodesResponse.java => MLDeployControllerNodesResponse.java} (59%) delete mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java delete mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerAction.java => MLUndeployControllerAction.java} (54%) create mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerNodeResponse.java => MLUndeployControllerNodeResponse.java} (50%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerNodesRequest.java => MLUndeployControllerNodesRequest.java} (63%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerNodesResponse.java => MLUndeployControllerNodesResponse.java} (57%) delete mode 100644 common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUpdateModelControllerAction.java => MLUpdateControllerAction.java} (62%) rename common/src/main/java/org/opensearch/ml/common/transport/controller/{MLUpdateModelControllerRequest.java => MLUpdateControllerRequest.java} (61%) create mode 100644 common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLModelControllerDeleteRequestTest.java => MLControllerDeleteRequestTest.java} (76%) rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLModelControllerGetRequestTest.java => MLControllerGetRequestTest.java} (77%) rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLModelControllerGetResponseTest.java => MLControllerGetResponseTest.java} (53%) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLCreateModelControllerResponseTest.java => MLCreateControllerResponseTest.java} (74%) delete mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerNodeResponseTest.java => MLDeployControllerNodeResponseTest.java} (68%) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLDeployModelControllerNodesResponseTest.java => MLDeployControllerNodesResponseTest.java} (67%) delete mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerNodeResponseTest.java => MLUndeployControllerNodeResponseTest.java} (69%) create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java rename common/src/test/java/org/opensearch/ml/common/transport/controller/{MLUndeployModelControllerNodesResponseTest.java => MLUndeployControllerNodesResponseTest.java} (67%) delete mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java create mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java delete mode 100644 common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java rename plugin/src/main/java/org/opensearch/ml/action/controller/{CreateModelControllerTransportAction.java => CreateControllerTransportAction.java} (64%) rename plugin/src/main/java/org/opensearch/ml/action/controller/{DeleteModelControllerTransportAction.java => DeleteControllerTransportAction.java} (75%) rename plugin/src/main/java/org/opensearch/ml/action/controller/{DeployModelControllerTransportAction.java => DeployControllerTransportAction.java} (53%) rename plugin/src/main/java/org/opensearch/ml/action/controller/{GetModelControllerTransportAction.java => GetControllerTransportAction.java} (81%) rename plugin/src/main/java/org/opensearch/ml/action/controller/{UndeployModelControllerTransportAction.java => UndeployControllerTransportAction.java} (52%) rename plugin/src/main/java/org/opensearch/ml/action/controller/{UpdateModelControllerTransportAction.java => UpdateControllerTransportAction.java} (66%) rename plugin/src/main/java/org/opensearch/ml/rest/{RestMLCreateModelControllerAction.java => RestMLCreateControllerAction.java} (55%) rename plugin/src/main/java/org/opensearch/ml/rest/{RestMLDeleteModelControllerAction.java => RestMLDeleteControllerAction.java} (52%) rename plugin/src/main/java/org/opensearch/ml/rest/{RestMLGetModelControllerAction.java => RestMLGetControllerAction.java} (53%) rename plugin/src/main/java/org/opensearch/ml/rest/{RestMLUpdateModelControllerAction.java => RestMLUpdateControllerAction.java} (56%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{CreateModelControllerTransportActionTests.java => CreateControllerTransportActionTests.java} (62%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{DeleteModelControllerTransportActionTests.java => DeleteControllerTransportActionTests.java} (66%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{DeployModelControllerTransportActionTests.java => DeployControllerTransportActionTests.java} (61%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{GetModelControllerTransportActionTests.java => GetControllerTransportActionTests.java} (75%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{UndeployModelControllerTransportActionTests.java => UndeployControllerTransportActionTests.java} (61%) rename plugin/src/test/java/org/opensearch/ml/action/controller/{UpdateModelControllerTransportActionTests.java => UpdateControllerTransportActionTests.java} (62%) rename plugin/src/test/java/org/opensearch/ml/rest/{RestMLCreateModelControllerActionTests.java => RestMLCreateControllerActionTests.java} (66%) rename plugin/src/test/java/org/opensearch/ml/rest/{RestMLDeleteModelControllerActionTests.java => RestMLDeleteControllerActionTests.java} (67%) rename plugin/src/test/java/org/opensearch/ml/rest/{RestMLGetModelControllerActionTests.java => RestMLGetControllerActionTests.java} (64%) rename plugin/src/test/java/org/opensearch/ml/rest/{RestMLUpdateModelControllerActionTests.java => RestMLUpdateControllerActionTests.java} (66%) diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index bf07ad5040..4ff3917080 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -7,7 +7,7 @@ import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.connector.AbstractConnector; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.APPLICATION_TYPE_FIELD; import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD; @@ -35,434 +35,436 @@ public class CommonValue { - public static Integer NO_SCHEMA_VERSION = 0; - public static final String REMOTE_SERVICE_ERROR = "Error from remote service: "; - public static final String USER = "user"; - public static final String META = "_meta"; - public static final String SCHEMA_VERSION_FIELD = "schema_version"; - public static final String UNDEPLOYED = "undeployed"; - public static final String NOT_FOUND = "not_found"; + public static Integer NO_SCHEMA_VERSION = 0; + public static final String REMOTE_SERVICE_ERROR = "Error from remote service: "; + public static final String USER = "user"; + public static final String META = "_meta"; + public static final String SCHEMA_VERSION_FIELD = "schema_version"; + public static final String UNDEPLOYED = "undeployed"; + public static final String NOT_FOUND = "not_found"; - public static final String MASTER_KEY = "master_key"; - public static final String CREATE_TIME_FIELD = "create_time"; + public static final String MASTER_KEY = "master_key"; + public static final String CREATE_TIME_FIELD = "create_time"; - public static final String BOX_TYPE_KEY = "box_type"; - //hot node - public static String HOT_BOX_TYPE = "hot"; - // warm node - public static String WARM_BOX_TYPE = "warm"; - public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; - public static final String ML_MODEL_INDEX = ".plugins-ml-model"; - public static final String ML_TASK_INDEX = ".plugins-ml-task"; - public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9; - public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; - public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; - public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; - public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; - public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; - public static final String ML_MODEL_CONTROLLER_INDEX = ".plugins-ml-controller"; - public static final Integer ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION = 1; - public static final String ML_MAP_RESPONSE_KEY = "response"; - public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; - public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; - public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; - public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; - public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; - public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; - public static final String USER_FIELD_MAPPING = " \"" - + CommonValue.USER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " }\n"; - public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + - " \"_meta\": {\n" + - " \"schema_version\": "+ML_MODEL_GROUP_INDEX_SCHEMA_VERSION+"\n" + - " },\n" + - " \"properties\": {\n" + - " \""+MLModelGroup.MODEL_GROUP_NAME_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.DESCRIPTION_FIELD+"\": {\n" + - " \"type\": \"text\"\n" + - " },\n" + - " \""+MLModelGroup.LATEST_VERSION_FIELD+"\": {\n" + - " \"type\": \"integer\"\n" + - " },\n" + - " \""+MLModelGroup.MODEL_GROUP_ID_FIELD+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.BACKEND_ROLES_FIELD+"\": {\n" + - " \"type\": \"text\",\n" + - " \"fields\": {\n" + - " \"keyword\": {\n" + - " \"type\": \"keyword\",\n" + - " \"ignore_above\": 256\n" + - " }\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.ACCESS+"\": {\n" + - " \"type\": \"keyword\"\n" + - " },\n" + - " \""+MLModelGroup.OWNER+"\": {\n" + - " \"type\": \"nested\",\n" + - " \"properties\": {\n" + - " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + - " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + - " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + - " }\n" + - " },\n" + - " \""+MLModelGroup.CREATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + - " \""+MLModelGroup.LAST_UPDATED_TIME_FIELD+"\": {\n" + - " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + - " }\n" + - "}"; + public static final String BOX_TYPE_KEY = "box_type"; + // hot node + public static String HOT_BOX_TYPE = "hot"; + // warm node + public static String WARM_BOX_TYPE = "warm"; + public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; + public static final String ML_MODEL_INDEX = ".plugins-ml-model"; + public static final String ML_TASK_INDEX = ".plugins-ml-task"; + public static final Integer ML_MODEL_GROUP_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_MODEL_INDEX_SCHEMA_VERSION = 9; + public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; + public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 2; + public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 2; + public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 2; + public static final String ML_CONTROLLER_INDEX = ".plugins-ml-controller"; + public static final Integer ML_CONTROLLER_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MAP_RESPONSE_KEY = "response"; + public static final String ML_AGENT_INDEX = ".plugins-ml-agent"; + public static final Integer ML_AGENT_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_META_INDEX = ".plugins-ml-memory-meta"; + public static final Integer ML_MEMORY_META_INDEX_SCHEMA_VERSION = 1; + public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; + public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1; + public static final String USER_FIELD_MAPPING = " \"" + + CommonValue.USER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " }\n"; + public static final String ML_MODEL_GROUP_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + ML_MODEL_GROUP_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + MLModelGroup.MODEL_GROUP_NAME_FIELD + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + MLModelGroup.DESCRIPTION_FIELD + "\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"" + MLModelGroup.LATEST_VERSION_FIELD + "\": {\n" + + " \"type\": \"integer\"\n" + + " },\n" + + " \"" + MLModelGroup.MODEL_GROUP_ID_FIELD + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + MLModelGroup.BACKEND_ROLES_FIELD + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + MLModelGroup.ACCESS + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + MLModelGroup.OWNER + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + + " }\n" + + " },\n" + + " \"" + MLModelGroup.CREATED_TIME_FIELD + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + MLModelGroup.LAST_UPDATED_TIME_FIELD + "\": {\n" + + " \"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" - + " \"" - + AbstractConnector.NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + AbstractConnector.VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + AbstractConnector.PROTOCOL_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + AbstractConnector.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.CREDENTIAL_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + AbstractConnector.ACTIONS_FIELD - + "\" : {\"type\": \"flat_object\"}\n"; + public static final String ML_CONNECTOR_INDEX_FIELDS = " \"properties\": {\n" + + " \"" + + AbstractConnector.NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + AbstractConnector.VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + AbstractConnector.PROTOCOL_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + AbstractConnector.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.CREDENTIAL_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + AbstractConnector.ACTIONS_FIELD + + "\" : {\"type\": \"flat_object\"}\n"; - public static final String ML_MODEL_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_MODEL_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLModel.ALGORITHM_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLModel.OLD_MODEL_VERSION_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_VERSION_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_GROUP_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_FIELD - + "\" : {\"type\": \"binary\"},\n" - + " \"" - + MLModel.CHUNK_NUMBER_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.TOTAL_CHUNKS_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.MODEL_ID_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLModel.MODEL_FORMAT_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_STATE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD - + "\" : {\"type\": \"long\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.PLANNING_WORKER_NODES_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.DEPLOY_TO_ALL_NODES_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.IS_HIDDEN_FIELD - + "\": {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.MODEL_CONFIG_FIELD - + "\" : {\"properties\":{\"" - + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" - + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" - + POOLING_MODE_FIELD + "\":{\"type\":\"keyword\"},\"" - + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" - + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" - + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" - + " \"" - + MLModel.IS_ENABLED_FIELD - + "\" : {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD - + "\" : {\"type\": \"boolean\"},\n" - + " \"" - + MLModel.MODEL_RATE_LIMITER_CONFIG_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD - + "\" : {\"type\": \"keyword\"},\n" - + " \"" - + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD - + "\" : {\"type\": \"integer\"},\n" - + " \"" - + MLModel.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_REGISTERED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_DEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.LAST_UNDEPLOYED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLModel.CONNECTOR_FIELD - + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String ML_MODEL_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_MODEL_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLModel.ALGORITHM_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLModel.OLD_MODEL_VERSION_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_VERSION_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_GROUP_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_FIELD + + "\" : {\"type\": \"binary\"},\n" + + " \"" + + MLModel.CHUNK_NUMBER_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.TOTAL_CHUNKS_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.MODEL_ID_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLModel.MODEL_FORMAT_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_STATE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_SIZE_IN_BYTES_FIELD + + "\" : {\"type\": \"long\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CURRENT_WORKER_NODE_COUNT_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.PLANNING_WORKER_NODES_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.DEPLOY_TO_ALL_NODES_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.IS_HIDDEN_FIELD + + "\": {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.MODEL_CONFIG_FIELD + + "\" : {\"properties\":{\"" + + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" + + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" + + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" + + POOLING_MODE_FIELD + "\":{\"type\":\"keyword\"},\"" + + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" + + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" + + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" + + " \"" + + MLModel.IS_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.IS_CONTROLLER_ENABLED_FIELD + + "\" : {\"type\": \"boolean\"},\n" + + " \"" + + MLModel.RATE_LIMITER_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD + + "\" : {\"type\": \"keyword\"},\n" + + " \"" + + MLModel.AUTO_REDEPLOY_RETRY_TIMES_FIELD + + "\" : {\"type\": \"integer\"},\n" + + " \"" + + MLModel.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_REGISTERED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_DEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.LAST_UNDEPLOYED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLModel.CONNECTOR_FIELD + + "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n}," + + USER_FIELD_MAPPING + + " }\n" + + "}"; - public static final String ML_TASK_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_TASK_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLTask.MODEL_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.TASK_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.FUNCTION_NAME_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.STATE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.INPUT_TYPE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.PROGRESS_FIELD - + "\": {\"type\": \"float\"},\n" - + " \"" - + MLTask.OUTPUT_INDEX_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.WORKER_NODE_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + MLTask.CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.LAST_UPDATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLTask.ERROR_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + MLTask.IS_ASYNC_TASK_FIELD - + "\" : {\"type\" : \"boolean\"}, \n" - + USER_FIELD_MAPPING - + " }\n" - + "}"; + public static final String ML_TASK_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_TASK_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLTask.MODEL_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.TASK_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.FUNCTION_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.STATE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.INPUT_TYPE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.PROGRESS_FIELD + + "\": {\"type\": \"float\"},\n" + + " \"" + + MLTask.OUTPUT_INDEX_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.WORKER_NODE_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + MLTask.CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.LAST_UPDATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLTask.ERROR_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + MLTask.IS_ASYNC_TASK_FIELD + + "\" : {\"type\" : \"boolean\"}, \n" + + USER_FIELD_MAPPING + + " }\n" + + "}"; - public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONNECTOR_SCHEMA_VERSION - + "},\n" - + ML_CONNECTOR_INDEX_FIELDS + ",\n" - + " \"" - + MLModelGroup.BACKEND_ROLES_FIELD - + "\": {\n" - + " \"type\": \"text\",\n" - + " \"fields\": {\n" - + " \"keyword\": {\n" - + " \"type\": \"keyword\",\n" - + " \"ignore_above\": 256\n" - + " }\n" - + " }\n" - + " },\n" - + " \"" - + MLModelGroup.ACCESS - + "\": {\n" - + " \"type\": \"keyword\"\n" - + " },\n" - + " \"" - + MLModelGroup.OWNER - + "\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" - + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" - + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" - + " }\n" - + " },\n" - + " \"" - + AbstractConnector.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + AbstractConnector.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_CONNECTOR_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONNECTOR_SCHEMA_VERSION + + "},\n" + + ML_CONNECTOR_INDEX_FIELDS + ",\n" + + " \"" + + MLModelGroup.BACKEND_ROLES_FIELD + + "\": {\n" + + " \"type\": \"text\",\n" + + " \"fields\": {\n" + + " \"keyword\": {\n" + + " \"type\": \"keyword\",\n" + + " \"ignore_above\": 256\n" + + " }\n" + + " }\n" + + " },\n" + + " \"" + + MLModelGroup.ACCESS + + "\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"" + + MLModelGroup.OWNER + + "\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"name\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\", \"ignore_above\":256}}},\n" + + " \"backend_roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"roles\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}},\n" + + " \"custom_attribute_names\": {\"type\":\"text\", \"fields\":{\"keyword\":{\"type\":\"keyword\"}}}\n" + + " }\n" + + " },\n" + + " \"" + + AbstractConnector.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + AbstractConnector.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; + public static final String ML_CONFIG_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_CONFIG_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_CONFIG_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MASTER_KEY - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_CONTROLLER_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONTROLLER_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLController.USER_RATE_LIMITER + + "\" : {\"type\": \"flat_object\"}\n" + + " }\n" + + "}"; - public static final String ML_MODEL_CONTROLLER_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLModelController.USER_RATE_LIMITER_CONFIG - + "\" : {\"type\": \"flat_object\"}\n" - + " }\n" - + "}"; + public static final String ML_AGENT_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_AGENT_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MLAgent.AGENT_NAME_FIELD + + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" + + " \"" + + MLAgent.AGENT_TYPE_FIELD + + "\" : {\"type\":\"keyword\"},\n" + + " \"" + + MLAgent.DESCRIPTION_FIELD + + "\" : {\"type\": \"text\"},\n" + + " \"" + + MLAgent.LLM_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.TOOLS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.PARAMETERS_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.MEMORY_FIELD + + "\" : {\"type\": \"flat_object\"},\n" + + " \"" + + MLAgent.CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + MLAgent.LAST_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; - public static final String ML_AGENT_INDEX_MAPPING = "{\n" - + " \"_meta\": {\"schema_version\": " - + ML_AGENT_INDEX_SCHEMA_VERSION - + "},\n" - + " \"properties\": {\n" - + " \"" - + MLAgent.AGENT_NAME_FIELD - + "\" : {\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}},\n" - + " \"" - + MLAgent.AGENT_TYPE_FIELD - + "\" : {\"type\":\"keyword\"},\n" - + " \"" - + MLAgent.DESCRIPTION_FIELD - + "\" : {\"type\": \"text\"},\n" - + " \"" - + MLAgent.LLM_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.TOOLS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.PARAMETERS_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.MEMORY_FIELD - + "\" : {\"type\": \"flat_object\"},\n" - + " \"" - + MLAgent.CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + MLAgent.LAST_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" - + " }\n" - + "}"; + public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + META_NAME_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + META_CREATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + META_UPDATED_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + USER_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + APPLICATION_TYPE_FIELD + + "\": {\"type\": \"keyword\"}\n" + + " }\n" + + "}"; - public static final String ML_MEMORY_META_INDEX_MAPPING = "{\n" - + " \"_meta\": {\n" - + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" - + " },\n" - + " \"properties\": {\n" - + " \"" - + META_NAME_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + META_CREATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + META_UPDATED_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + USER_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + APPLICATION_TYPE_FIELD - + "\": {\"type\": \"keyword\"}\n" - + " }\n" - + "}"; - - public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" - + " \"_meta\": {\n" - + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" - + " },\n" - + " \"properties\": {\n" - + " \"" - + INTERACTIONS_CONVERSATION_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_CREATE_TIME_FIELD - + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" - + " \"" - + INTERACTIONS_INPUT_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_PROMPT_TEMPLATE_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_RESPONSE_FIELD - + "\": {\"type\": \"text\"},\n" - + " \"" - + INTERACTIONS_ORIGIN_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_ADDITIONAL_INFO_FIELD - + "\": {\"type\": \"flat_object\"},\n" - + " \"" - + PARENT_INTERACTIONS_ID_FIELD - + "\": {\"type\": \"keyword\"},\n" - + " \"" - + INTERACTIONS_TRACE_NUMBER_FIELD - + "\": {\"type\": \"long\"}\n" - + " }\n" - + "}"; + public static final String ML_MEMORY_MESSAGE_INDEX_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + INTERACTIONS_CONVERSATION_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + INTERACTIONS_INPUT_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_RESPONSE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_ORIGIN_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"flat_object\"},\n" + + " \"" + + PARENT_INTERACTIONS_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_TRACE_NUMBER_FIELD + + "\": {\"type\": \"long\"}\n" + + " }\n" + + "}"; } diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 8a2b50d07f..cec2805891 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -56,8 +56,8 @@ public class MLModel implements ToXContentObject { // Model level quota and throttling control public static final String IS_ENABLED_FIELD = "is_enabled"; - public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; - public static final String IS_MODEL_CONTROLLER_ENABLED_FIELD = "is_model_controller_enabled"; + public static final String RATE_LIMITER_FIELD = "rate_limiter"; + public static final String IS_CONTROLLER_ENABLED_FIELD = "is_controller_enabled"; public static final String MODEL_CONFIG_FIELD = "model_config"; public static final String CREATED_TIME_FIELD = "created_time"; public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; @@ -100,8 +100,8 @@ public class MLModel implements ToXContentObject { private String modelContentHash; private MLModelConfig modelConfig; private Boolean isEnabled; - private Boolean isModelControllerEnabled; - private MLRateLimiter modelRateLimiterConfig; + private Boolean isControllerEnabled; + private MLRateLimiter rateLimiter; private Instant createdTime; private Instant lastUpdateTime; private Instant lastRegisteredTime; @@ -120,7 +120,8 @@ public class MLModel implements ToXContentObject { private String[] planningWorkerNodes; // plan to deploy model to these nodes private boolean deployToAllNodes; - //is domain manager creates any special hidden model in the cluster this status will be true. Otherwise, + // is domain manager creates any special hidden model in the cluster this status + // will be true. Otherwise, // False by default private Boolean isHidden; @Setter @@ -129,35 +130,35 @@ public class MLModel implements ToXContentObject { @Builder(toBuilder = true) public MLModel(String name, - String modelGroupId, - FunctionName algorithm, - String version, - String content, - User user, - String description, - MLModelFormat modelFormat, - MLModelState modelState, - Long modelContentSizeInBytes, - String modelContentHash, - Boolean isEnabled, - Boolean isModelControllerEnabled, - MLRateLimiter modelRateLimiterConfig, - MLModelConfig modelConfig, - Instant createdTime, - Instant lastUpdateTime, - Instant lastRegisteredTime, - Instant lastDeployedTime, - Instant lastUndeployedTime, - Integer autoRedeployRetryTimes, - String modelId, Integer chunkNumber, - Integer totalChunks, - Integer planningWorkerNodeCount, - Integer currentWorkerNodeCount, - String[] planningWorkerNodes, - boolean deployToAllNodes, - Boolean isHidden, - Connector connector, - String connectorId) { + String modelGroupId, + FunctionName algorithm, + String version, + String content, + User user, + String description, + MLModelFormat modelFormat, + MLModelState modelState, + Long modelContentSizeInBytes, + String modelContentHash, + Boolean isEnabled, + Boolean isControllerEnabled, + MLRateLimiter rateLimiter, + MLModelConfig modelConfig, + Instant createdTime, + Instant lastUpdateTime, + Instant lastRegisteredTime, + Instant lastDeployedTime, + Instant lastUndeployedTime, + Integer autoRedeployRetryTimes, + String modelId, Integer chunkNumber, + Integer totalChunks, + Integer planningWorkerNodeCount, + Integer currentWorkerNodeCount, + String[] planningWorkerNodes, + boolean deployToAllNodes, + Boolean isHidden, + Connector connector, + String connectorId) { this.name = name; this.modelGroupId = modelGroupId; this.algorithm = algorithm; @@ -170,8 +171,8 @@ public MLModel(String name, this.modelContentSizeInBytes = modelContentSizeInBytes; this.modelContentHash = modelContentHash; this.isEnabled = isEnabled; - this.isModelControllerEnabled = isModelControllerEnabled; - this.modelRateLimiterConfig = modelRateLimiterConfig; + this.isControllerEnabled = isControllerEnabled; + this.rateLimiter = rateLimiter; this.modelConfig = modelConfig; this.createdTime = createdTime; this.lastUpdateTime = lastUpdateTime; @@ -191,7 +192,7 @@ public MLModel(String name, this.connectorId = connectorId; } - public MLModel(StreamInput input) throws IOException{ + public MLModel(StreamInput input) throws IOException { name = input.readOptionalString(); algorithm = input.readEnum(FunctionName.class); version = input.readString(); @@ -219,9 +220,9 @@ public MLModel(StreamInput input) throws IOException{ } } isEnabled = input.readOptionalBoolean(); - isModelControllerEnabled = input.readOptionalBoolean(); + isControllerEnabled = input.readOptionalBoolean(); if (input.readBoolean()) { - modelRateLimiterConfig = new MLRateLimiter(input); + rateLimiter = new MLRateLimiter(input); } createdTime = input.readOptionalInstant(); lastUpdateTime = input.readOptionalInstant(); @@ -278,10 +279,10 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(false); } out.writeOptionalBoolean(isEnabled); - out.writeOptionalBoolean(isModelControllerEnabled); - if (modelRateLimiterConfig != null) { + out.writeOptionalBoolean(isControllerEnabled); + if (rateLimiter != null) { out.writeBoolean(true); - modelRateLimiterConfig.writeTo(out); + rateLimiter.writeTo(out); } else { out.writeBoolean(false); } @@ -351,11 +352,11 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isEnabled != null) { builder.field(IS_ENABLED_FIELD, isEnabled); } - if (isModelControllerEnabled != null) { - builder.field(IS_MODEL_CONTROLLER_ENABLED_FIELD, isModelControllerEnabled); + if (isControllerEnabled != null) { + builder.field(IS_CONTROLLER_ENABLED_FIELD, isControllerEnabled); } - if (modelRateLimiterConfig != null) { - builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + if (rateLimiter != null) { + builder.field(RATE_LIMITER_FIELD, rateLimiter); } if (createdTime != null) { builder.field(CREATED_TIME_FIELD, createdTime.toEpochMilli()); @@ -426,8 +427,8 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws String modelContentHash = null; MLModelConfig modelConfig = null; Boolean isEnabled = null; - Boolean isModelControllerEnabled = null; - MLRateLimiter modelRateLimiterConfig = null; + Boolean isControllerEnabled = null; + MLRateLimiter rateLimiter = null; Instant createdTime = null; Instant lastUpdateTime = null; Instant lastUploadedTime = null; @@ -516,11 +517,11 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws case IS_ENABLED_FIELD: isEnabled = parser.booleanValue(); break; - case IS_MODEL_CONTROLLER_ENABLED_FIELD: - isModelControllerEnabled = parser.booleanValue(); + case IS_CONTROLLER_ENABLED_FIELD: + isControllerEnabled = parser.booleanValue(); break; - case MODEL_RATE_LIMITER_CONFIG_FIELD: - modelRateLimiterConfig = MLRateLimiter.parse(parser); + case RATE_LIMITER_FIELD: + rateLimiter = MLRateLimiter.parse(parser); break; case PLANNING_WORKER_NODE_COUNT_FIELD: planningWorkerNodeCount = parser.intValue(); @@ -589,13 +590,13 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws .modelContentHash(modelContentHash) .modelConfig(modelConfig) .isEnabled(isEnabled) - .isModelControllerEnabled(isModelControllerEnabled) - .modelRateLimiterConfig(modelRateLimiterConfig) + .isControllerEnabled(isControllerEnabled) + .rateLimiter(rateLimiter) .createdTime(createdTime) .lastUpdateTime(lastUpdateTime) - .lastRegisteredTime(lastRegisteredTime == null? lastUploadedTime : lastRegisteredTime) - .lastDeployedTime(lastDeployedTime == null? lastLoadedTime : lastDeployedTime) - .lastUndeployedTime(lastUndeployedTime == null? lastUnloadedTime : lastUndeployedTime) + .lastRegisteredTime(lastRegisteredTime == null ? lastUploadedTime : lastRegisteredTime) + .lastDeployedTime(lastDeployedTime == null ? lastLoadedTime : lastDeployedTime) + .lastUndeployedTime(lastUndeployedTime == null ? lastUnloadedTime : lastUndeployedTime) .modelId(modelId) .autoRedeployRetryTimes(autoRedeployRetryTimes) .chunkNumber(chunkNumber) diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java b/common/src/main/java/org/opensearch/ml/common/controller/MLController.java similarity index 58% rename from common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java rename to common/src/main/java/org/opensearch/ml/common/controller/MLController.java index 269a875f57..f356cad477 100644 --- a/common/src/main/java/org/opensearch/ml/common/controller/MLModelController.java +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLController.java @@ -29,25 +29,26 @@ import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; @Data -public class MLModelController implements ToXContentObject, Writeable { +public class MLController implements ToXContentObject, Writeable { public static final String MODEL_ID_FIELD = "model_id"; // mandatory - public static final String USER_RATE_LIMITER_CONFIG = "user_rate_limiter_config"; + public static final String USER_RATE_LIMITER = "user_rate_limiter"; @Getter private String modelId; - // The String is the username field where the MLRateLimiter is its corresponding rate limiter config. - private Map userRateLimiterConfig; + // The String is the username field where the MLRateLimiter is its corresponding + // rate limiter config. + private Map userRateLimiter; @Builder(toBuilder = true) - public MLModelController(String modelId, Map userRateLimiterConfig) { + public MLController(String modelId, Map userRateLimiter) { this.modelId = modelId; - this.userRateLimiterConfig = userRateLimiterConfig; + this.userRateLimiter = userRateLimiter; } - public static MLModelController parse(XContentParser parser) throws IOException { + public static MLController parse(XContentParser parser) throws IOException { String modelId = null; - Map userRateLimiterConfig = new HashMap<>(); + Map userRateLimiter = new HashMap<>(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -58,15 +59,16 @@ public static MLModelController parse(XContentParser parser) throws IOException case MODEL_ID_FIELD: modelId = parser.text(); break; - case USER_RATE_LIMITER_CONFIG: - Map userRateLimiterConfigStringMap = getParameterMap(parser.map()); - userRateLimiterConfigStringMap.forEach((user, rateLimiterString) -> { + case USER_RATE_LIMITER: + Map userRateLimiterStringMap = getParameterMap(parser.map()); + userRateLimiterStringMap.forEach((user, rateLimiterString) -> { try { - XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString); + XContentParser rateLimiterParser = XContentType.JSON.xContent().createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, rateLimiterString); rateLimiterParser.nextToken(); MLRateLimiter rateLimiter = MLRateLimiter.parse(rateLimiterParser); if (!rateLimiter.isEmpty()) { - userRateLimiterConfig.put(user, rateLimiter); + userRateLimiter.put(user, rateLimiter); } } catch (IOException e) { throw new RuntimeException(e); @@ -79,22 +81,23 @@ public static MLModelController parse(XContentParser parser) throws IOException } } // Model ID can only be set through RestRequest. - return new MLModelController(modelId, userRateLimiterConfig); + return new MLController(modelId, userRateLimiter); } - public MLModelController(StreamInput in) throws IOException{ + public MLController(StreamInput in) throws IOException { modelId = in.readString(); if (in.readBoolean()) { - userRateLimiterConfig = in.readMap(StreamInput::readString, MLRateLimiter::new); + userRateLimiter = in.readMap(StreamInput::readString, MLRateLimiter::new); } } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); - if (userRateLimiterConfig != null) { + if (userRateLimiter != null) { out.writeBoolean(true); - out.writeMap(userRateLimiterConfig, StreamOutput::writeString, (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput)); + out.writeMap(userRateLimiter, StreamOutput::writeString, + (streamOutput, rateLimiter) -> rateLimiter.writeTo(streamOutput)); } else { out.writeBoolean(false); } @@ -104,28 +107,28 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); builder.field(MODEL_ID_FIELD, modelId); - if (userRateLimiterConfig != null) { - builder.field(USER_RATE_LIMITER_CONFIG, userRateLimiterConfig); + if (userRateLimiter != null) { + builder.field(USER_RATE_LIMITER, userRateLimiter); } builder.endObject(); return builder; } - /** - * Checks if a deployment is required after updating the MLModelController. + * Checks if a deployment is required after updating the MLController. * - * @param updateContent The updated MLModelController object. + * @param updateContent The updated MLController object. * @return True if a deployment is required, false otherwise. */ - public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) { - if (updateContent != null && updateContent.getUserRateLimiterConfig() != null && !updateContent.getUserRateLimiterConfig().isEmpty()) { - Map updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig(); - for (Map.Entry entry : updateUserRateLimiterConfig.entrySet()) { + public boolean isDeployRequiredAfterUpdate(MLController updateContent) { + if (updateContent != null && updateContent.getUserRateLimiter() != null + && !updateContent.getUserRateLimiter().isEmpty()) { + Map updateUserRateLimiter = updateContent.getUserRateLimiter(); + for (Map.Entry entry : updateUserRateLimiter.entrySet()) { String newUser = entry.getKey(); MLRateLimiter newRateLimiter = entry.getValue(); - if (this.userRateLimiterConfig.containsKey(newUser)) { - MLRateLimiter oldRateLimiter = this.userRateLimiterConfig.get(newUser); + if (this.userRateLimiter.containsKey(newUser)) { + MLRateLimiter oldRateLimiter = this.userRateLimiter.get(newUser); if (MLRateLimiter.isDeployRequiredAfterUpdate(oldRateLimiter, newRateLimiter)) { return true; } @@ -139,16 +142,16 @@ public boolean isDeployRequiredAfterUpdate(MLModelController updateContent) { return false; } - public void update(MLModelController updateContent) { - Map updateUserRateLimiterConfig = updateContent.getUserRateLimiterConfig(); - if (updateUserRateLimiterConfig != null && !updateUserRateLimiterConfig.isEmpty()) { - updateUserRateLimiterConfig.forEach((user, rateLimiter) -> { + public void update(MLController updateContent) { + Map updateUserRateLimiter = updateContent.getUserRateLimiter(); + if (updateUserRateLimiter != null && !updateUserRateLimiter.isEmpty()) { + updateUserRateLimiter.forEach((user, rateLimiter) -> { // rateLimiter can't be null due to parsing exception - if (this.userRateLimiterConfig.containsKey(user)) { - this.userRateLimiterConfig.get(user).update(rateLimiter); - } else { - this.userRateLimiterConfig.put(user, rateLimiter); - } + if (this.userRateLimiter.containsKey(user)) { + this.userRateLimiter.get(user).update(rateLimiter); + } else { + this.userRateLimiter.put(user, rateLimiter); + } }); } } diff --git a/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java index c132392708..d5906a29ed 100644 --- a/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java +++ b/common/src/main/java/org/opensearch/ml/common/controller/MLRateLimiter.java @@ -25,21 +25,21 @@ @Setter @Getter public class MLRateLimiter implements ToXContentObject, Writeable { - public static final String RATE_LIMIT_NUMBER_FIELD = "rate_limit_number"; - public static final String RATE_LIMIT_UNIT_FIELD = "rate_limit_unit"; + public static final String LIMIT_FIELD = "limit"; + public static final String UNIT_FIELD = "unit"; - private String rateLimitNumber; - private TimeUnit rateLimitUnit; + private String limit; + private TimeUnit unit; @Builder(toBuilder = true) - public MLRateLimiter(String rateLimitNumber, TimeUnit rateLimitUnit) { - this.rateLimitNumber = rateLimitNumber; - this.rateLimitUnit = rateLimitUnit; + public MLRateLimiter(String limit, TimeUnit unit) { + this.limit = limit; + this.unit = unit; } public static MLRateLimiter parse(XContentParser parser) throws IOException { - String rateLimitNumber = null; - TimeUnit rateLimitUnit = null; + String limit = null; + TimeUnit unit = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -47,33 +47,33 @@ public static MLRateLimiter parse(XContentParser parser) throws IOException { parser.nextToken(); switch (fieldName) { - case RATE_LIMIT_NUMBER_FIELD: - rateLimitNumber = parser.text(); + case LIMIT_FIELD: + limit = parser.text(); break; - case RATE_LIMIT_UNIT_FIELD: - rateLimitUnit = TimeUnit.valueOf(parser.text()); + case UNIT_FIELD: + unit = TimeUnit.valueOf(parser.text()); break; default: parser.skipChildren(); break; } } - return new MLRateLimiter(rateLimitNumber, rateLimitUnit); + return new MLRateLimiter(limit, unit); } - public MLRateLimiter(StreamInput in) throws IOException{ - this.rateLimitNumber = in.readOptionalString(); + public MLRateLimiter(StreamInput in) throws IOException { + this.limit = in.readOptionalString(); if (in.readBoolean()) { - this.rateLimitUnit = in.readEnum(TimeUnit.class); + this.unit = in.readEnum(TimeUnit.class); } } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalString(rateLimitNumber); - if (rateLimitUnit != null) { + out.writeOptionalString(limit); + if (unit != null) { out.writeBoolean(true); - out.writeEnum(rateLimitUnit); + out.writeEnum(unit); } else { out.writeBoolean(false); } @@ -82,22 +82,22 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); - if (rateLimitNumber != null) { - builder.field(RATE_LIMIT_NUMBER_FIELD, rateLimitNumber); + if (limit != null) { + builder.field(LIMIT_FIELD, limit); } - if (rateLimitUnit != null) { - builder.field(RATE_LIMIT_UNIT_FIELD, rateLimitUnit); + if (unit != null) { + builder.field(UNIT_FIELD, unit); } builder.endObject(); return builder; } public void update(MLRateLimiter updateContent) { - if (updateContent.getRateLimitNumber() != null) { - this.rateLimitNumber = updateContent.getRateLimitNumber(); + if (updateContent.getLimit() != null) { + this.limit = updateContent.getLimit(); } - if (updateContent.getRateLimitUnit() != null) { - this.rateLimitUnit = updateContent.getRateLimitUnit(); + if (updateContent.getUnit() != null) { + this.unit = updateContent.getUnit(); } } @@ -111,11 +111,13 @@ public static MLRateLimiter update(MLRateLimiter rateLimiter, MLRateLimiter upda } /** - * Checks the validity of this incoming update before performing an update operation. - * A valid update indicates the corresponding index will be updated with the current MLRateLimiter config and the update content + * Checks the validity of this incoming update before performing an update + * operation. + * A valid update indicates the corresponding index will be updated with the + * current MLRateLimiter config and the update content * - * @param rateLimiter The existing rate limiter. - * @param updateContent The update content. + * @param rateLimiter The existing rate limiter. + * @param updateContent The update content. * @return true if the update is valid, false otherwise. */ public static boolean updateValidityPreCheck(MLRateLimiter rateLimiter, MLRateLimiter updateContent) { @@ -125,15 +127,19 @@ public static boolean updateValidityPreCheck(MLRateLimiter rateLimiter, MLRateLi return true; } else if (updateContent.isEmpty()) { return false; - } else return (!Objects.equals(updateContent.getRateLimitNumber(), rateLimiter.getRateLimitNumber()) && updateContent.getRateLimitNumber() != null) - || (!Objects.equals(updateContent.getRateLimitUnit(), rateLimiter.getRateLimitUnit()) && updateContent.getRateLimitUnit() != null); + } else + return (!Objects.equals(updateContent.getLimit(), rateLimiter.getLimit()) + && updateContent.getLimit() != null) + || (!Objects.equals(updateContent.getUnit(), rateLimiter.getUnit()) + && updateContent.getUnit() != null); } /** - * Checks if we need to deploy this update into ML Cache (if model is deployed) after performing this update operation. + * Checks if we need to deploy this update into ML Cache (if model is deployed) + * after performing this update operation. * - * @param rateLimiter The existing rate limiter. - * @param updateContent The update content. + * @param rateLimiter The existing rate limiter. + * @param updateContent The update content. * @return true if the update is valid, false otherwise. */ public static boolean isDeployRequiredAfterUpdate(MLRateLimiter rateLimiter, MLRateLimiter updateContent) { @@ -141,16 +147,16 @@ public static boolean isDeployRequiredAfterUpdate(MLRateLimiter rateLimiter, MLR return false; } else { return updateContent.isValid() - || (rateLimiter.getRateLimitUnit() != null && updateContent.getRateLimitNumber() != null) - || (rateLimiter.getRateLimitNumber() != null && updateContent.getRateLimitUnit() != null); + || (rateLimiter.getUnit() != null && updateContent.getLimit() != null) + || (rateLimiter.getLimit() != null && updateContent.getUnit() != null); } } public boolean isValid() { - return (this.rateLimitUnit != null && this.rateLimitNumber != null); + return (this.unit != null && this.limit != null); } public boolean isEmpty() { - return (this.rateLimitUnit == null && this.rateLimitNumber == null); + return (this.unit == null && this.limit == null); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteAction.java similarity index 55% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteAction.java index 2e44fffa5c..7924245b26 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteAction.java @@ -8,9 +8,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.delete.DeleteResponse; -public class MLModelControllerDeleteAction extends ActionType { - public static final MLModelControllerDeleteAction INSTANCE = new MLModelControllerDeleteAction(); +public class MLControllerDeleteAction extends ActionType { + public static final MLControllerDeleteAction INSTANCE = new MLControllerDeleteAction(); public static final String NAME = "cluster:admin/opensearch/ml/controllers/delete"; - private MLModelControllerDeleteAction() { super(NAME, DeleteResponse::new);} + private MLControllerDeleteAction() { + super(NAME, DeleteResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java similarity index 72% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java index d7709d808d..8fdb8bc564 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequest.java @@ -21,16 +21,16 @@ import static org.opensearch.action.ValidateActions.addValidationError; -public class MLModelControllerDeleteRequest extends ActionRequest { +public class MLControllerDeleteRequest extends ActionRequest { @Getter String modelId; @Builder - public MLModelControllerDeleteRequest(String modelId) { + public MLControllerDeleteRequest(String modelId) { this.modelId = modelId; } - public MLModelControllerDeleteRequest(StreamInput input) throws IOException { + public MLControllerDeleteRequest(StreamInput input) throws IOException { super(input); this.modelId = input.readString(); } @@ -52,19 +52,19 @@ public ActionRequestValidationException validate() { return exception; } - public static MLModelControllerDeleteRequest fromActionRequest(ActionRequest actionRequest) { - if (actionRequest instanceof MLModelControllerDeleteRequest) { - return (MLModelControllerDeleteRequest)actionRequest; + public static MLControllerDeleteRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLControllerDeleteRequest) { + return (MLControllerDeleteRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLModelControllerDeleteRequest(input); + return new MLControllerDeleteRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionRequest into MLModelControllerDeleteRequest", e); + throw new UncheckedIOException("failed to parse ActionRequest into MLControllerDeleteRequest", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetAction.java new file mode 100644 index 0000000000..cc5b511b2b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +public class MLControllerGetAction extends ActionType { + public static final MLControllerGetAction INSTANCE = new MLControllerGetAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/get"; + + private MLControllerGetAction() { + super(NAME, MLControllerGetResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java similarity index 75% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java index d46afd93b3..86754c1732 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequest.java @@ -27,18 +27,18 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLModelControllerGetRequest extends ActionRequest { - +public class MLControllerGetRequest extends ActionRequest { + String modelId; boolean returnContent; @Builder - public MLModelControllerGetRequest(String modelId, boolean returnContent) { + public MLControllerGetRequest(String modelId, boolean returnContent) { this.modelId = modelId; this.returnContent = returnContent; } - public MLModelControllerGetRequest(StreamInput in) throws IOException { + public MLControllerGetRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); this.returnContent = in.readBoolean(); @@ -62,19 +62,19 @@ public ActionRequestValidationException validate() { return exception; } - public static MLModelControllerGetRequest fromActionRequest(ActionRequest actionRequest) { - if (actionRequest instanceof MLModelControllerGetRequest) { - return (MLModelControllerGetRequest) actionRequest; + public static MLControllerGetRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLControllerGetRequest) { + return (MLControllerGetRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLModelControllerGetRequest(input); + return new MLControllerGetRequest(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionRequest into MLModelControllerGetRequest", e); + throw new UncheckedIOException("failed to parse ActionRequest into MLControllerGetRequest", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java similarity index 56% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java index 6c5fe8db09..7c07e91a1f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponse.java @@ -15,51 +15,51 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.UncheckedIOException; -public class MLModelControllerGetResponse extends ActionResponse implements ToXContentObject { +public class MLControllerGetResponse extends ActionResponse implements ToXContentObject { @Getter - MLModelController modelController; + MLController controller; @Builder - public MLModelControllerGetResponse(MLModelController modelController) { - this.modelController = modelController; + public MLControllerGetResponse(MLController controller) { + this.controller = controller; } - public MLModelControllerGetResponse(StreamInput in) throws IOException { + public MLControllerGetResponse(StreamInput in) throws IOException { super(in); - modelController = new MLModelController(in); + controller = new MLController(in); } @Override - public void writeTo(StreamOutput out) throws IOException{ - modelController.writeTo(out); + public void writeTo(StreamOutput out) throws IOException { + controller.writeTo(out); } @Override public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException { - return modelController.toXContent(xContentBuilder, params); + return controller.toXContent(xContentBuilder, params); } - public static MLModelControllerGetResponse fromActionResponse(ActionResponse actionResponse) { - if (actionResponse instanceof MLModelControllerGetResponse) { - return (MLModelControllerGetResponse) actionResponse; + public static MLControllerGetResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLControllerGetResponse) { + return (MLControllerGetResponse) actionResponse; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLModelControllerGetResponse(input); + return new MLControllerGetResponse(input); } } catch (IOException e) { - throw new UncheckedIOException("failed to parse ActionResponse into MLModelControllerGetResponse", e); + throw new UncheckedIOException("failed to parse ActionResponse into MLControllerGetResponse", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerAction.java new file mode 100644 index 0000000000..ba8dda6696 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.transport.controller; + +import org.opensearch.action.ActionType; + +public class MLCreateControllerAction extends ActionType { + public static final MLCreateControllerAction INSTANCE = new MLCreateControllerAction(); + public static final String NAME = "cluster:admin/opensearch/ml/controllers/create"; + + private MLCreateControllerAction() { + super(NAME, MLCreateControllerResponse::new); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java similarity index 60% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java index 136a3bd373..efea44da24 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequest.java @@ -15,7 +15,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -27,47 +27,46 @@ @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLCreateModelControllerRequest extends ActionRequest { - private MLModelController modelControllerInput; +public class MLCreateControllerRequest extends ActionRequest { + private MLController controllerInput; @Builder - public MLCreateModelControllerRequest(MLModelController modelControllerInput) { - this.modelControllerInput = modelControllerInput; + public MLCreateControllerRequest(MLController controllerInput) { + this.controllerInput = controllerInput; } - public MLCreateModelControllerRequest(StreamInput in) throws IOException { + public MLCreateControllerRequest(StreamInput in) throws IOException { super(in); - this.modelControllerInput = new MLModelController(in); + this.controllerInput = new MLController(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - modelControllerInput.writeTo(out); + controllerInput.writeTo(out); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if (modelControllerInput == null) { + if (controllerInput == null) { exception = addValidationError("Model controller input can't be null", exception); - } + } return exception; } - public static MLCreateModelControllerRequest fromActionRequest(ActionRequest actionRequest) { - if (actionRequest instanceof MLCreateModelControllerRequest) { - return (MLCreateModelControllerRequest) actionRequest; + public static MLCreateControllerRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLCreateControllerRequest) { + return (MLCreateControllerRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); - try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) - { - return new MLCreateModelControllerRequest(input); - } + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLCreateControllerRequest(input); + } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateModelControllerRequest", e); + throw new UncheckedIOException("Failed to parse ActionRequest into MLCreateControllerRequest", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java similarity index 73% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java index 531aa4daf7..592caf5d6b 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponse.java @@ -20,7 +20,7 @@ import java.io.UncheckedIOException; @Getter -public class MLCreateModelControllerResponse extends ActionResponse implements ToXContentObject { +public class MLCreateControllerResponse extends ActionResponse implements ToXContentObject { public static final String MODEL_ID_FIELD = "model_id"; public static final String STATUS_FIELD = "status"; @@ -29,14 +29,14 @@ public class MLCreateModelControllerResponse extends ActionResponse implements T String modelId; String status; - public MLCreateModelControllerResponse(StreamInput in) throws IOException { + public MLCreateControllerResponse(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); this.status = in.readString(); } @Builder - public MLCreateModelControllerResponse(String modelId, String status) { + public MLCreateControllerResponse(String modelId, String status) { this.modelId = modelId; this.status = status; } @@ -56,18 +56,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static MLCreateModelControllerResponse fromActionResponse(ActionResponse actionResponse) { - if (actionResponse instanceof MLCreateModelControllerResponse) { - return (MLCreateModelControllerResponse) actionResponse; + public static MLCreateControllerResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLCreateControllerResponse) { + return (MLCreateControllerResponse) actionResponse; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLCreateModelControllerResponse(input); + return new MLCreateControllerResponse(input); } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateModelControllerResponse", e); + throw new UncheckedIOException("Failed to parse ActionResponse into MLCreateControllerResponse", e); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java deleted file mode 100644 index 4e99704771..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerAction.java +++ /dev/null @@ -1,17 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.ml.common.transport.controller; - -import org.opensearch.action.ActionType; - -public class MLCreateModelControllerAction extends ActionType{ - public static final MLCreateModelControllerAction INSTANCE = new MLCreateModelControllerAction(); - public static final String NAME = "cluster:admin/opensearch/ml/controllers/create"; - - private MLCreateModelControllerAction() { - super(NAME, MLCreateModelControllerResponse::new); - } - -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerAction.java similarity index 56% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerAction.java index 68b1094e0e..2c4a61c706 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerAction.java @@ -8,9 +8,11 @@ import org.opensearch.action.ActionType; // This action will only be passively called when creating or updating a model controller when the model is deployed. -public class MLDeployModelControllerAction extends ActionType { - public static final MLDeployModelControllerAction INSTANCE = new MLDeployModelControllerAction(); +public class MLDeployControllerAction extends ActionType { + public static final MLDeployControllerAction INSTANCE = new MLDeployControllerAction(); public static final String NAME = "cluster:admin/opensearch/ml/controllers/deploy"; - private MLDeployModelControllerAction() { super(NAME, MLDeployModelControllerNodesResponse::new);} + private MLDeployControllerAction() { + super(NAME, MLDeployControllerNodesResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java new file mode 100644 index 0000000000..0396fc3a16 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class MLDeployControllerNodeRequest extends TransportRequest { + @Getter + private MLDeployControllerNodesRequest deployControllerNodesRequest; + + public MLDeployControllerNodeRequest(StreamInput in) throws IOException { + super(in); + this.deployControllerNodesRequest = new MLDeployControllerNodesRequest(in); + } + + public MLDeployControllerNodeRequest(MLDeployControllerNodesRequest request) { + this.deployControllerNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + deployControllerNodesRequest.writeTo(out); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java similarity index 50% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java index 9587a4e40b..7b038a4a21 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponse.java @@ -19,32 +19,32 @@ @Getter @Log4j2 -public class MLDeployModelControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { - private Map modelControllerDeployStatus; +public class MLDeployControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map controllerDeployStatus; - public MLDeployModelControllerNodeResponse(DiscoveryNode node, Map modelControllerDeployStatus) { + public MLDeployControllerNodeResponse(DiscoveryNode node, Map controllerDeployStatus) { super(node); - this.modelControllerDeployStatus = modelControllerDeployStatus; + this.controllerDeployStatus = controllerDeployStatus; } - public MLDeployModelControllerNodeResponse(StreamInput in) throws IOException { + public MLDeployControllerNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelControllerDeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); + this.controllerDeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); } } - public static MLDeployModelControllerNodeResponse readStats(StreamInput in) throws IOException { - return new MLDeployModelControllerNodeResponse(in); + public static MLDeployControllerNodeResponse readStats(StreamInput in) throws IOException { + return new MLDeployControllerNodeResponse(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - if (!isModelControllerDeployStatusEmpty()) { + if (!isControllerDeployStatusEmpty()) { out.writeBoolean(true); - out.writeMap(modelControllerDeployStatus, StreamOutput::writeString, StreamOutput::writeString); + out.writeMap(controllerDeployStatus, StreamOutput::writeString, StreamOutput::writeString); } else { out.writeBoolean(false); } @@ -52,8 +52,8 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject("stats"); - if (!isModelControllerDeployStatusEmpty()) { - for (Map.Entry stat : modelControllerDeployStatus.entrySet()) { + if (!isControllerDeployStatusEmpty()) { + for (Map.Entry stat : controllerDeployStatus.entrySet()) { builder.field(stat.getKey(), stat.getValue()); } } @@ -61,7 +61,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public boolean isModelControllerDeployStatusEmpty() { - return modelControllerDeployStatus == null || modelControllerDeployStatus.isEmpty(); + public boolean isControllerDeployStatusEmpty() { + return controllerDeployStatus == null || controllerDeployStatus.isEmpty(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java similarity index 64% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java index 19c638f872..1a70c53a90 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.ml.common.transport.controller; +package org.opensearch.ml.common.transport.controller; import lombok.Getter; import org.opensearch.action.support.nodes.BaseNodesRequest; @@ -12,22 +12,22 @@ import org.opensearch.core.common.io.stream.StreamOutput; import java.io.IOException; -public class MLUndeployModelControllerNodesRequest extends BaseNodesRequest { +public class MLDeployControllerNodesRequest extends BaseNodesRequest { @Getter private String modelId; - public MLUndeployModelControllerNodesRequest(StreamInput in) throws IOException { + public MLDeployControllerNodesRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); } - public MLUndeployModelControllerNodesRequest(String[] nodeIds, String modelId) { + public MLDeployControllerNodesRequest(String[] nodeIds, String modelId) { super(nodeIds); this.modelId = modelId; } - public MLUndeployModelControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { + public MLDeployControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { super(nodeIds); this.modelId = modelId; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java similarity index 59% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java index 36f046d81f..50d60d1801 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponse.java @@ -18,13 +18,16 @@ import java.io.IOException; import java.util.List; -public class MLUndeployModelControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { +public class MLDeployControllerNodesResponse extends BaseNodesResponse + implements ToXContentObject { - public MLUndeployModelControllerNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(MLUndeployModelControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); + public MLDeployControllerNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLDeployControllerNodeResponse::readStats), + in.readList(FailedNodeException::new)); } - public MLUndeployModelControllerNodesResponse(ClusterName clusterName, List nodes, List failures) { + public MLDeployControllerNodesResponse(ClusterName clusterName, List nodes, + List failures) { super(clusterName, nodes, failures); } @@ -34,13 +37,13 @@ public void writeTo(StreamOutput out) throws IOException { } @Override - public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { out.writeList(nodes); } @Override - public List readNodesFrom(StreamInput in) throws IOException { - return in.readList(MLUndeployModelControllerNodeResponse::readStats); + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLDeployControllerNodeResponse::readStats); } @Override @@ -48,8 +51,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par String nodeId; DiscoveryNode node; builder.startObject(); - for (MLUndeployModelControllerNodeResponse deployStats : getNodes()) { - if (!deployStats.isModelControllerUndeployStatusEmpty()) { + for (MLDeployControllerNodeResponse deployStats : getNodes()) { + if (!deployStats.isControllerDeployStatusEmpty()) { node = deployStats.getNode(); nodeId = node.getId(); builder.startObject(nodeId); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java deleted file mode 100644 index d11e488641..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeRequest.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import java.io.IOException; -import lombok.Getter; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportRequest; - -public class MLDeployModelControllerNodeRequest extends TransportRequest { - @Getter - private MLDeployModelControllerNodesRequest deployModelControllerNodesRequest; - - public MLDeployModelControllerNodeRequest(StreamInput in) throws IOException { - super(in); - this.deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest(in); - } - - public MLDeployModelControllerNodeRequest(MLDeployModelControllerNodesRequest request) { - this.deployModelControllerNodesRequest = request; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - deployModelControllerNodesRequest.writeTo(out); - } -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java deleted file mode 100644 index bbae2ac7de..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetAction.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import org.opensearch.action.ActionType; - -public class MLModelControllerGetAction extends ActionType { - public static final MLModelControllerGetAction INSTANCE = new MLModelControllerGetAction(); - public static final String NAME = "cluster:admin/opensearch/ml/controllers/get"; - - private MLModelControllerGetAction() { super(NAME, MLModelControllerGetResponse::new);} -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerAction.java similarity index 54% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerAction.java index 3be1af7306..bfd2991fbf 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerAction.java @@ -8,9 +8,11 @@ import org.opensearch.action.ActionType; // This action will only be passively called when deleting a model controller when the model is deployed. -public class MLUndeployModelControllerAction extends ActionType { - public static final MLUndeployModelControllerAction INSTANCE = new MLUndeployModelControllerAction(); +public class MLUndeployControllerAction extends ActionType { + public static final MLUndeployControllerAction INSTANCE = new MLUndeployControllerAction(); public static final String NAME = "cluster:admin/opensearch/ml/controllers/undeploy"; - private MLUndeployModelControllerAction() { super(NAME, MLUndeployModelControllerNodesResponse::new);} + private MLUndeployControllerAction() { + super(NAME, MLUndeployControllerNodesResponse::new); + } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java new file mode 100644 index 0000000000..dc4f6af75b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeRequest.java @@ -0,0 +1,33 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import java.io.IOException; +import lombok.Getter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.transport.TransportRequest; + +public class MLUndeployControllerNodeRequest extends TransportRequest { + @Getter + private MLUndeployControllerNodesRequest undeployControllerNodesRequest; + + public MLUndeployControllerNodeRequest(StreamInput in) throws IOException { + super(in); + this.undeployControllerNodesRequest = new MLUndeployControllerNodesRequest(in); + } + + public MLUndeployControllerNodeRequest(MLUndeployControllerNodesRequest request) { + this.undeployControllerNodesRequest = request; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + undeployControllerNodesRequest.writeTo(out); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java similarity index 50% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java index bf4a1cb8a0..7438871caf 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponse.java @@ -19,32 +19,32 @@ @Getter @Log4j2 -public class MLUndeployModelControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { - private Map modelControllerUndeployStatus; +public class MLUndeployControllerNodeResponse extends BaseNodeResponse implements ToXContentFragment { + private Map controllerUndeployStatus; - public MLUndeployModelControllerNodeResponse(DiscoveryNode node, Map modelControllerUndeployStatus) { + public MLUndeployControllerNodeResponse(DiscoveryNode node, Map controllerUndeployStatus) { super(node); - this.modelControllerUndeployStatus = modelControllerUndeployStatus; + this.controllerUndeployStatus = controllerUndeployStatus; } - public MLUndeployModelControllerNodeResponse(StreamInput in) throws IOException { + public MLUndeployControllerNodeResponse(StreamInput in) throws IOException { super(in); if (in.readBoolean()) { - this.modelControllerUndeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); + this.controllerUndeployStatus = in.readMap(StreamInput::readString, StreamInput::readString); } } - public static MLUndeployModelControllerNodeResponse readStats(StreamInput in) throws IOException { - return new MLUndeployModelControllerNodeResponse(in); + public static MLUndeployControllerNodeResponse readStats(StreamInput in) throws IOException { + return new MLUndeployControllerNodeResponse(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - if (!isModelControllerUndeployStatusEmpty()) { + if (!isControllerUndeployStatusEmpty()) { out.writeBoolean(true); - out.writeMap(modelControllerUndeployStatus, StreamOutput::writeString, StreamOutput::writeString); + out.writeMap(controllerUndeployStatus, StreamOutput::writeString, StreamOutput::writeString); } else { out.writeBoolean(false); } @@ -52,8 +52,8 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject("stats"); - if (!isModelControllerUndeployStatusEmpty()) { - for (Map.Entry stat : modelControllerUndeployStatus.entrySet()) { + if (!isControllerUndeployStatusEmpty()) { + for (Map.Entry stat : controllerUndeployStatus.entrySet()) { builder.field(stat.getKey(), stat.getValue()); } } @@ -61,7 +61,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public boolean isModelControllerUndeployStatusEmpty() { - return modelControllerUndeployStatus == null || modelControllerUndeployStatus.isEmpty(); + public boolean isControllerUndeployStatusEmpty() { + return controllerUndeployStatus == null || controllerUndeployStatus.isEmpty(); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java similarity index 63% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java index ac399828f1..af9785dcee 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.ml.common.transport.controller; +package org.opensearch.ml.common.transport.controller; import lombok.Getter; import org.opensearch.action.support.nodes.BaseNodesRequest; @@ -12,22 +12,22 @@ import org.opensearch.core.common.io.stream.StreamOutput; import java.io.IOException; -public class MLDeployModelControllerNodesRequest extends BaseNodesRequest { +public class MLUndeployControllerNodesRequest extends BaseNodesRequest { @Getter private String modelId; - public MLDeployModelControllerNodesRequest(StreamInput in) throws IOException { + public MLUndeployControllerNodesRequest(StreamInput in) throws IOException { super(in); this.modelId = in.readString(); } - public MLDeployModelControllerNodesRequest(String[] nodeIds, String modelId) { + public MLUndeployControllerNodesRequest(String[] nodeIds, String modelId) { super(nodeIds); this.modelId = modelId; } - public MLDeployModelControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { + public MLUndeployControllerNodesRequest(DiscoveryNode[] nodeIds, String modelId) { super(nodeIds); this.modelId = modelId; } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java similarity index 57% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java index bbd27c7cca..11996955c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponse.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.ml.common.transport.controller; +package org.opensearch.ml.common.transport.controller; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; @@ -18,13 +18,16 @@ import java.io.IOException; import java.util.List; -public class MLDeployModelControllerNodesResponse extends BaseNodesResponse implements ToXContentObject { +public class MLUndeployControllerNodesResponse extends BaseNodesResponse + implements ToXContentObject { - public MLDeployModelControllerNodesResponse(StreamInput in) throws IOException { - super(new ClusterName(in), in.readList(MLDeployModelControllerNodeResponse::readStats), in.readList(FailedNodeException::new)); + public MLUndeployControllerNodesResponse(StreamInput in) throws IOException { + super(new ClusterName(in), in.readList(MLUndeployControllerNodeResponse::readStats), + in.readList(FailedNodeException::new)); } - public MLDeployModelControllerNodesResponse(ClusterName clusterName, List nodes, List failures) { + public MLUndeployControllerNodesResponse(ClusterName clusterName, List nodes, + List failures) { super(clusterName, nodes, failures); } @@ -34,13 +37,13 @@ public void writeTo(StreamOutput out) throws IOException { } @Override - public void writeNodesTo(StreamOutput out, List nodes) throws IOException { + public void writeNodesTo(StreamOutput out, List nodes) throws IOException { out.writeList(nodes); } @Override - public List readNodesFrom(StreamInput in) throws IOException { - return in.readList(MLDeployModelControllerNodeResponse::readStats); + public List readNodesFrom(StreamInput in) throws IOException { + return in.readList(MLUndeployControllerNodeResponse::readStats); } @Override @@ -48,8 +51,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par String nodeId; DiscoveryNode node; builder.startObject(); - for (MLDeployModelControllerNodeResponse deployStats : getNodes()) { - if (!deployStats.isModelControllerDeployStatusEmpty()) { + for (MLUndeployControllerNodeResponse deployStats : getNodes()) { + if (!deployStats.isControllerUndeployStatusEmpty()) { node = deployStats.getNode(); nodeId = node.getId(); builder.startObject(nodeId); diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java deleted file mode 100644 index 0cbd67891b..0000000000 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeRequest.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import java.io.IOException; -import lombok.Getter; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportRequest; - - -public class MLUndeployModelControllerNodeRequest extends TransportRequest { - @Getter - private MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest; - - public MLUndeployModelControllerNodeRequest(StreamInput in) throws IOException { - super(in); - this.undeployModelControllerNodesRequest = new MLUndeployModelControllerNodesRequest(in); - } - - public MLUndeployModelControllerNodeRequest(MLUndeployModelControllerNodesRequest request) { - this.undeployModelControllerNodesRequest = request; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - undeployModelControllerNodesRequest.writeTo(out); - } - -} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerAction.java similarity index 62% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerAction.java index 7c48429765..782d62b279 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerAction.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerAction.java @@ -7,11 +7,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.update.UpdateResponse; -public class MLUpdateModelControllerAction extends ActionType { - public static final MLUpdateModelControllerAction INSTANCE = new MLUpdateModelControllerAction(); +public class MLUpdateControllerAction extends ActionType { + public static final MLUpdateControllerAction INSTANCE = new MLUpdateControllerAction(); public static final String NAME = "cluster:admin/opensearch/ml/controllers/update"; - private MLUpdateModelControllerAction() { + private MLUpdateControllerAction() { super(NAME, UpdateResponse::new); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java similarity index 61% rename from common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java rename to common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java index 7e8abedda5..5a067a1411 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequest.java @@ -15,7 +15,7 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -28,46 +28,46 @@ @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString -public class MLUpdateModelControllerRequest extends ActionRequest { - private MLModelController updateModelControllerInput; +public class MLUpdateControllerRequest extends ActionRequest { + private MLController updateControllerInput; @Builder - public MLUpdateModelControllerRequest(MLModelController updateModelControllerInput) { - this.updateModelControllerInput = updateModelControllerInput; + public MLUpdateControllerRequest(MLController updateControllerInput) { + this.updateControllerInput = updateControllerInput; } - public MLUpdateModelControllerRequest(StreamInput in) throws IOException { + public MLUpdateControllerRequest(StreamInput in) throws IOException { super(in); - this.updateModelControllerInput = new MLModelController(in); + this.updateControllerInput = new MLController(in); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - updateModelControllerInput.writeTo(out); + updateControllerInput.writeTo(out); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; - if (updateModelControllerInput == null) { + if (updateControllerInput == null) { exception = addValidationError("Update model controller input can't be null", exception); - } + } return exception; } - public static MLUpdateModelControllerRequest fromActionRequest(ActionRequest actionRequest) { - if (actionRequest instanceof MLUpdateModelControllerRequest) { - return (MLUpdateModelControllerRequest) actionRequest; + public static MLUpdateControllerRequest fromActionRequest(ActionRequest actionRequest) { + if (actionRequest instanceof MLUpdateControllerRequest) { + return (MLUpdateControllerRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); - OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new MLUpdateModelControllerRequest(input); - } + return new MLUpdateControllerRequest(input); + } } catch (IOException e) { - throw new UncheckedIOException("Failed to parse action request to MLCreateModelControllerRequest", e); + throw new UncheckedIOException("Failed to parse action request to MLCreateControllerRequest", e); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java index 4cada116a6..74090c3491 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/model/MLUpdateModelInput.java @@ -27,19 +27,22 @@ @Data public class MLUpdateModelInput implements ToXContentObject, Writeable { - + public static final String MODEL_ID_FIELD = "model_id"; // passively set when passing url to rest API public static final String DESCRIPTION_FIELD = "description"; // optional - public static final String MODEL_VERSION_FIELD = "model_version"; // passively set when register model to a new model group + public static final String MODEL_VERSION_FIELD = "model_version"; // passively set when register model to a new + // model group public static final String MODEL_NAME_FIELD = "name"; // optional public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional public static final String IS_ENABLED_FIELD = "is_enabled"; // optional - public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; // optional + public static final String RATE_LIMITER_FIELD = "rate_limiter"; // optional public static final String MODEL_CONFIG_FIELD = "model_config"; // optional - public static final String UPDATED_CONNECTOR_FIELD = "updated_connector"; // passively set when updating the internal connector + public static final String UPDATED_CONNECTOR_FIELD = "updated_connector"; // passively set when updating the + // internal connector public static final String CONNECTOR_ID_FIELD = "connector_id"; // optional public static final String CONNECTOR_FIELD = "connector"; // optional - public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update request + public static final String LAST_UPDATED_TIME_FIELD = "last_updated_time"; // passively set when sending update + // request @Getter private String modelId; @@ -48,7 +51,7 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { private String name; private String modelGroupId; private Boolean isEnabled; - private MLRateLimiter modelRateLimiterConfig; + private MLRateLimiter rateLimiter; private MLModelConfig modelConfig; private Connector updatedConnector; private String connectorId; @@ -57,15 +60,15 @@ public class MLUpdateModelInput implements ToXContentObject, Writeable { @Builder(toBuilder = true) public MLUpdateModelInput(String modelId, String description, String version, String name, String modelGroupId, - Boolean isEnabled, MLRateLimiter modelRateLimiterConfig, MLModelConfig modelConfig, - Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime) { + Boolean isEnabled, MLRateLimiter rateLimiter, MLModelConfig modelConfig, + Connector updatedConnector, String connectorId, MLCreateConnectorInput connector, Instant lastUpdateTime) { this.modelId = modelId; this.description = description; this.version = version; this.name = name; this.modelGroupId = modelGroupId; this.isEnabled = isEnabled; - this.modelRateLimiterConfig = modelRateLimiterConfig; + this.rateLimiter = rateLimiter; this.modelConfig = modelConfig; this.updatedConnector = updatedConnector; this.connectorId = connectorId; @@ -81,7 +84,7 @@ public MLUpdateModelInput(StreamInput in) throws IOException { modelGroupId = in.readOptionalString(); isEnabled = in.readOptionalBoolean(); if (in.readBoolean()) { - modelRateLimiterConfig = new MLRateLimiter(in); + rateLimiter = new MLRateLimiter(in); } if (in.readBoolean()) { modelConfig = new TextEmbeddingModelConfig(in); @@ -115,8 +118,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isEnabled != null) { builder.field(IS_ENABLED_FIELD, isEnabled); } - if (modelRateLimiterConfig != null) { - builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + if (rateLimiter != null) { + builder.field(RATE_LIMITER_FIELD, rateLimiter); } if (modelConfig != null) { builder.field(MODEL_CONFIG_FIELD, modelConfig); @@ -145,9 +148,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(name); out.writeOptionalString(modelGroupId); out.writeOptionalBoolean(isEnabled); - if (modelRateLimiterConfig != null) { + if (rateLimiter != null) { out.writeBoolean(true); - modelRateLimiterConfig.writeTo(out); + rateLimiter.writeTo(out); } else { out.writeBoolean(false); } @@ -180,7 +183,7 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException String name = null; String modelGroupId = null; Boolean isEnabled = null; - MLRateLimiter modelRateLimiterConfig = null; + MLRateLimiter rateLimiter = null; MLModelConfig modelConfig = null; Connector updatedConnector = null; String connectorId = null; @@ -204,8 +207,8 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException case IS_ENABLED_FIELD: isEnabled = parser.booleanValue(); break; - case MODEL_RATE_LIMITER_CONFIG_FIELD: - modelRateLimiterConfig = MLRateLimiter.parse(parser); + case RATE_LIMITER_FIELD: + rateLimiter = MLRateLimiter.parse(parser); break; case MODEL_CONFIG_FIELD: modelConfig = TextEmbeddingModelConfig.parse(parser); @@ -221,7 +224,9 @@ public static MLUpdateModelInput parse(XContentParser parser) throws IOException break; } } - // Model ID can only be set through RestRequest. Model version can only be set automatically. - return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, modelRateLimiterConfig, modelConfig, updatedConnector, connectorId, connector, lastUpdateTime); + // Model ID can only be set through RestRequest. Model version can only be set + // automatically. + return new MLUpdateModelInput(modelId, description, version, name, modelGroupId, isEnabled, rateLimiter, + modelConfig, updatedConnector, connectorId, connector, lastUpdateTime); } } \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java index ee5c89f1da..420f8b3e0f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelInput.java @@ -44,7 +44,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { public static final String DESCRIPTION_FIELD = "description"; public static final String VERSION_FIELD = "version"; public static final String IS_ENABLED_FIELD = "is_enabled"; - public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; + public static final String RATE_LIMITER_FIELD = "rate_limiter"; public static final String URL_FIELD = "url"; public static final String MODEL_FORMAT_FIELD = "model_format"; public static final String MODEL_CONFIG_FIELD = "model_config"; @@ -63,7 +63,7 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { private String version; private String description; private Boolean isEnabled; - private MLRateLimiter modelRateLimiterConfig; + private MLRateLimiter rateLimiter; private String url; private String hashValue; private MLModelFormat modelFormat; @@ -84,26 +84,25 @@ public class MLRegisterModelInput implements ToXContentObject, Writeable { @Builder(toBuilder = true) public MLRegisterModelInput(FunctionName functionName, - String modelName, - String modelGroupId, - String version, - String description, - Boolean isEnabled, - MLRateLimiter modelRateLimiterConfig, - String url, - String hashValue, - MLModelFormat modelFormat, - MLModelConfig modelConfig, - boolean deployModel, - String[] modelNodeIds, - Connector connector, - String connectorId, - List backendRoles, - Boolean addAllBackendRoles, - AccessMode accessMode, - Boolean doesVersionCreateModelGroup, - Boolean isHidden - ) { + String modelName, + String modelGroupId, + String version, + String description, + Boolean isEnabled, + MLRateLimiter rateLimiter, + String url, + String hashValue, + MLModelFormat modelFormat, + MLModelConfig modelConfig, + boolean deployModel, + String[] modelNodeIds, + Connector connector, + String connectorId, + List backendRoles, + Boolean addAllBackendRoles, + AccessMode accessMode, + Boolean doesVersionCreateModelGroup, + Boolean isHidden) { this.functionName = Objects.requireNonNullElse(functionName, FunctionName.TEXT_EMBEDDING); if (modelName == null) { throw new IllegalArgumentException("model name is null"); @@ -112,7 +111,11 @@ public MLRegisterModelInput(FunctionName functionName, if (modelFormat == null) { throw new IllegalArgumentException("model format is null"); } - if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration. + if (url != null && modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE + && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model + // configuration. Currently, we only support one + // type of sparse model, which is pretrained, and + // it doesn't necessitate a model configuration. throw new IllegalArgumentException("model config is null"); } } @@ -121,7 +124,7 @@ public MLRegisterModelInput(FunctionName functionName, this.version = version; this.description = description; this.isEnabled = isEnabled; - this.modelRateLimiterConfig = modelRateLimiterConfig; + this.rateLimiter = rateLimiter; this.url = url; this.hashValue = hashValue; this.modelFormat = modelFormat; @@ -137,7 +140,6 @@ public MLRegisterModelInput(FunctionName functionName, this.isHidden = isHidden; } - public MLRegisterModelInput(StreamInput in) throws IOException { this.functionName = in.readEnum(FunctionName.class); this.modelName = in.readString(); @@ -146,7 +148,7 @@ public MLRegisterModelInput(StreamInput in) throws IOException { this.description = in.readOptionalString(); this.isEnabled = in.readOptionalBoolean(); if (in.readBoolean()) { - this.modelRateLimiterConfig = new MLRateLimiter(in); + this.rateLimiter = new MLRateLimiter(in); } this.url = in.readOptionalString(); this.hashValue = in.readOptionalString(); @@ -185,9 +187,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(version); out.writeOptionalString(description); out.writeOptionalBoolean(isEnabled); - if (modelRateLimiterConfig != null) { + if (rateLimiter != null) { out.writeBoolean(true); - modelRateLimiterConfig.writeTo(out); + rateLimiter.writeTo(out); } else { out.writeBoolean(false); } @@ -248,8 +250,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isEnabled != null) { builder.field(IS_ENABLED_FIELD, isEnabled); } - if (modelRateLimiterConfig != null) { - builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + if (rateLimiter != null) { + builder.field(RATE_LIMITER_FIELD, rateLimiter); } if (url != null) { builder.field(URL_FIELD, url); @@ -292,11 +294,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, boolean deployModel) throws IOException { + public static MLRegisterModelInput parse(XContentParser parser, String modelName, String version, + boolean deployModel) throws IOException { FunctionName functionName = null; String modelGroupId = null; Boolean isEnabled = null; - MLRateLimiter modelRateLimiterConfig = null; + MLRateLimiter rateLimiter = null; String url = null; String hashValue = null; String description = null; @@ -325,8 +328,8 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName case IS_ENABLED_FIELD: isEnabled = parser.booleanValue(); break; - case MODEL_RATE_LIMITER_CONFIG_FIELD: - modelRateLimiterConfig = MLRateLimiter.parse(parser); + case RATE_LIMITER_FIELD: + rateLimiter = MLRateLimiter.parse(parser); break; case URL_FIELD: url = parser.text(); @@ -378,7 +381,10 @@ public static MLRegisterModelInput parse(XContentParser parser, String modelName break; } } - return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); + return new MLRegisterModelInput(functionName, modelName, modelGroupId, version, description, isEnabled, + rateLimiter, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), + connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, + isHidden); } public static MLRegisterModelInput parse(XContentParser parser, boolean deployModel) throws IOException { @@ -387,7 +393,7 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo String modelGroupId = null; String version = null; Boolean isEnabled = null; - MLRateLimiter modelRateLimiterConfig = null; + MLRateLimiter rateLimiter = null; String url = null; String hashValue = null; String description = null; @@ -426,8 +432,8 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo case IS_ENABLED_FIELD: isEnabled = parser.booleanValue(); break; - case MODEL_RATE_LIMITER_CONFIG_FIELD: - modelRateLimiterConfig = MLRateLimiter.parse(parser); + case RATE_LIMITER_FIELD: + rateLimiter = MLRateLimiter.parse(parser); break; case URL_FIELD: url = parser.text(); @@ -476,6 +482,8 @@ public static MLRegisterModelInput parse(XContentParser parser, boolean deployMo break; } } - return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); + return new MLRegisterModelInput(functionName, name, modelGroupId, version, description, isEnabled, rateLimiter, + url, hashValue, modelFormat, modelConfig, deployModel, modelNodeIds.toArray(new String[0]), connector, + connectorId, backendRoles, addAllBackendRoles, accessMode, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java index cb08dff812..25378e42c4 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInput.java @@ -31,27 +31,26 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @Data -public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ +public class MLRegisterModelMetaInput implements ToXContentObject, Writeable { public static final String FUNCTION_NAME_FIELD = "function_name"; - public static final String MODEL_NAME_FIELD = "name"; //mandatory - public static final String DESCRIPTION_FIELD = "description"; //optional - public static final String IS_ENABLED_FIELD = "is_enabled"; //optional - public static final String MODEL_RATE_LIMITER_CONFIG_FIELD = "model_rate_limiter_config"; //optional + public static final String MODEL_NAME_FIELD = "name"; // mandatory + public static final String DESCRIPTION_FIELD = "description"; // optional + public static final String IS_ENABLED_FIELD = "is_enabled"; // optional + public static final String RATE_LIMITER_FIELD = "rate_limiter"; // optional public static final String VERSION_FIELD = "version"; - public static final String MODEL_FORMAT_FIELD = "model_format"; //mandatory + public static final String MODEL_FORMAT_FIELD = "model_format"; // mandatory public static final String MODEL_STATE_FIELD = "model_state"; public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes"; - public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; //mandatory - public static final String MODEL_CONFIG_FIELD = "model_config"; //mandatory - public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; //mandatory - public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; //optional - public static final String BACKEND_ROLES_FIELD = "backend_roles"; //optional - public static final String ACCESS_MODE = "access_mode"; //optional - public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; //optional + public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value"; // mandatory + public static final String MODEL_CONFIG_FIELD = "model_config"; // mandatory + public static final String TOTAL_CHUNKS_FIELD = "total_chunks"; // mandatory + public static final String MODEL_GROUP_ID_FIELD = "model_group_id"; // optional + public static final String BACKEND_ROLES_FIELD = "backend_roles"; // optional + public static final String ACCESS_MODE = "access_mode"; // optional + public static final String ADD_ALL_BACKEND_ROLES = "add_all_backend_roles"; // optional public static final String DOES_VERSION_CREATE_MODEL_GROUP = "does_version_create_model_group"; - private FunctionName functionName; private String name; @@ -59,7 +58,7 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private String description; private String version; private Boolean isEnabled; - private MLRateLimiter modelRateLimiterConfig; + private MLRateLimiter rateLimiter; private MLModelFormat modelFormat; private MLModelState modelState; @@ -75,10 +74,13 @@ public class MLRegisterModelMetaInput implements ToXContentObject, Writeable{ private Boolean isHidden; @Builder(toBuilder = true) - public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, String description, Boolean isEnabled, MLRateLimiter modelRateLimiterConfig, MLModelFormat modelFormat, MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, MLModelConfig modelConfig, Integer totalChunks, List backendRoles, - AccessMode accessMode, - Boolean isAddAllBackendRoles, - Boolean doesVersionCreateModelGroup, Boolean isHidden) { + public MLRegisterModelMetaInput(String name, FunctionName functionName, String modelGroupId, String version, + String description, Boolean isEnabled, MLRateLimiter rateLimiter, MLModelFormat modelFormat, + MLModelState modelState, Long modelContentSizeInBytes, String modelContentHashValue, + MLModelConfig modelConfig, Integer totalChunks, List backendRoles, + AccessMode accessMode, + Boolean isAddAllBackendRoles, + Boolean doesVersionCreateModelGroup, Boolean isHidden) { if (name == null) { throw new IllegalArgumentException("model name is null"); } @@ -93,7 +95,11 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m if (modelContentHashValue == null) { throw new IllegalArgumentException("model content hash value is null"); } - if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model configuration. Currently, we only support one type of sparse model, which is pretrained, and it doesn't necessitate a model configuration. + if (modelConfig == null && functionName != FunctionName.SPARSE_TOKENIZE + && functionName != FunctionName.SPARSE_ENCODING) { // The tokenize model doesn't require a model + // configuration. Currently, we only support one type + // of sparse model, which is pretrained, and it + // doesn't necessitate a model configuration. throw new IllegalArgumentException("model config is null"); } if (totalChunks == null) { @@ -104,7 +110,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.version = version; this.description = description; this.isEnabled = isEnabled; - this.modelRateLimiterConfig = modelRateLimiterConfig; + this.rateLimiter = rateLimiter; this.modelFormat = modelFormat; this.modelState = modelState; this.modelContentSizeInBytes = modelContentSizeInBytes; @@ -118,7 +124,7 @@ public MLRegisterModelMetaInput(String name, FunctionName functionName, String m this.isHidden = isHidden; } - public MLRegisterModelMetaInput(StreamInput in) throws IOException{ + public MLRegisterModelMetaInput(StreamInput in) throws IOException { this.name = in.readString(); this.functionName = in.readEnum(FunctionName.class); this.modelGroupId = in.readOptionalString(); @@ -126,7 +132,7 @@ public MLRegisterModelMetaInput(StreamInput in) throws IOException{ this.description = in.readOptionalString(); this.isEnabled = in.readOptionalBoolean(); if (in.readBoolean()) { - modelRateLimiterConfig = new MLRateLimiter(in); + rateLimiter = new MLRateLimiter(in); } if (in.readBoolean()) { modelFormat = in.readEnum(MLModelFormat.class); @@ -157,9 +163,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(version); out.writeOptionalString(description); out.writeOptionalBoolean(isEnabled); - if (modelRateLimiterConfig != null) { + if (rateLimiter != null) { out.writeBoolean(true); - modelRateLimiterConfig.writeTo(out); + rateLimiter.writeTo(out); } else { out.writeBoolean(false); } @@ -218,8 +224,8 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (isEnabled != null) { builder.field(IS_ENABLED_FIELD, isEnabled); } - if (modelRateLimiterConfig != null) { - builder.field(MODEL_RATE_LIMITER_CONFIG_FIELD, modelRateLimiterConfig); + if (rateLimiter != null) { + builder.field(RATE_LIMITER_FIELD, rateLimiter); } builder.field(MODEL_FORMAT_FIELD, modelFormat); if (modelState != null) { @@ -257,7 +263,7 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc String version = null; String description = null; Boolean isEnabled = null; - MLRateLimiter modelRateLimiterConfig = null; + MLRateLimiter rateLimiter = null; MLModelFormat modelFormat = null; MLModelState modelState = null; Long modelContentSizeInBytes = null; @@ -293,8 +299,8 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc case IS_ENABLED_FIELD: isEnabled = parser.booleanValue(); break; - case MODEL_RATE_LIMITER_CONFIG_FIELD: - modelRateLimiterConfig = MLRateLimiter.parse(parser); + case RATE_LIMITER_FIELD: + rateLimiter = MLRateLimiter.parse(parser); break; case MODEL_FORMAT_FIELD: modelFormat = MLModelFormat.from(parser.text()); @@ -338,7 +344,9 @@ public static MLRegisterModelMetaInput parse(XContentParser parser) throws IOExc break; } } - return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, isEnabled, modelRateLimiterConfig, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); + return new MLRegisterModelMetaInput(name, functionName, modelGroupId, version, description, isEnabled, + rateLimiter, modelFormat, modelState, modelContentSizeInBytes, modelContentHashValue, modelConfig, + totalChunks, backendRoles, accessMode, isAddAllBackendRoles, doesVersionCreateModelGroup, isHidden); } } diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java new file mode 100644 index 0000000000..da5d6415ff --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLControllerTest.java @@ -0,0 +1,365 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.common.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchModule; + +public class MLControllerTest { + private MLRateLimiter rateLimiter; + + private MLController controller; + + private MLController controllerNull; + + private final String expectedInputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() throws Exception { + rateLimiter = MLRateLimiter.builder() + .limit("1") + .unit(TimeUnit.MILLISECONDS) + .build(); + + controllerNull = MLController.builder() + .modelId("testModelId").build(); + + controller = MLControllerGenerator("testUser", rateLimiter); + + } + + @Test + public void readInputStreamSuccess() throws IOException { + readInputStream(controller, parsedInput -> { + assertEquals("testModelId", parsedInput.getModelId()); + assertEquals(controller.getUserRateLimiter().get("testUser").getLimit(), + parsedInput.getUserRateLimiter().get("testUser").getLimit()); + }); + } + + @Test + public void readInputStreamSuccessWithNullFields() throws IOException { + controller.setUserRateLimiter(null); + readInputStream(controller, parsedInput -> { + assertNull(parsedInput.getUserRateLimiter()); + }); + } + + @Test + public void testToXContent() throws Exception { + String jsonStr = serializationWithToXContent(controller); + assertEquals(expectedInputStr, jsonStr); + } + + @Test + public void testToXContentIncomplete() throws Exception { + final String expectedIncompleteInputStr = "{\"model_id\":\"testModelId\"}"; + String jsonStr = serializationWithToXContent(controllerNull); + assertEquals(expectedIncompleteInputStr, jsonStr); + } + + @Test + public void testToXContentWithNullMLRateLimiterInUserRateLimiter() throws Exception { + // Notice that MLController will throw an exception if it parses this + // output string, check + // parseWithNullMLRateLimiterInUserRateLimiterFieldWithException test + // below. + final String expectedOutputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":null}}"; + MLController controllerWithTestUserAndEmptyRateLimiter = MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", null); + } + }) + .build(); + String jsonStr = serializationWithToXContent(controllerWithTestUserAndEmptyRateLimiter); + assertEquals(expectedOutputStrWithNullField, jsonStr); + } + + @Test + public void parseSuccess() throws Exception { + testParseFromJsonString(expectedInputStr, parsedInput -> assertEquals("testModelId", parsedInput.getModelId())); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty + // different from usual + public void parseWithoutUserRateLimiterFieldWithNoException() throws Exception { + final String expectedIncompleteInputStr = "{\"model_id\":\"testModelId\"}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{}}"; + + testParseFromJsonString(expectedIncompleteInputStr, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty + // different from usual + public void parseWithNullUserRateLimiterFieldWithNoException() throws Exception { + final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":null}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // Notice that this won't throw an IllegalStateException, which is pretty + // different from usual + public void parseWithTestUserAndEmptyRateLimiterFieldWithNoException() throws Exception { + final String expectedInputStrWithEmptyField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + + "{\"testUser\":{}}}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":" + + "{}}"; + testParseFromJsonString(expectedInputStrWithEmptyField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithNullField() throws Exception { + exceptionRule.expect(IllegalStateException.class); + final String expectedInputStrWithNullField = "{\"model_id\":null,\"user_rate_limiter\":" + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithIllegalField() throws Exception { + final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" + + + "{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}"; + + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + // This will throw a ParsingException because MLRateLimiter parser cannot parse + // null field. + public void parseWithNullMLRateLimiterInUserRateLimiterFieldWithException() throws Exception { + exceptionRule.expect(RuntimeException.class); + final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":null}}"; + final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":null}}"; + + testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { + try { + assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void parseWithIllegalRateLimiterFieldWithException() throws Exception { + exceptionRule.expect(RuntimeException.class); + final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter\":" + + + "{\"testUser\":\"Some illegal content that MLRateLimiter parser cannot parse.\"}}"; + + testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { + try { + assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void testUserRateLimiterUpdate() { + MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().limit("1").build(); + + MLController controllerWithEmptyUserRateLimiter = MLControllerGenerator(); + MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", + rateLimiterWithNumber); + MLController controllerWithNewUserAndEmptyRateLimiter = MLControllerGenerator("newUser"); + + controllerWithEmptyUserRateLimiter.update(controllerNull); + assertTrue(controllerWithEmptyUserRateLimiter.getUserRateLimiter().isEmpty()); + + controllerWithEmptyUserRateLimiter.update(controllerWithEmptyUserRateLimiter); + assertTrue(controllerWithEmptyUserRateLimiter.getUserRateLimiter().isEmpty()); + + controllerWithEmptyUserRateLimiter.update(controllerWithTestUserAndRateLimiterWithNumber); + assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") + .getLimit()); + assertNull(controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") + .getUnit()); + + controllerWithEmptyUserRateLimiter.update(controller); + assertEquals("1", controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("testUser") + .getLimit()); + assertEquals(TimeUnit.MILLISECONDS, controllerWithEmptyUserRateLimiter.getUserRateLimiter() + .get("testUser").getUnit()); + + controllerWithEmptyUserRateLimiter.update(controllerWithNewUserAndEmptyRateLimiter); + assertTrue(controllerWithEmptyUserRateLimiter.getUserRateLimiter().get("newUser").isEmpty()); + } + + @Test + public void testUserRateLimiterIsUpdatable() { + MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().limit("1").build(); + + MLController controllerWithEmptyUserRateLimiter = MLControllerGenerator(); + MLController controllerWithTestUserAndRateLimiterWithNumber = MLControllerGenerator("testUser", + rateLimiterWithNumber); + MLController controllerWithNewUserAndRateLimiterWithNumber = MLControllerGenerator("newUser", + rateLimiterWithNumber); + MLController controllerWithNewUserAndEmptyRateLimiter = MLControllerGenerator("newUser"); + MLController controllerWithNewUserAndRateLimiter = MLControllerGenerator("newUser", rateLimiter); + + assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(null)); + assertFalse(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controllerNull)); + assertFalse(controllerWithEmptyUserRateLimiter + .isDeployRequiredAfterUpdate(controllerWithEmptyUserRateLimiter)); + assertFalse(controllerWithEmptyUserRateLimiter + .isDeployRequiredAfterUpdate(controllerWithNewUserAndEmptyRateLimiter)); + + assertFalse(controllerWithEmptyUserRateLimiter + .isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber)); + assertFalse(controllerWithTestUserAndRateLimiterWithNumber + .isDeployRequiredAfterUpdate(controllerWithTestUserAndRateLimiterWithNumber)); + assertTrue(controllerWithEmptyUserRateLimiter.isDeployRequiredAfterUpdate(controller)); + assertTrue(controllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(controller)); + + assertFalse(controllerWithTestUserAndRateLimiterWithNumber + .isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiterWithNumber)); + assertTrue(controllerWithTestUserAndRateLimiterWithNumber + .isDeployRequiredAfterUpdate(controllerWithNewUserAndRateLimiter)); + } + + private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, + expectedInputStr); + parser.nextToken(); + MLController parsedInput = MLController.parse(parser); + verify.accept(parsedInput); + } + + private void readInputStream(MLController input, Consumer verify) throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + input.writeTo(bytesStreamOutput); + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLController parsedInput = new MLController(streamInput); + verify.accept(parsedInput); + } + + private String serializationWithToXContent(MLController input) throws IOException { + XContentBuilder builder = XContentFactory.jsonBuilder(); + input.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(builder); + return builder.toString(); + } + + private MLController MLControllerGenerator(String user, MLRateLimiter rateLimiter) { + return MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put(user, rateLimiter); + } + }) + .build(); + + } + + private MLController MLControllerGenerator(String user) { + return MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put(user, MLRateLimiter.builder().build()); + } + }) + .build(); + + } + + private MLController MLControllerGenerator() { + return MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>()) + .build(); + + } + + @Ignore + @Test + public void testRateLimiterRemove() { + MLController controllerWithTestUserAndEmptyRateLimiter = MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", MLRateLimiter.builder().build()); + } + }) + .build(); + + controller.update(controllerWithTestUserAndEmptyRateLimiter); + assertNull(controller.getUserRateLimiter().get("testUser")); + } + +} diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java deleted file mode 100644 index e9283c4958..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/controller/MLModelControllerTest.java +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ -package org.opensearch.ml.common.controller; - - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; - -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; - -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.search.SearchModule; - -public class MLModelControllerTest { - private MLRateLimiter rateLimiter; - - private MLModelController modelController; - - private MLModelController modelControllerNull; - - private final String expectedInputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + - "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; - - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - - @Before - public void setUp() throws Exception { - rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) - .build(); - - modelControllerNull = MLModelController.builder() - .modelId("testModelId").build(); - - modelController = MLModelControllerGenerator("testUser", rateLimiter); - - } - - @Test - public void readInputStreamSuccess() throws IOException { - readInputStream(modelController, parsedInput -> { - assertEquals("testModelId", parsedInput.getModelId()); - assertEquals(modelController.getUserRateLimiterConfig().get("testUser").getRateLimitNumber(), - parsedInput.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - }); - } - - @Test - public void readInputStreamSuccessWithNullFields() throws IOException { - modelController.setUserRateLimiterConfig(null); - readInputStream(modelController, parsedInput -> { - assertNull(parsedInput.getUserRateLimiterConfig()); - }); - } - - @Test - public void testToXContent() throws Exception { - String jsonStr = serializationWithToXContent(modelController); - assertEquals(expectedInputStr, jsonStr); - } - - - @Test - public void testToXContentIncomplete() throws Exception { - final String expectedIncompleteInputStr = - "{\"model_id\":\"testModelId\"}"; - String jsonStr = serializationWithToXContent(modelControllerNull); - assertEquals(expectedIncompleteInputStr, jsonStr); - } - - @Test - public void testToXContentWithNullMLRateLimiterInUserRateLimiterConfig() throws Exception { - // Notice that MLModelController will throw an exception if it parses this output string, check parseWithNullMLRateLimiterInUserRateLimiterConfigFieldWithException test below. - final String expectedOutputStrWithNullField = - "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; - MLModelController modelControllerWithTestUserAndEmptyRateLimiter = MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>(){{put("testUser", null);}}) - .build(); - String jsonStr = serializationWithToXContent(modelControllerWithTestUserAndEmptyRateLimiter); - assertEquals(expectedOutputStrWithNullField, jsonStr); - } - - @Test - public void parseSuccess() throws Exception { - testParseFromJsonString(expectedInputStr, parsedInput -> assertEquals("testModelId", parsedInput.getModelId())); - } - - @Test - // Notice that this won't throw an IllegalStateException, which is pretty different from usual - public void parseWithoutUserRateLimiterConfigFieldWithNoException() throws Exception { - final String expectedIncompleteInputStr = "{\"model_id\":\"testModelId\"}"; - final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{}}"; - - testParseFromJsonString(expectedIncompleteInputStr, parsedInput -> { - try { - assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - // Notice that this won't throw an IllegalStateException, which is pretty different from usual - public void parseWithNullUserRateLimiterConfigFieldWithNoException() throws Exception { - final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":null}"; - final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{}}"; - - testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { - try { - assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - // Notice that this won't throw an IllegalStateException, which is pretty different from usual - public void parseWithTestUserAndEmptyRateLimiterFieldWithNoException() throws Exception { - final String expectedInputStrWithEmptyField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + - "{\"testUser\":{}}}"; - final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":" + - "{}}"; - testParseFromJsonString(expectedInputStrWithEmptyField, parsedInput -> { - try { - assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - public void parseWithNullField() throws Exception { - exceptionRule.expect(IllegalStateException.class); - final String expectedInputStrWithNullField = "{\"model_id\":null,\"user_rate_limiter_config\":" + - "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; - - testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { - try { - assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter_config\":" + - "{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}"; - - testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { - try { - assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - // This will throw a ParsingException because MLRateLimiter parser cannot parse null field. - public void parseWithNullMLRateLimiterInUserRateLimiterConfigFieldWithException() throws Exception { - exceptionRule.expect(RuntimeException.class); - final String expectedInputStrWithNullField = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; - final String expectedOutputStr = "{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":null}}"; - - testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { - try { - assertEquals(expectedOutputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - public void parseWithIllegalRateLimiterFieldWithException() throws Exception { - exceptionRule.expect(RuntimeException.class); - final String expectedInputStrWithIllegalField = "{\"model_id\":\"testModelId\",\"illegal_field\":\"This field need to be skipped.\",\"user_rate_limiter_config\":" + - "{\"testUser\":\"Some illegal content that MLRateLimiter parser cannot parse.\"}}"; - - testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { - try { - assertEquals(expectedInputStr, serializationWithToXContent(parsedInput)); - } catch (IOException e) { - throw new RuntimeException(e); - } - }); - } - - @Test - public void testUserRateLimiterConfigUpdate() { - MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().rateLimitNumber("1").build(); - - MLModelController modelControllerWithEmptyUserRateLimiterConfig = MLModelControllerGenerator(); - MLModelController modelControllerWithTestUserAndRateLimiterWithNumber = MLModelControllerGenerator("testUser", rateLimiterWithNumber); - MLModelController modelControllerWithNewUserAndEmptyRateLimiter = MLModelControllerGenerator("newUser"); - - modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerNull); - assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().isEmpty()); - - modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithEmptyUserRateLimiterConfig); - assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().isEmpty()); - - modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithTestUserAndRateLimiterWithNumber); - assertEquals("1", modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - assertNull(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); - - modelControllerWithEmptyUserRateLimiterConfig.update(modelController); - assertEquals("1", modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); - - modelControllerWithEmptyUserRateLimiterConfig.update(modelControllerWithNewUserAndEmptyRateLimiter); - assertTrue(modelControllerWithEmptyUserRateLimiterConfig.getUserRateLimiterConfig().get("newUser").isEmpty()); - } - - @Test - public void testUserRateLimiterConfigIsUpdatable() { - MLRateLimiter rateLimiterWithNumber = MLRateLimiter.builder().rateLimitNumber("1").build(); - - MLModelController modelControllerWithEmptyUserRateLimiterConfig = MLModelControllerGenerator(); - MLModelController modelControllerWithTestUserAndRateLimiterWithNumber = MLModelControllerGenerator("testUser", rateLimiterWithNumber); - MLModelController modelControllerWithNewUserAndRateLimiterWithNumber = MLModelControllerGenerator("newUser", rateLimiterWithNumber); - MLModelController modelControllerWithNewUserAndEmptyRateLimiter = MLModelControllerGenerator("newUser"); - MLModelController modelControllerWithNewUserAndRateLimiter = MLModelControllerGenerator("newUser", rateLimiter); - - assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(null)); - assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerNull)); - assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithEmptyUserRateLimiterConfig)); - assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndEmptyRateLimiter)); - - assertFalse(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelControllerWithTestUserAndRateLimiterWithNumber)); - assertFalse(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithTestUserAndRateLimiterWithNumber)); - assertTrue(modelControllerWithEmptyUserRateLimiterConfig.isDeployRequiredAfterUpdate(modelController)); - assertTrue(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelController)); - - assertFalse(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndRateLimiterWithNumber)); - assertTrue(modelControllerWithTestUserAndRateLimiterWithNumber.isDeployRequiredAfterUpdate(modelControllerWithNewUserAndRateLimiter)); - } - - private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); - parser.nextToken(); - MLModelController parsedInput = MLModelController.parse(parser); - verify.accept(parsedInput); - } - - private void readInputStream(MLModelController input, Consumer verify) throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - input.writeTo(bytesStreamOutput); - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLModelController parsedInput = new MLModelController(streamInput); - verify.accept(parsedInput); - } - - private String serializationWithToXContent(MLModelController input) throws IOException { - XContentBuilder builder = XContentFactory.jsonBuilder(); - input.toXContent(builder, ToXContent.EMPTY_PARAMS); - assertNotNull(builder); - return builder.toString(); - } - - private MLModelController MLModelControllerGenerator(String user, MLRateLimiter rateLimiter) { - return MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>(){{put(user, rateLimiter);}}) - .build(); - - } - - private MLModelController MLModelControllerGenerator(String user) { - return MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>(){{put(user, MLRateLimiter.builder().build());}}) - .build(); - - } - - private MLModelController MLModelControllerGenerator() { - return MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>()) - .build(); - - } - - @Ignore - @Test - public void testRateLimiterRemove() { - MLModelController modelControllerWithTestUserAndEmptyRateLimiter = MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>(){{put("testUser", MLRateLimiter.builder().build());}}) - .build(); - - modelController.update(modelControllerWithTestUserAndEmptyRateLimiter); - assertNull(modelController.getUserRateLimiterConfig().get("testUser")); - } - -} diff --git a/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java index 7cc85974ab..f2e1f89dd5 100644 --- a/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java +++ b/common/src/test/java/org/opensearch/ml/common/controller/MLRateLimiterTest.java @@ -4,7 +4,6 @@ */ package org.opensearch.ml.common.controller; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -43,7 +42,7 @@ public class MLRateLimiterTest { private MLRateLimiter rateLimiterNull; - private final String expectedInputStr = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}"; + private final String expectedInputStr = "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}"; @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -51,15 +50,15 @@ public class MLRateLimiterTest { @Before public void setUp() throws Exception { rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) + .limit("1") + .unit(TimeUnit.MILLISECONDS) .build(); rateLimiterWithNumber = MLRateLimiter.builder() - .rateLimitNumber("1") + .limit("1") .build(); rateLimiterWithUnit = MLRateLimiter.builder() - .rateLimitUnit(TimeUnit.MILLISECONDS) + .unit(TimeUnit.MILLISECONDS) .build(); rateLimiterNull = MLRateLimiter.builder().build(); @@ -69,15 +68,15 @@ public void setUp() throws Exception { @Test public void readInputStreamSuccess() throws IOException { readInputStream(rateLimiter, parsedInput -> { - assertEquals("1", parsedInput.getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, parsedInput.getRateLimitUnit()); + assertEquals("1", parsedInput.getLimit()); + assertEquals(TimeUnit.MILLISECONDS, parsedInput.getUnit()); }); } @Test public void readInputStreamSuccessWithNullFields() throws IOException { readInputStream(rateLimiterWithNumber, parsedInput -> { - assertNull(parsedInput.getRateLimitUnit()); + assertNull(parsedInput.getUnit()); }); } @@ -98,15 +97,15 @@ public void testToXContentIncomplete() throws Exception { @Test public void parseSuccess() throws Exception { testParseFromJsonString(expectedInputStr, parsedInput -> { - assertEquals("1", parsedInput.getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, parsedInput.getRateLimitUnit()); + assertEquals("1", parsedInput.getLimit()); + assertEquals(TimeUnit.MILLISECONDS, parsedInput.getUnit()); }); } @Test public void parseWithNullField() throws Exception { exceptionRule.expect(IllegalStateException.class); - final String expectedInputStrWithNullField = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":null}"; + final String expectedInputStrWithNullField = "{\"limit\":\"1\",\"unit\":null}"; testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { @@ -119,7 +118,7 @@ public void parseWithNullField() throws Exception { @Test public void parseWithIllegalField() throws Exception { - final String expectedInputStrWithIllegalField = "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":" + + final String expectedInputStrWithIllegalField = "{\"limit\":\"1\",\"unit\":" + "\"MILLISECONDS\",\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { @@ -150,25 +149,25 @@ public void testIsRateLimiterRemovable() { @Test public void testRateLimiterUpdate() { MLRateLimiter updatedRateLimiter = MLRateLimiter.update(rateLimiterNull, rateLimiter); - assertEquals("1", updatedRateLimiter.getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getRateLimitUnit()); + assertEquals("1", updatedRateLimiter.getLimit()); + assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getUnit()); } @Test public void testRateLimiterPartiallyUpdate() { rateLimiterNull.update(rateLimiterWithNumber); - assertEquals("1", rateLimiterNull.getRateLimitNumber()); - assertNull(rateLimiterNull.getRateLimitUnit()); + assertEquals("1", rateLimiterNull.getLimit()); + assertNull(rateLimiterNull.getUnit()); rateLimiterNull.update(rateLimiterWithUnit); - assertEquals("1", rateLimiterNull.getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, rateLimiterNull.getRateLimitUnit()); + assertEquals("1", rateLimiterNull.getLimit()); + assertEquals(TimeUnit.MILLISECONDS, rateLimiterNull.getUnit()); } @Test public void testRateLimiterUpdateNull() { MLRateLimiter updatedRateLimiter = MLRateLimiter.update(null, rateLimiter); - assertEquals("1", updatedRateLimiter.getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getRateLimitUnit()); + assertEquals("1", updatedRateLimiter.getLimit()); + assertEquals(TimeUnit.MILLISECONDS, updatedRateLimiter.getUnit()); } @Test @@ -191,11 +190,11 @@ public void testRateLimiterIsUpdatable() { @Test public void testRateLimiterIsDeployRequiredAfterUpdate() { MLRateLimiter rateLimiterWithNumber2 = MLRateLimiter.builder() - .rateLimitNumber("2") + .limit("2") .build(); MLRateLimiter rateLimiterWithUnit2 = MLRateLimiter.builder() - .rateLimitUnit(TimeUnit.NANOSECONDS) + .unit(TimeUnit.NANOSECONDS) .build(); assertTrue(MLRateLimiter.isDeployRequiredAfterUpdate(rateLimiter, rateLimiterWithNumber2)); @@ -209,8 +208,10 @@ public void testRateLimiterIsDeployRequiredAfterUpdate() { } private void testParseFromJsonString(String expectedInputStr, Consumer verify) throws Exception { - XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, - Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, expectedInputStr); + XContentParser parser = XContentType.JSON.xContent() + .createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedXContents()), LoggingDeprecationHandler.INSTANCE, + expectedInputStr); parser.nextToken(); MLRateLimiter parsedInput = MLRateLimiter.parse(parser); verify.accept(parsedInput); @@ -235,7 +236,7 @@ private String serializationWithToXContent(MLRateLimiter input) throws IOExcepti @Test public void testRateLimiterRemove() { MLRateLimiter updatedRateLimiter = MLRateLimiter.update(rateLimiter, rateLimiterNull); - assertNull(updatedRateLimiter.getRateLimitUnit()); - assertNull(updatedRateLimiter.getRateLimitNumber()); + assertNull(updatedRateLimiter.getUnit()); + assertNull(updatedRateLimiter.getLimit()); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java similarity index 76% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java index ae5ef8dd49..a9ff1a6361 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerDeleteRequestTest.java @@ -20,27 +20,27 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -public class MLModelControllerDeleteRequestTest { +public class MLControllerDeleteRequestTest { private String modelId; - private MLModelControllerDeleteRequest request; + private MLControllerDeleteRequest request; @Before public void setUp() { modelId = "testModelId"; - request = MLModelControllerDeleteRequest.builder() + request = MLControllerDeleteRequest.builder() .modelId(modelId).build(); } - @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - MLModelControllerDeleteRequest parsedRequest = new MLModelControllerDeleteRequest(bytesStreamOutput.bytes().streamInput()); + MLControllerDeleteRequest parsedRequest = new MLControllerDeleteRequest( + bytesStreamOutput.bytes().streamInput()); assertEquals(parsedRequest.getModelId(), modelId); } @@ -51,15 +51,15 @@ public void validateSuccess() { @Test public void validateWithNullModelIdException() { - MLModelControllerDeleteRequest request = MLModelControllerDeleteRequest.builder().build(); + MLControllerDeleteRequest request = MLControllerDeleteRequest.builder().build(); ActionRequestValidationException exception = request.validate(); assertEquals("Validation Failed: 1: ML model id can't be null;", exception.getMessage()); } @Test - public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { - assertSame(MLModelControllerDeleteRequest.fromActionRequest(request), request); + public void fromActionRequestWithMLUpdateControllerRequestSuccess() { + assertSame(MLControllerDeleteRequest.fromActionRequest(request), request); } @Test @@ -75,7 +75,7 @@ public void writeTo(StreamOutput out) throws IOException { request.writeTo(out); } }; - MLModelControllerDeleteRequest result = MLModelControllerDeleteRequest.fromActionRequest(actionRequest); + MLControllerDeleteRequest result = MLControllerDeleteRequest.fromActionRequest(actionRequest); assertNotSame(result, request); assertEquals(result.getModelId(), request.getModelId()); } @@ -93,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException { throw new IOException("test"); } }; - MLModelControllerDeleteRequest.fromActionRequest(actionRequest); + MLControllerDeleteRequest.fromActionRequest(actionRequest); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequestTest.java similarity index 77% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequestTest.java index f45b790250..d7d688e2bd 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetRequestTest.java @@ -20,25 +20,25 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamOutput; -public class MLModelControllerGetRequestTest { +public class MLControllerGetRequestTest { private String modelId; - private MLModelControllerGetRequest request; + private MLControllerGetRequest request; @Before public void setUp() { modelId = "testModelId"; - request = MLModelControllerGetRequest.builder().modelId(modelId).build(); + request = MLControllerGetRequest.builder().modelId(modelId).build(); } @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - MLModelControllerGetRequest parsedRequest = new MLModelControllerGetRequest(bytesStreamOutput.bytes().streamInput()); + MLControllerGetRequest parsedRequest = new MLControllerGetRequest(bytesStreamOutput.bytes().streamInput()); assertEquals(modelId, parsedRequest.getModelId()); } @@ -55,7 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { request.writeTo(out); } }; - MLModelControllerGetRequest requestFromActionRequest = MLModelControllerGetRequest.fromActionRequest(actionRequest); + MLControllerGetRequest requestFromActionRequest = MLControllerGetRequest.fromActionRequest(actionRequest); assertNotSame(request, requestFromActionRequest); assertEquals(request.getModelId(), requestFromActionRequest.getModelId()); } @@ -73,19 +73,19 @@ public void writeTo(StreamOutput out) throws IOException { throw new IOException(); } }; - MLModelControllerGetRequest.fromActionRequest(actionRequest); + MLControllerGetRequest.fromActionRequest(actionRequest); } @Test - public void fromActionRequestWithMLModelControllerGetRequestSuccess() { - MLModelControllerGetRequest requestFromActionRequest = MLModelControllerGetRequest.fromActionRequest(request); + public void fromActionRequestWithMLControllerGetRequestSuccess() { + MLControllerGetRequest requestFromActionRequest = MLControllerGetRequest.fromActionRequest(request); assertSame(request, requestFromActionRequest); assertEquals(request.getModelId(), requestFromActionRequest.getModelId()); } @Test public void validateNullModelIdException() { - MLModelControllerGetRequest request = MLModelControllerGetRequest.builder().build(); + MLControllerGetRequest request = MLControllerGetRequest.builder().build(); ActionRequestValidationException actionRequestValidationException = request.validate(); assertEquals("Validation Failed: 1: ML model id can't be null;", actionRequestValidationException.getMessage()); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java similarity index 53% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java index af6526638a..6d29106842 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLModelControllerGetResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLControllerGetResponseTest.java @@ -26,39 +26,44 @@ import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; -public class MLModelControllerGetResponseTest { +public class MLControllerGetResponseTest { - private MLModelController modelController; + private MLController controller; - private MLModelControllerGetResponse response; + private MLControllerGetResponse response; @Before public void setUp() { MLRateLimiter rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) + .limit("1") + .unit(TimeUnit.MILLISECONDS) .build(); - modelController = MLModelController.builder() + controller = MLController.builder() .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>() {{ - put("testUser", rateLimiter); - }}) + .userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }) .build(); - response = MLModelControllerGetResponse.builder().modelController(modelController).build(); + response = MLControllerGetResponse.builder().controller(controller).build(); } @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); - MLModelControllerGetResponse parsedResponse = new MLModelControllerGetResponse(bytesStreamOutput.bytes().streamInput()); - assertNotEquals(response.getModelController(), parsedResponse.getModelController()); - assertEquals(response.getModelController().getModelId(), parsedResponse.getModelController().getModelId()); - assertEquals(response.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitNumber(), parsedResponse.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - assertEquals(response.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitUnit(), parsedResponse.getModelController().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); + MLControllerGetResponse parsedResponse = new MLControllerGetResponse( + bytesStreamOutput.bytes().streamInput()); + assertNotEquals(response.getController(), parsedResponse.getController()); + assertEquals(response.getController().getModelId(), parsedResponse.getController().getModelId()); + assertEquals(response.getController().getUserRateLimiter().get("testUser").getLimit(), + parsedResponse.getController().getUserRateLimiter().get("testUser").getLimit()); + assertEquals(response.getController().getUserRateLimiter().get("testUser").getUnit(), + parsedResponse.getController().getUserRateLimiter().get("testUser").getUnit()); } @Test @@ -67,14 +72,17 @@ public void toXContentTest() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); String jsonStr = builder.toString(); - assertEquals("{\"model_id\":\"testModelId\",\"user_rate_limiter_config\":{\"testUser\":{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"}}}",jsonStr); + assertEquals( + "{\"model_id\":\"testModelId\",\"user_rate_limiter\":{\"testUser\":{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"}}}", + jsonStr); } @Test - public void fromActionResponseWithMLModelControllerGetResponseSuccess() { - MLModelControllerGetResponse responseFromActionResponse = MLModelControllerGetResponse.fromActionResponse(response); + public void fromActionResponseWithMLControllerGetResponseSuccess() { + MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse + .fromActionResponse(response); assertSame(response, responseFromActionResponse); - assertEquals(response.getModelController(), responseFromActionResponse.getModelController()); + assertEquals(response.getController(), responseFromActionResponse.getController()); } @Test @@ -85,9 +93,10 @@ public void writeTo(StreamOutput out) throws IOException { response.writeTo(out); } }; - MLModelControllerGetResponse responseFromActionResponse = MLModelControllerGetResponse.fromActionResponse(actionResponse); + MLControllerGetResponse responseFromActionResponse = MLControllerGetResponse + .fromActionResponse(actionResponse); assertNotSame(response, responseFromActionResponse); - assertNotEquals(response.getModelController(), responseFromActionResponse.getModelController()); + assertNotEquals(response.getController(), responseFromActionResponse.getController()); } @Test(expected = UncheckedIOException.class) @@ -98,6 +107,6 @@ public void writeTo(StreamOutput out) throws IOException { throw new IOException(); } }; - MLModelControllerGetResponse.fromActionResponse(actionResponse); + MLControllerGetResponse.fromActionResponse(actionResponse); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java new file mode 100644 index 0000000000..a8e19b58a7 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerRequestTest.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.controller.MLRateLimiter; + +public class MLCreateControllerRequestTest { + private MLController controllerInput; + + private MLCreateControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .limit("1") + .unit(TimeUnit.MILLISECONDS) + .build(); + controllerInput = MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }) + .build(); + request = MLCreateControllerRequest.builder() + .controllerInput(controllerInput) + .build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLCreateControllerRequest parsedRequest = new MLCreateControllerRequest( + bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getControllerInput().getModelId()); + assertTrue(parsedRequest.getControllerInput().getUserRateLimiter().containsKey("testUser")); + assertEquals("1", parsedRequest.getControllerInput().getUserRateLimiter().get("testUser") + .getLimit()); + assertEquals(TimeUnit.MILLISECONDS, + parsedRequest.getControllerInput().getUserRateLimiter().get("testUser").getUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLControllerInputException() { + MLCreateControllerRequest request = MLCreateControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + controllerInput.setModelId(null); + MLCreateControllerRequest request = MLCreateControllerRequest.builder() + .controllerInput(controllerInput) + .build(); + + assertNull(request.validate()); + assertNull(request.getControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLCreateControllerRequestSuccess() { + assertSame(MLCreateControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLCreateControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLCreateControllerRequest result = MLCreateControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getControllerInput().getModelId(), result.getControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLCreateControllerRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java similarity index 74% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java index 1c3f4160f2..6c11a667d1 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateControllerResponseTest.java @@ -13,7 +13,6 @@ import java.io.IOException; import java.io.UncheckedIOException; - import org.junit.Before; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -24,28 +23,28 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.TestHelper; -public class MLCreateModelControllerResponseTest { +public class MLCreateControllerResponseTest { - private MLCreateModelControllerResponse response; + private MLCreateControllerResponse response; @Before public void setup() { - response = new MLCreateModelControllerResponse("testModelId", "Status"); + response = new MLCreateControllerResponse("testModelId", "Status"); } - @Test public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); response.writeTo(bytesStreamOutput); - MLCreateModelControllerResponse newResponse = new MLCreateModelControllerResponse(bytesStreamOutput.bytes().streamInput()); + MLCreateControllerResponse newResponse = new MLCreateControllerResponse( + bytesStreamOutput.bytes().streamInput()); assertEquals(response.getModelId(), newResponse.getModelId()); assertEquals(response.getStatus(), newResponse.getStatus()); } @Test public void testToXContent() throws IOException { - MLCreateModelControllerResponse response = new MLCreateModelControllerResponse("testModelId", "Status"); + MLCreateControllerResponse response = new MLCreateControllerResponse("testModelId", "Status"); XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); response.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); @@ -55,8 +54,8 @@ public void testToXContent() throws IOException { } @Test - public void fromActionResponseWithMLCreateModelControllerResponseSuccess() { - MLCreateModelControllerResponse responseFromActionResponse = MLCreateModelControllerResponse.fromActionResponse(response); + public void fromActionResponseWithMLCreateControllerResponseSuccess() { + MLCreateControllerResponse responseFromActionResponse = MLCreateControllerResponse.fromActionResponse(response); assertSame(response, responseFromActionResponse); assertEquals(response.getModelId(), responseFromActionResponse.getModelId()); } @@ -69,7 +68,8 @@ public void writeTo(StreamOutput out) throws IOException { response.writeTo(out); } }; - MLCreateModelControllerResponse responseFromActionResponse = MLCreateModelControllerResponse.fromActionResponse(actionResponse); + MLCreateControllerResponse responseFromActionResponse = MLCreateControllerResponse + .fromActionResponse(actionResponse); assertNotSame(response, responseFromActionResponse); assertEquals(response.getModelId(), responseFromActionResponse.getModelId()); } @@ -82,6 +82,6 @@ public void writeTo(StreamOutput out) throws IOException { throw new IOException(); } }; - MLCreateModelControllerResponse.fromActionResponse(actionResponse); + MLCreateControllerResponse.fromActionResponse(actionResponse); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java deleted file mode 100644 index b95eb49bd9..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLCreateModelControllerRequestTest.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.HashMap; -import java.util.concurrent.TimeUnit; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.controller.MLRateLimiter; - - -public class MLCreateModelControllerRequestTest { - private MLModelController modelControllerInput; - - private MLCreateModelControllerRequest request; - - @Before - public void setUp() throws Exception { - - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) - .build(); - modelControllerInput = MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>() {{ - put("testUser", rateLimiter); - }}) - .build(); - request = MLCreateModelControllerRequest.builder() - .modelControllerInput(modelControllerInput) - .build(); - } - - @Test - public void writeToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLCreateModelControllerRequest parsedRequest = new MLCreateModelControllerRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals("testModelId", parsedRequest.getModelControllerInput().getModelId()); - assertTrue(parsedRequest.getModelControllerInput().getUserRateLimiterConfig().containsKey("testUser")); - assertEquals("1", parsedRequest.getModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); - } - - @Test - public void validateSuccess() { - assertNull(request.validate()); - } - - @Test - public void validateWithNullMLModelControllerInputException() { - MLCreateModelControllerRequest request = MLCreateModelControllerRequest.builder().build(); - ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: Model controller input can't be null;", exception.getMessage()); - } - - @Test - public void validateWithNullMLModelID() { - modelControllerInput.setModelId(null); - MLCreateModelControllerRequest request = MLCreateModelControllerRequest.builder() - .modelControllerInput(modelControllerInput) - .build(); - - assertNull(request.validate()); - assertNull(request.getModelControllerInput().getModelId()); - } - - @Test - public void fromActionRequestWithMLCreateModelControllerRequestSuccess() { - assertSame(MLCreateModelControllerRequest.fromActionRequest(request), request); - } - - @Test - public void fromActionRequestWithNonMLCreateModelControllerRequestSuccess() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLCreateModelControllerRequest result = MLCreateModelControllerRequest.fromActionRequest(actionRequest); - assertNotSame(result, request); - assertEquals(request.getModelControllerInput().getModelId(), result.getModelControllerInput().getModelId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequestIOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLCreateModelControllerRequest.fromActionRequest(actionRequest); - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java similarity index 68% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java index 3a2a3104a7..e0d817f2d7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodeResponseTest.java @@ -28,7 +28,7 @@ import org.opensearch.core.common.transport.TransportAddress; @RunWith(MockitoJUnitRunner.class) -public class MLDeployModelControllerNodeResponseTest { +public class MLDeployControllerNodeResponseTest { @Mock private DiscoveryNode localNode; @@ -44,35 +44,35 @@ public void setUp() throws Exception { new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); } @Test public void testSerializationDeserialization() throws IOException { - Map deployModelControllerStatus = Map.of("modelName:version", "response"); - MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, deployModelControllerStatus); + Map deployControllerStatus = Map.of("modelName:version", "response"); + MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, deployControllerStatus); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLDeployModelControllerNodeResponse newResponse = new MLDeployModelControllerNodeResponse(output.bytes().streamInput()); + MLDeployControllerNodeResponse newResponse = new MLDeployControllerNodeResponse(output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @Test public void testSerializationDeserializationNullModelUpdateModelCacheStatus() throws IOException { - MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, null); + MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLDeployModelControllerNodeResponse newResponse = new MLDeployModelControllerNodeResponse(output.bytes().streamInput()); + MLDeployControllerNodeResponse newResponse = new MLDeployControllerNodeResponse(output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @Test public void testReadProfile() throws IOException { - MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, new HashMap<>()); + MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, new HashMap<>()); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLDeployModelControllerNodeResponse newResponse = MLDeployModelControllerNodeResponse.readStats(output.bytes().streamInput()); + MLDeployControllerNodeResponse newResponse = MLDeployControllerNodeResponse + .readStats(output.bytes().streamInput()); assertNotEquals(newResponse, response); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java new file mode 100644 index 0000000000..cb07cdbbc0 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesRequestTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +// This test combined MLDeployControllerNodesRequestTest and MLDeployControllerNodeRequestTest together. +@RunWith(MockitoJUnitRunner.class) +public class MLDeployControllerNodesRequestTest { + + @Mock + private DiscoveryNode localNode1; + + @Mock + private DiscoveryNode localNode2; + + private MLDeployControllerNodeRequest deployControllerNodeRequestWithStringNodeIds; + + private MLDeployControllerNodeRequest deployControllerNodeRequestWithDiscoveryNodeIds; + + @Before + public void setUp() throws Exception { + + String modelId = "testModelId"; + String[] stringNodeIds = { "nodeId1", "nodeId2", "nodeId3" }; + DiscoveryNode[] discoveryNodeIds = { localNode1, localNode2 }; + + deployControllerNodeRequestWithStringNodeIds = new MLDeployControllerNodeRequest( + new MLDeployControllerNodesRequest(stringNodeIds, modelId)); + deployControllerNodeRequestWithDiscoveryNodeIds = new MLDeployControllerNodeRequest( + new MLDeployControllerNodesRequest(discoveryNodeIds, modelId)); + + } + + @Test + public void testConstructorSerialization1() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + deployControllerNodeRequestWithStringNodeIds.writeTo(output); + assertEquals("testModelId", + deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId()); + + } + + @Test + public void testConstructorSerialization2() { + assertEquals(2, deployControllerNodeRequestWithDiscoveryNodeIds.getDeployControllerNodesRequest() + .concreteNodes().length); + + } + + @Test + public void testConstructorFromInputStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + deployControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLDeployControllerNodeRequest parsedNodeRequest = new MLDeployControllerNodeRequest(streamInput); + + assertEquals(deployControllerNodeRequestWithStringNodeIds.getDeployControllerNodesRequest().getModelId(), + parsedNodeRequest.getDeployControllerNodesRequest().getModelId()); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java similarity index 67% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java index 47620d015b..bcc3f0c38a 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployControllerNodesResponseTest.java @@ -32,7 +32,7 @@ import org.opensearch.core.xcontent.XContentBuilder; @RunWith(MockitoJUnitRunner.class) -public class MLDeployModelControllerNodesResponseTest { +public class MLDeployControllerNodesResponseTest { @Mock private ClusterName clusterName; @Mock @@ -49,59 +49,57 @@ public void setUp() throws Exception { new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); node2 = new DiscoveryNode( "foo2", "foo2", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); } @Test public void testSerializationDeserialization1() throws IOException { - List responseList = new ArrayList<>(); + List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); - MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, responseList, failuresList); + MLDeployControllerNodesResponse response = new MLDeployControllerNodesResponse(clusterName, responseList, + failuresList); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLDeployModelControllerNodesResponse newResponse = new MLDeployModelControllerNodesResponse(output.bytes().streamInput()); + MLDeployControllerNodesResponse newResponse = new MLDeployControllerNodesResponse(output.bytes().streamInput()); assertEquals(newResponse.getNodes().size(), response.getNodes().size()); } @Test public void testToXContent() throws IOException { - List nodes = new ArrayList<>(); + List nodes = new ArrayList<>(); - Map deployModelControllerStatus1 = Map.of("modelId1", "response"); - nodes.add(new MLDeployModelControllerNodeResponse(node1, deployModelControllerStatus1)); + Map deployControllerStatus1 = Map.of("modelId1", "response"); + nodes.add(new MLDeployControllerNodeResponse(node1, deployControllerStatus1)); - Map deployModelControllerStatus2 = Map.of("modelId2", "response"); - nodes.add(new MLDeployModelControllerNodeResponse(node2, deployModelControllerStatus2)); + Map deployControllerStatus2 = Map.of("modelId2", "response"); + nodes.add(new MLDeployControllerNodeResponse(node2, deployControllerStatus2)); List failures = new ArrayList<>(); - MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, nodes, failures); + MLDeployControllerNodesResponse response = new MLDeployControllerNodesResponse(clusterName, nodes, failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr - ); + jsonStr); } @Test public void testNullUpdateModelCacheStatusToXContent() throws IOException { - List nodes = new ArrayList<>(); - nodes.add(new MLDeployModelControllerNodeResponse(node1, null)); + List nodes = new ArrayList<>(); + nodes.add(new MLDeployControllerNodeResponse(node1, null)); List failures = new ArrayList<>(); - MLDeployModelControllerNodesResponse response = new MLDeployModelControllerNodesResponse(clusterName, nodes, failures); + MLDeployControllerNodesResponse response = new MLDeployControllerNodesResponse(clusterName, nodes, failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals("{}",jsonStr); + assertEquals("{}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java deleted file mode 100644 index 7f30734f96..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLDeployModelControllerNodesRequestTest.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; - -import java.io.IOException; - -import static org.junit.Assert.assertEquals; - -// This test combined MLDeployModelControllerNodesRequestTest and MLDeployModelControllerNodeRequestTest together. -@RunWith(MockitoJUnitRunner.class) -public class MLDeployModelControllerNodesRequestTest { - - @Mock - private DiscoveryNode localNode1; - - @Mock - private DiscoveryNode localNode2; - - private MLDeployModelControllerNodeRequest deployModelControllerNodeRequestWithStringNodeIds; - - private MLDeployModelControllerNodeRequest deployModelControllerNodeRequestWithDiscoveryNodeIds; - - @Before - public void setUp() throws Exception { - - String modelId = "testModelId"; - String[] stringNodeIds = {"nodeId1", "nodeId2", "nodeId3"}; - DiscoveryNode[] discoveryNodeIds = {localNode1, localNode2}; - - deployModelControllerNodeRequestWithStringNodeIds = new MLDeployModelControllerNodeRequest( - new MLDeployModelControllerNodesRequest(stringNodeIds, modelId) - ); - deployModelControllerNodeRequestWithDiscoveryNodeIds = new MLDeployModelControllerNodeRequest( - new MLDeployModelControllerNodesRequest(discoveryNodeIds, modelId) - ); - - } - - @Test - public void testConstructorSerialization1() throws IOException { - BytesStreamOutput output = new BytesStreamOutput(); - deployModelControllerNodeRequestWithStringNodeIds.writeTo(output); - assertEquals("testModelId", deployModelControllerNodeRequestWithStringNodeIds.getDeployModelControllerNodesRequest().getModelId()); - - } - - @Test - public void testConstructorSerialization2() { - assertEquals(2, deployModelControllerNodeRequestWithDiscoveryNodeIds.getDeployModelControllerNodesRequest().concreteNodes().length); - - } - - @Test - public void testConstructorFromInputStream() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - deployModelControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); - - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLDeployModelControllerNodeRequest parsedNodeRequest = new MLDeployModelControllerNodeRequest(streamInput); - - assertEquals(deployModelControllerNodeRequestWithStringNodeIds.getDeployModelControllerNodesRequest().getModelId(), - parsedNodeRequest.getDeployModelControllerNodesRequest().getModelId()); - - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java similarity index 69% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java index 5f0e045418..e1df438393 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodeResponseTest.java @@ -31,8 +31,8 @@ import org.opensearch.core.xcontent.XContentBuilder; @RunWith(MockitoJUnitRunner.class) -public class MLUndeployModelControllerNodeResponseTest { - +public class MLUndeployControllerNodeResponseTest { + @Mock private DiscoveryNode localNode; @@ -47,35 +47,38 @@ public void setUp() throws Exception { new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); } @Test public void testSerializationDeserialization() throws IOException { - Map undeployModelControllerStatus = Map.of("modelName:version", "response"); - MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, undeployModelControllerStatus); + Map undeployControllerStatus = Map.of("modelName:version", "response"); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, + undeployControllerStatus); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployModelControllerNodeResponse newResponse = new MLUndeployModelControllerNodeResponse(output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse( + output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @Test public void testSerializationDeserializationNullModelUpdateModelCacheStatus() throws IOException { - MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, null); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployModelControllerNodeResponse newResponse = new MLUndeployModelControllerNodeResponse(output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = new MLUndeployControllerNodeResponse( + output.bytes().streamInput()); assertEquals(newResponse.getNode().getId(), response.getNode().getId()); } @Test public void testReadProfile() throws IOException { - MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse(localNode, new HashMap<>()); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, new HashMap<>()); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployModelControllerNodeResponse newResponse = MLUndeployModelControllerNodeResponse.readStats(output.bytes().streamInput()); + MLUndeployControllerNodeResponse newResponse = MLUndeployControllerNodeResponse + .readStats(output.bytes().streamInput()); assertNotEquals(newResponse, response); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java new file mode 100644 index 0000000000..20d6d66fb5 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesRequestTest.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import java.io.IOException; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; + +// This test combined MLUndeployControllerNodesRequestTest and MLUndeployControllerNodeRequestTest together. +@RunWith(MockitoJUnitRunner.class) +public class MLUndeployControllerNodesRequestTest { + + @Mock + private DiscoveryNode localNode1; + + @Mock + private DiscoveryNode localNode2; + + private MLUndeployControllerNodeRequest undeployControllerNodeRequestWithStringNodeIds; + + private MLUndeployControllerNodeRequest undeployControllerNodeRequestWithDiscoveryNodeIds; + + @Before + public void setUp() throws Exception { + + String modelId = "testModelId"; + String[] stringNodeIds = { "nodeId1", "nodeId2", "nodeId3" }; + DiscoveryNode[] discoveryNodeIds = { localNode1, localNode2 }; + + undeployControllerNodeRequestWithStringNodeIds = new MLUndeployControllerNodeRequest( + new MLUndeployControllerNodesRequest(stringNodeIds, modelId)); + undeployControllerNodeRequestWithDiscoveryNodeIds = new MLUndeployControllerNodeRequest( + new MLUndeployControllerNodesRequest(discoveryNodeIds, modelId)); + + } + + @Test + public void testConstructorSerialization1() throws IOException { + BytesStreamOutput output = new BytesStreamOutput(); + undeployControllerNodeRequestWithStringNodeIds.writeTo(output); + assertEquals("testModelId", + undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId()); + + } + + @Test + public void testConstructorSerialization2() { + assertEquals(2, undeployControllerNodeRequestWithDiscoveryNodeIds.getUndeployControllerNodesRequest() + .concreteNodes().length); + + } + + @Test + public void testConstructorFromInputStream() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + undeployControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + MLUndeployControllerNodeRequest parsedNodeRequest = new MLUndeployControllerNodeRequest(streamInput); + + assertEquals(undeployControllerNodeRequestWithStringNodeIds.getUndeployControllerNodesRequest().getModelId(), + parsedNodeRequest.getUndeployControllerNodesRequest().getModelId()); + + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java similarity index 67% rename from common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java rename to common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java index 77aa947fa6..c374b85312 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployControllerNodesResponseTest.java @@ -32,7 +32,7 @@ import org.opensearch.core.xcontent.XContentBuilder; @RunWith(MockitoJUnitRunner.class) -public class MLUndeployModelControllerNodesResponseTest { +public class MLUndeployControllerNodesResponseTest { @Mock private ClusterName clusterName; @Mock @@ -49,59 +49,60 @@ public void setUp() throws Exception { new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); node2 = new DiscoveryNode( "foo2", "foo2", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), Collections.emptyMap(), Collections.singleton(CLUSTER_MANAGER_ROLE), - Version.CURRENT - ); + Version.CURRENT); } @Test public void testSerializationDeserialization1() throws IOException { - List responseList = new ArrayList<>(); + List responseList = new ArrayList<>(); List failuresList = new ArrayList<>(); - MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, responseList, failuresList); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, responseList, + failuresList); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - MLUndeployModelControllerNodesResponse newResponse = new MLUndeployModelControllerNodesResponse(output.bytes().streamInput()); + MLUndeployControllerNodesResponse newResponse = new MLUndeployControllerNodesResponse( + output.bytes().streamInput()); assertEquals(newResponse.getNodes().size(), response.getNodes().size()); } @Test public void testToXContent() throws IOException { - List nodes = new ArrayList<>(); + List nodes = new ArrayList<>(); - Map undeployModelControllerStatus1 = Map.of("modelId1", "response"); - nodes.add(new MLUndeployModelControllerNodeResponse(node1, undeployModelControllerStatus1)); + Map undeployControllerStatus1 = Map.of("modelId1", "response"); + nodes.add(new MLUndeployControllerNodeResponse(node1, undeployControllerStatus1)); - Map undeployModelControllerStatus2 = Map.of("modelId2", "response"); - nodes.add(new MLUndeployModelControllerNodeResponse(node2, undeployModelControllerStatus2)); + Map undeployControllerStatus2 = Map.of("modelId2", "response"); + nodes.add(new MLUndeployControllerNodeResponse(node2, undeployControllerStatus2)); List failures = new ArrayList<>(); - MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, nodes, failures); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, + failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); assertEquals( "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", - jsonStr - ); + jsonStr); } @Test public void testNullUpdateModelCacheStatusToXContent() throws IOException { - List nodes = new ArrayList<>(); - nodes.add(new MLUndeployModelControllerNodeResponse(node1, null)); + List nodes = new ArrayList<>(); + nodes.add(new MLUndeployControllerNodeResponse(node1, null)); List failures = new ArrayList<>(); - MLUndeployModelControllerNodesResponse response = new MLUndeployModelControllerNodesResponse(clusterName, nodes, failures); + MLUndeployControllerNodesResponse response = new MLUndeployControllerNodesResponse(clusterName, nodes, + failures); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = builder.toString(); - assertEquals("{}",jsonStr); + assertEquals("{}", jsonStr); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java deleted file mode 100644 index ad08f76408..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUndeployModelControllerNodesRequestTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.ml.common.transport.controller; - -import static org.junit.Assert.assertEquals; -import java.io.IOException; - -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnitRunner; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; - -// This test combined MLUndeployModelControllerNodesRequestTest and MLUndeployModelControllerNodeRequestTest together. -@RunWith(MockitoJUnitRunner.class) -public class MLUndeployModelControllerNodesRequestTest { - - @Mock - private DiscoveryNode localNode1; - - @Mock - private DiscoveryNode localNode2; - - private MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequestWithStringNodeIds; - - private MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequestWithDiscoveryNodeIds; - - @Before - public void setUp() throws Exception { - - String modelId = "testModelId"; - String[] stringNodeIds = {"nodeId1", "nodeId2", "nodeId3"}; - DiscoveryNode[] discoveryNodeIds = {localNode1, localNode2}; - - undeployModelControllerNodeRequestWithStringNodeIds = new MLUndeployModelControllerNodeRequest( - new MLUndeployModelControllerNodesRequest(stringNodeIds, modelId) - ); - undeployModelControllerNodeRequestWithDiscoveryNodeIds = new MLUndeployModelControllerNodeRequest( - new MLUndeployModelControllerNodesRequest(discoveryNodeIds, modelId) - ); - - } - - @Test - public void testConstructorSerialization1() throws IOException { - BytesStreamOutput output = new BytesStreamOutput(); - undeployModelControllerNodeRequestWithStringNodeIds.writeTo(output); - assertEquals("testModelId", undeployModelControllerNodeRequestWithStringNodeIds.getUndeployModelControllerNodesRequest().getModelId()); - - } - - @Test - public void testConstructorSerialization2() { - assertEquals(2, undeployModelControllerNodeRequestWithDiscoveryNodeIds.getUndeployModelControllerNodesRequest().concreteNodes().length); - - } - - @Test - public void testConstructorFromInputStream() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - undeployModelControllerNodeRequestWithStringNodeIds.writeTo(bytesStreamOutput); - - StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); - MLUndeployModelControllerNodeRequest parsedNodeRequest = new MLUndeployModelControllerNodeRequest(streamInput); - - assertEquals(undeployModelControllerNodeRequestWithStringNodeIds.getUndeployModelControllerNodesRequest().getModelId(), - parsedNodeRequest.getUndeployModelControllerNodesRequest().getModelId()); - - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java new file mode 100644 index 0000000000..73a40047d2 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateControllerRequestTest.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.controller; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.controller.MLRateLimiter; + +public class MLUpdateControllerRequestTest { + private MLController updateControllerInput; + + private MLUpdateControllerRequest request; + + @Before + public void setUp() throws Exception { + + MLRateLimiter rateLimiter = MLRateLimiter.builder() + .limit("1") + .unit(TimeUnit.MILLISECONDS) + .build(); + updateControllerInput = MLController.builder() + .modelId("testModelId") + .userRateLimiter(new HashMap<>() { + { + put("testUser", rateLimiter); + } + }) + .build(); + request = MLUpdateControllerRequest.builder() + .updateControllerInput(updateControllerInput) + .build(); + } + + @Test + public void writeToSuccess() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + request.writeTo(bytesStreamOutput); + MLUpdateControllerRequest parsedRequest = new MLUpdateControllerRequest( + bytesStreamOutput.bytes().streamInput()); + assertEquals("testModelId", parsedRequest.getUpdateControllerInput().getModelId()); + assertTrue(parsedRequest.getUpdateControllerInput().getUserRateLimiter().containsKey("testUser")); + assertEquals("1", parsedRequest.getUpdateControllerInput().getUserRateLimiter().get("testUser") + .getLimit()); + assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getUpdateControllerInput().getUserRateLimiter() + .get("testUser").getUnit()); + } + + @Test + public void validateSuccess() { + assertNull(request.validate()); + } + + @Test + public void validateWithNullMLControllerInputException() { + MLUpdateControllerRequest request = MLUpdateControllerRequest.builder().build(); + ActionRequestValidationException exception = request.validate(); + assertEquals("Validation Failed: 1: Update model controller input can't be null;", exception.getMessage()); + } + + @Test + public void validateWithNullMLModelID() { + updateControllerInput.setModelId(null); + MLUpdateControllerRequest request = MLUpdateControllerRequest.builder() + .updateControllerInput(updateControllerInput) + .build(); + + assertNull(request.validate()); + assertNull(request.getUpdateControllerInput().getModelId()); + } + + @Test + public void fromActionRequestWithMLUpdateControllerRequestSuccess() { + assertSame(MLUpdateControllerRequest.fromActionRequest(request), request); + } + + @Test + public void fromActionRequestWithNonMLUpdateControllerRequestSuccess() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + request.writeTo(out); + } + }; + MLUpdateControllerRequest result = MLUpdateControllerRequest.fromActionRequest(actionRequest); + assertNotSame(result, request); + assertEquals(request.getUpdateControllerInput().getModelId(), + result.getUpdateControllerInput().getModelId()); + } + + @Test(expected = UncheckedIOException.class) + public void fromActionRequestIOException() { + ActionRequest actionRequest = new ActionRequest() { + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new IOException("test"); + } + }; + MLUpdateControllerRequest.fromActionRequest(actionRequest); + } +} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java deleted file mode 100644 index e452cfccc9..0000000000 --- a/common/src/test/java/org/opensearch/ml/common/transport/controller/MLUpdateModelControllerRequestTest.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.ml.common.transport.controller; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.HashMap; -import java.util.concurrent.TimeUnit; - -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.controller.MLRateLimiter; - -public class MLUpdateModelControllerRequestTest { - private MLModelController updateModelControllerInput; - - private MLUpdateModelControllerRequest request; - - @Before - public void setUp() throws Exception { - - MLRateLimiter rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) - .build(); - updateModelControllerInput = MLModelController.builder() - .modelId("testModelId") - .userRateLimiterConfig(new HashMap<>() {{ - put("testUser", rateLimiter); - }}) - .build(); - request = MLUpdateModelControllerRequest.builder() - .updateModelControllerInput(updateModelControllerInput) - .build(); - } - - @Test - public void writeToSuccess() throws IOException { - BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - request.writeTo(bytesStreamOutput); - MLUpdateModelControllerRequest parsedRequest = new MLUpdateModelControllerRequest(bytesStreamOutput.bytes().streamInput()); - assertEquals("testModelId", parsedRequest.getUpdateModelControllerInput().getModelId()); - assertTrue(parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().containsKey("testUser")); - assertEquals("1", parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitNumber()); - assertEquals(TimeUnit.MILLISECONDS, parsedRequest.getUpdateModelControllerInput().getUserRateLimiterConfig().get("testUser").getRateLimitUnit()); - } - - @Test - public void validateSuccess() { - assertNull(request.validate()); - } - - @Test - public void validateWithNullMLModelControllerInputException() { - MLUpdateModelControllerRequest request = MLUpdateModelControllerRequest.builder().build(); - ActionRequestValidationException exception = request.validate(); - assertEquals("Validation Failed: 1: Update model controller input can't be null;", exception.getMessage()); - } - - @Test - public void validateWithNullMLModelID() { - updateModelControllerInput.setModelId(null); - MLUpdateModelControllerRequest request = MLUpdateModelControllerRequest.builder() - .updateModelControllerInput(updateModelControllerInput) - .build(); - - assertNull(request.validate()); - assertNull(request.getUpdateModelControllerInput().getModelId()); - } - - @Test - public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { - assertSame(MLUpdateModelControllerRequest.fromActionRequest(request), request); - } - - @Test - public void fromActionRequestWithNonMLUpdateModelControllerRequestSuccess() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - request.writeTo(out); - } - }; - MLUpdateModelControllerRequest result = MLUpdateModelControllerRequest.fromActionRequest(actionRequest); - assertNotSame(result, request); - assertEquals(request.getUpdateModelControllerInput().getModelId(), result.getUpdateModelControllerInput().getModelId()); - } - - @Test(expected = UncheckedIOException.class) - public void fromActionRequestIOException() { - ActionRequest actionRequest = new ActionRequest() { - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - throw new IOException("test"); - } - }; - MLUpdateModelControllerRequest.fromActionRequest(actionRequest); - } -} diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java index e283647f38..a53f1ee02d 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model/MLUpdateModelInputTest.java @@ -44,21 +44,29 @@ public class MLUpdateModelInputTest { private MLUpdateModelInput updateModelInput; - private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + - "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + private final String expectedInputStr = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1}"; - private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + - "\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + - "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + private final String expectedOutputStr = "{\"model_id\":null,\"name\":\"name\",\"description\":\"description\",\"model_group_id\":" + + + "\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":" + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\",\"parameters\":{},\"credential\":{}}}"; @@ -90,10 +98,9 @@ public void setUp() throws Exception { .method("POST") .url("https://api.openai.com/v1/chat/completions") .headers(Map.of("Authorization", "Bearer ${credential.api_key}")) - .requestBody("{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") - .build() - ) - ) + .requestBody( + "{ \"model\": \"${parameters.model}\", \"messages\": ${parameters.messages} }") + .build())) .build(); MLCreateConnectorInput updateContent = MLCreateConnectorInput @@ -104,8 +111,8 @@ public void setUp() throws Exception { .build(); MLRateLimiter rateLimiter = MLRateLimiter.builder() - .rateLimitNumber("1") - .rateLimitUnit(TimeUnit.MILLISECONDS) + .limit("1") + .unit(TimeUnit.MILLISECONDS) .build(); updateModelInput = MLUpdateModelInput.builder() @@ -115,14 +122,14 @@ public void setUp() throws Exception { .name("name") .description("description") .isEnabled(false) - .modelRateLimiterConfig(rateLimiter) + .rateLimiter(rateLimiter) .modelConfig(config) .updatedConnector(updatedConnector) .connectorId("test-connector_id") .connector(updateContent) .lastUpdateTime(Instant.ofEpochMilli(1)) .build(); - } + } @Test public void readInputStreamSuccess() throws IOException { @@ -148,8 +155,7 @@ public void testToXContent() throws Exception { @Test public void testToXContentIncomplete() throws Exception { - String expectedIncompleteInputStr = - "{\"model_id\":\"test-model_id\"}"; + String expectedIncompleteInputStr = "{\"model_id\":\"test-model_id\"}"; updateModelInput = MLUpdateModelInput.builder() .modelId("test-model_id").build(); String jsonStr = serializationWithToXContent(updateModelInput); @@ -166,10 +172,12 @@ public void parseSuccess() throws Exception { @Test public void parseWithNullFieldWithoutModel() throws Exception { exceptionRule.expect(IllegalStateException.class); - String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":" + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + - "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + String expectedInputStrWithNullField = "{\"model_id\":\"test-model_id\",\"name\":null,\"description\":\"description\",\"model_version\":" + + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"connector_id\":\"test-connector_id\"}"; testParseFromJsonString(expectedInputStrWithNullField, parsedInput -> { try { @@ -182,15 +190,21 @@ public void parseWithNullFieldWithoutModel() throws Exception { @Test public void parseWithIllegalFieldWithoutModel() throws Exception { - String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + - "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"model_rate_limiter_config\":" + - "{\"rate_limit_number\":\"1\",\"rate_limit_unit\":\"MILLISECONDS\"},\"model_config\":" + - "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + String expectedInputStrWithIllegalField = "{\"model_id\":\"test-model_id\",\"name\":\"name\",\"description\":\"description\",\"model_version\":" + + + "\"2\",\"model_group_id\":\"modelGroupId\",\"is_enabled\":false,\"rate_limiter\":" + + "{\"limit\":\"1\",\"unit\":\"MILLISECONDS\"},\"model_config\":" + + "{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"" + + "{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"},\"updated_connector\":" + - "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + - "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + - "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + - "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "{\"name\":\"test\",\"version\":\"1\",\"protocol\":\"http\",\"parameters\":{\"param1\":\"value1\"},\"credential\":" + + + "{\"api_key\":\"credential_value\"},\"actions\":[{\"action_type\":\"PREDICT\",\"method\":\"POST\",\"url\":" + + + "\"https://api.openai.com/v1/chat/completions\",\"headers\":{\"Authorization\":\"Bearer ${credential.api_key}\"},\"request_body\":" + + + "\"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": ${parameters.messages} }\"}]},\"connector_id\":" + + "\"test-connector_id\",\"connector\":{\"description\":\"updated description\",\"version\":\"1\"},\"last_updated_time\":1,\"illegal_field\":\"This field need to be skipped.\"}"; testParseFromJsonString(expectedInputStrWithIllegalField, parsedInput -> { try { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java index 6c25f36dfe..627985813e 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/model_group/MLModelGroupDeleteRequestTest.java @@ -33,7 +33,8 @@ public void setUp() { public void writeToSuccess() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); request.writeTo(bytesStreamOutput); - MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest(bytesStreamOutput.bytes().streamInput()); + MLModelGroupDeleteRequest parsedRequest = new MLModelGroupDeleteRequest( + bytesStreamOutput.bytes().streamInput()); assertEquals(parsedRequest.getModelGroupId(), modelGroupId); } @@ -51,7 +52,7 @@ public void validateWithNullModelIdException() { } @Test - public void fromActionRequestWithMLUpdateModelControllerRequestSuccess() { + public void fromActionRequestWithMLUpdateControllerRequestSuccess() { assertSame(MLModelGroupDeleteRequest.fromActionRequest(request), request); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 178228992e..0e8169ac64 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -54,7 +54,7 @@ public class AwsConnectorExecutor implements RemoteConnectorExecutor { private ScriptService scriptService; @Setter @Getter - private TokenBucket modelRateLimiter; + private TokenBucket rateLimiter; @Setter @Getter private Map userRateLimiterMap; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index d881707195..cc3670e5d1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -53,7 +53,7 @@ public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { @Setter @Getter - private TokenBucket modelRateLimiter; + private TokenBucket rateLimiter; @Setter @Getter private Map userRateLimiterMap; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 34575b7ce8..be50af3aff 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -82,7 +82,7 @@ default void setScriptService(ScriptService scriptService) {} Connector getConnector(); - TokenBucket getModelRateLimiter(); + TokenBucket getRateLimiter(); Map getUserRateLimiterMap(); @@ -94,7 +94,7 @@ default void setXContentRegistry(NamedXContentRegistry xContentRegistry) {} default void setClusterService(ClusterService clusterService) {} - default void setModelRateLimiter(TokenBucket modelRateLimiter) {} + default void setRateLimiter(TokenBucket rateLimiter) {} default void setUserRateLimiterMap(Map userRateLimiterMap) {} @@ -121,7 +121,7 @@ default void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List params, Encryptor encry this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE)); this.connectorExecutor.setClient((Client) params.get(CLIENT)); this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY)); - this.connectorExecutor.setModelRateLimiter((TokenBucket) params.get(MODEL_RATE_LIMITER)); + this.connectorExecutor.setRateLimiter((TokenBucket) params.get(RATE_LIMITER)); this.connectorExecutor.setUserRateLimiterMap((Map) params.get(USER_RATE_LIMITER_MAP)); } catch (RuntimeException e) { log.error("Failed to init remote model.", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java index 26fabba0c2..0cc329f1ac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndex.java @@ -14,15 +14,15 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MEMORY_META_INDEX_SCHEMA_VERSION; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX_MAPPING; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_SCHEMA_VERSION; @@ -39,7 +39,7 @@ public enum MLIndex { TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION), - MODEL_CONTROLLER(ML_MODEL_CONTROLLER_INDEX, false, ML_MODEL_CONTROLLER_INDEX_MAPPING, ML_MODEL_CONTROLLER_INDEX_SCHEMA_VERSION), + CONTROLLER(ML_CONTROLLER_INDEX, false, ML_CONTROLLER_INDEX_MAPPING, ML_CONTROLLER_INDEX_SCHEMA_VERSION), AGENT(ML_AGENT_INDEX, false, ML_AGENT_INDEX_MAPPING, ML_AGENT_INDEX_SCHEMA_VERSION), MEMORY_META(ML_MEMORY_META_INDEX, false, ML_MEMORY_META_INDEX_MAPPING, ML_MEMORY_META_INDEX_SCHEMA_VERSION), MEMORY_MESSAGE(ML_MEMORY_MESSAGE_INDEX, false, ML_MEMORY_MESSAGE_INDEX_MAPPING, ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java index 46135e5496..34b06212a7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java @@ -76,8 +76,8 @@ public void initMLConfigIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONFIG, listener); } - public void initMLModelControllerIndex(ActionListener listener) { - initMLIndexIfAbsent(MLIndex.MODEL_CONTROLLER, listener); + public void initMLControllerIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.CONTROLLER, listener); } public void initMLAgentIndex(ActionListener listener) { @@ -148,7 +148,8 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) }) ); } else { - // no need to update index if it does not exist or the version is already up-to-date. + // no need to update index if it does not exist or the version is already + // up-to-date. indexMappingUpdated.get(indexName).set(true); internalListener.onResponse(true); } @@ -169,9 +170,11 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) /** * Check if we should update index based on schema version. - * @param indexName index name + * + * @param indexName index name * @param newVersion new index mapping version - * @param listener action listener, if should update index, will pass true to its onResponse method + * @param listener action listener, if should update index, will pass true to + * its onResponse method */ public void shouldUpdateIndex(String indexName, Integer newVersion, ActionListener listener) { IndexMetadata indexMetaData = clusterService.state().getMetadata().indices().get(indexName); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java similarity index 64% rename from plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java index ece13bcd3b..9a9c97b23c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/CreateModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/CreateControllerTransportAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; @@ -15,6 +15,7 @@ import java.util.List; import java.util.Map; +import org.apache.commons.lang3.ArrayUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; @@ -37,14 +38,14 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerResponse; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.common.transport.controller.MLCreateControllerResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; @@ -59,7 +60,7 @@ @Log4j2 @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) -public class CreateModelControllerTransportAction extends HandledTransportAction { +public class CreateControllerTransportAction extends HandledTransportAction { MLIndicesHandler mlIndicesHandler; Client client; MLModelManager mlModelManager; @@ -68,7 +69,7 @@ public class CreateModelControllerTransportAction extends HandledTransportAction ModelAccessControlHelper modelAccessControlHelper; @Inject - public CreateModelControllerTransportAction( + public CreateControllerTransportAction( TransportService transportService, ActionFilters actionFilters, MLIndicesHandler mlIndicesHandler, @@ -78,7 +79,7 @@ public CreateModelControllerTransportAction( MLModelCacheHelper mlModelCacheHelper, MLModelManager mlModelManager ) { - super(MLCreateModelControllerAction.NAME, transportService, actionFilters, MLCreateModelControllerRequest::new); + super(MLCreateControllerAction.NAME, transportService, actionFilters, MLCreateControllerRequest::new); this.mlIndicesHandler = mlIndicesHandler; this.client = client; this.mlModelManager = mlModelManager; @@ -88,15 +89,15 @@ public CreateModelControllerTransportAction( } @Override - protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { - MLCreateModelControllerRequest createModelControllerRequest = MLCreateModelControllerRequest.fromActionRequest(request); - MLModelController modelController = createModelControllerRequest.getModelControllerInput(); - String modelId = modelController.getModelId(); + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLCreateControllerRequest createControllerRequest = MLCreateControllerRequest.fromActionRequest(request); + MLController controller = createControllerRequest.getControllerInput(); + String modelId = controller.getModelId(); User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); if (functionName == TEXT_EMBEDDING || functionName == REMOTE) { @@ -104,7 +105,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { if (mlModel.getModelState() != MLModelState.DEPLOYING) { - indexAndCreateModelController(mlModel, modelController, wrappedListener); + indexAndCreateController(mlModel, controller, wrappedListener); } else { wrappedListener .onFailure( @@ -166,13 +167,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener + MLController controller, + ActionListener actionListener ) { - log.info("Indexing the model controller into system index"); - mlIndicesHandler.initMLModelControllerIndex(ActionListener.wrap(indexCreated -> { + mlIndicesHandler.initMLControllerIndex(ActionListener.wrap(indexCreated -> { if (!indexCreated) { actionListener.onFailure(new RuntimeException("Failed to create model controller index.")); return; @@ -180,61 +180,53 @@ private void indexAndCreateModelController( try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener indexResponseListener = ActionListener.wrap(indexResponse -> { String modelId = indexResponse.getId(); - MLCreateModelControllerResponse response = new MLCreateModelControllerResponse( - modelId, - indexResponse.getResult().name() - ); + MLCreateControllerResponse response = new MLCreateControllerResponse(modelId, indexResponse.getResult().name()); log.info("Model controller for model id {} saved into index, result:{}", modelId, indexResponse.getResult()); if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) { - mlModelManager.updateModel(modelId, Map.of(MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD, true)); + mlModelManager.updateModel(modelId, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, true)); } - if (mlModelCacheHelper.isModelDeployed(modelId)) { + if (!ArrayUtils.isEmpty(mlModelCacheHelper.getWorkerNodes(modelId))) { log.info("Model {} is deployed. Start to deploy the model controller into cache.", modelId); String[] targetNodeIds = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); - MLDeployModelControllerNodesRequest deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest( + MLDeployControllerNodesRequest deployControllerNodesRequest = new MLDeployControllerNodesRequest( targetNodeIds, - modelController.getModelId() + controller.getModelId() ); client - .execute( - MLDeployModelControllerAction.INSTANCE, - deployModelControllerNodesRequest, - ActionListener.wrap(nodesResponse -> { - if (nodesResponse != null && isDeployModelControllerSuccessOnAllNodes(nodesResponse)) { - log.info("Successfully create model controller and deploy it into cache with model ID {}", modelId); - actionListener.onResponse(response); - } else { - String[] nodeIds = getDeployModelControllerFailedNodesList(nodesResponse); - log - .error( - "Successfully create model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", - modelId, - Arrays.toString(nodeIds) - ); - actionListener - .onFailure( - new RuntimeException( - "Successfully create model controller index with model ID " - + modelId - + " but deploy model controller to cache was failed on following nodes " - + Arrays.toString(nodeIds) - + ", please retry." - ) - ); - } - }, e -> { - log.error("Failed to deploy model controller for model: {}" + modelId, e); - actionListener.onFailure(e); - }) - ); + .execute(MLDeployControllerAction.INSTANCE, deployControllerNodesRequest, ActionListener.wrap(nodesResponse -> { + if (nodesResponse != null && isDeployControllerSuccessOnAllNodes(nodesResponse)) { + log.info("Successfully create model controller and deploy it into cache with model ID {}", modelId); + actionListener.onResponse(response); + } else { + String[] nodeIds = getDeployControllerFailedNodesList(nodesResponse); + log + .error( + "Successfully create model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", + modelId, + Arrays.toString(nodeIds) + ); + actionListener + .onFailure( + new RuntimeException( + "Successfully create model controller index with model ID " + + modelId + + " but deploy model controller to cache was failed on following nodes " + + Arrays.toString(nodeIds) + + ", please retry." + ) + ); + } + }, e -> { + log.error("Failed to deploy model controller for model: {}" + modelId, e); + actionListener.onFailure(e); + })); } else { actionListener.onResponse(response); } }, actionListener::onFailure); - IndexRequest indexRequest = new IndexRequest(ML_MODEL_CONTROLLER_INDEX).id(modelController.getModelId()); - indexRequest - .source(modelController.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); + IndexRequest indexRequest = new IndexRequest(ML_CONTROLLER_INDEX).id(controller.getModelId()); + indexRequest.source(controller.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)); indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(indexRequest, ActionListener.runBefore(indexResponseListener, context::restore)); } catch (Exception e) { @@ -247,16 +239,16 @@ private void indexAndCreateModelController( })); } - private boolean isDeployModelControllerSuccessOnAllNodes(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { - return deployModelControllerNodesResponse.failures() == null || deployModelControllerNodesResponse.failures().isEmpty(); + private boolean isDeployControllerSuccessOnAllNodes(MLDeployControllerNodesResponse deployControllerNodesResponse) { + return deployControllerNodesResponse.failures() == null || deployControllerNodesResponse.failures().isEmpty(); } - private String[] getDeployModelControllerFailedNodesList(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { - if (deployModelControllerNodesResponse == null) { + private String[] getDeployControllerFailedNodesList(MLDeployControllerNodesResponse deployControllerNodesResponse) { + if (deployControllerNodesResponse == null) { return getAllNodes(); } else { List nodeIds = new ArrayList<>(); - for (FailedNodeException failedNodeException : deployModelControllerNodesResponse.failures()) { + for (FailedNodeException failedNodeException : deployControllerNodesResponse.failures()) { nodeIds.add(failedNodeException.nodeId()); } return nodeIds.toArray(new String[0]); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java similarity index 75% rename from plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java index 97c002e3f2..80c9b7568a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeleteControllerTransportAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import java.util.ArrayList; import java.util.Arrays; @@ -13,6 +13,7 @@ import java.util.List; import java.util.Map; +import org.apache.commons.lang3.ArrayUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.FailedNodeException; @@ -30,11 +31,11 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -48,7 +49,7 @@ @Log4j2 @FieldDefaults(level = AccessLevel.PRIVATE) -public class DeleteModelControllerTransportAction extends HandledTransportAction { +public class DeleteControllerTransportAction extends HandledTransportAction { Client client; NamedXContentRegistry xContentRegistry; ClusterService clusterService; @@ -57,7 +58,7 @@ public class DeleteModelControllerTransportAction extends HandledTransportAction ModelAccessControlHelper modelAccessControlHelper; @Inject - public DeleteModelControllerTransportAction( + public DeleteControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, @@ -67,7 +68,7 @@ public DeleteModelControllerTransportAction( MLModelCacheHelper mlModelCacheHelper, ModelAccessControlHelper modelAccessControlHelper ) { - super(MLModelControllerDeleteAction.NAME, transportService, actionFilters, MLModelControllerDeleteRequest::new); + super(MLControllerDeleteAction.NAME, transportService, actionFilters, MLControllerDeleteRequest::new); this.client = client; this.xContentRegistry = xContentRegistry; this.clusterService = clusterService; @@ -78,8 +79,8 @@ public DeleteModelControllerTransportAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { - MLModelControllerDeleteRequest modelControllerDeleteRequest = MLModelControllerDeleteRequest.fromActionRequest(request); - String modelId = modelControllerDeleteRequest.getModelId(); + MLControllerDeleteRequest controllerDeleteRequest = MLControllerDeleteRequest.fromActionRequest(request); + String modelId = controllerDeleteRequest.getModelId(); User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -89,11 +90,11 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { mlModelManager - .getModelController( + .getController( modelId, ActionListener .wrap( - modelController -> deleteModelControllerWithDeployedModel(modelId, wrappedListener), + controller -> deleteControllerWithDeployedModel(modelId, wrappedListener), deleteException -> { log.error(deleteException); wrappedListener.onFailure(deleteException); @@ -126,13 +127,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener deleteModelControllerWithDeployedModel(modelId, wrappedListener), deleteException -> { - log.error(deleteException); - wrappedListener.onFailure(deleteException); - }) + ActionListener.wrap(controller -> deleteControllerWithDeployedModel(modelId, wrappedListener), deleteException -> { + log.error(deleteException); + wrappedListener.onFailure(deleteException); + }) ); })); } catch (Exception e) { @@ -141,30 +141,31 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + // This method is used to handle the condition if we need to undeploy the model + // controller before deleting it from the index or not. + private void deleteControllerWithDeployedModel(String modelId, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (mlModelCacheHelper.isModelDeployed(modelId)) { + if (!ArrayUtils.isEmpty(mlModelCacheHelper.getWorkerNodes(modelId))) { log.info("Model has already been deployed in ML cache, need undeploy model controller before sending delete request."); String[] targetNodeIds = getAllNodes(); - MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest = new MLUndeployModelControllerNodesRequest( + MLUndeployControllerNodesRequest undeployControllerNodesRequest = new MLUndeployControllerNodesRequest( targetNodeIds, modelId ); client .execute( - MLUndeployModelControllerAction.INSTANCE, - undeployModelControllerNodesRequest, + MLUndeployControllerAction.INSTANCE, + undeployControllerNodesRequest, ActionListener.runBefore(ActionListener.wrap(nodesResponse -> { - if (nodesResponse != null && isUndeployModelControllerSuccessOnAllNodes(nodesResponse)) { + if (nodesResponse != null && isUndeployControllerSuccessOnAllNodes(nodesResponse)) { log .info( "Successfully undeploy model controller from cache. Start to delete the model controller for model {}", modelId ); - deleteModelController(modelId, actionListener); + deleteController(modelId, actionListener); } else { - String[] nodeIds = getUndeployModelControllerFailedNodesList(nodesResponse); + String[] nodeIds = getUndeployControllerFailedNodesList(nodesResponse); log .error( "Failed to undeploy model controller with model ID {} on following nodes {}, deletion is aborted. Please retry or undeploy the model manually and then perform the deletion.", @@ -193,7 +194,7 @@ private void deleteModelControllerWithDeployedModel(String modelId, ActionListen }), context::restore) ); } else { - deleteModelController(modelId, actionListener); + deleteController(modelId, actionListener); } } catch (Exception e) { log.error("Failed to delete model controller", e); @@ -201,13 +202,13 @@ private void deleteModelControllerWithDeployedModel(String modelId, ActionListen } } - private void deleteModelController(String modelId, ActionListener actionListener) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_CONTROLLER_INDEX, modelId); + private void deleteController(String modelId, ActionListener actionListener) { + DeleteRequest deleteRequest = new DeleteRequest(ML_CONTROLLER_INDEX, modelId); client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); - mlModelManager.updateModel(modelId, Map.of(MLModel.IS_MODEL_CONTROLLER_ENABLED_FIELD, false)); + mlModelManager.updateModel(modelId, Map.of(MLModel.IS_CONTROLLER_ENABLED_FIELD, false)); actionListener.onResponse(deleteResponse); } @@ -219,20 +220,16 @@ public void onFailure(Exception e) { }); } - private boolean isUndeployModelControllerSuccessOnAllNodes( - MLUndeployModelControllerNodesResponse undeployModelControllerNodesResponse - ) { - return undeployModelControllerNodesResponse.failures() == null || undeployModelControllerNodesResponse.failures().isEmpty(); + private boolean isUndeployControllerSuccessOnAllNodes(MLUndeployControllerNodesResponse undeployControllerNodesResponse) { + return undeployControllerNodesResponse.failures() == null || undeployControllerNodesResponse.failures().isEmpty(); } - private String[] getUndeployModelControllerFailedNodesList( - MLUndeployModelControllerNodesResponse undeployModelControllerNodesResponse - ) { - if (undeployModelControllerNodesResponse == null) { + private String[] getUndeployControllerFailedNodesList(MLUndeployControllerNodesResponse undeployControllerNodesResponse) { + if (undeployControllerNodesResponse == null) { return getAllNodes(); } else { List nodeIds = new ArrayList<>(); - for (FailedNodeException failedNodeException : undeployModelControllerNodesResponse.failures()) { + for (FailedNodeException failedNodeException : undeployControllerNodesResponse.failures()) { nodeIds.add(failedNodeException.nodeId()); } return nodeIds.toArray(new String[0]); diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/DeployControllerTransportAction.java similarity index 53% rename from plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/DeployControllerTransportAction.java index 9744cdc917..37ef462c19 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/DeployModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/DeployControllerTransportAction.java @@ -20,11 +20,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeResponse; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; @@ -34,8 +34,8 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -public class DeployModelControllerTransportAction extends - TransportNodesAction { +public class DeployControllerTransportAction extends + TransportNodesAction { private final MLModelManager mlModelManager; private final ClusterService clusterService; @@ -47,7 +47,7 @@ public class DeployModelControllerTransportAction extends private ModelAccessControlHelper modelAccessControlHelper; @Inject - public DeployModelControllerTransportAction( + public DeployControllerTransportAction( TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager, @@ -60,15 +60,15 @@ public DeployModelControllerTransportAction( ModelAccessControlHelper modelAccessControlHelper ) { super( - MLDeployModelControllerAction.NAME, + MLDeployControllerAction.NAME, threadPool, clusterService, transportService, actionFilters, - MLDeployModelControllerNodesRequest::new, - MLDeployModelControllerNodeRequest::new, + MLDeployControllerNodesRequest::new, + MLDeployControllerNodeRequest::new, ThreadPool.Names.MANAGEMENT, - MLDeployModelControllerNodeResponse.class + MLDeployControllerNodeResponse.class ); this.mlModelManager = mlModelManager; this.clusterService = clusterService; @@ -80,42 +80,40 @@ public DeployModelControllerTransportAction( } @Override - protected MLDeployModelControllerNodesResponse newResponse( - MLDeployModelControllerNodesRequest request, - List responses, + protected MLDeployControllerNodesResponse newResponse( + MLDeployControllerNodesRequest request, + List responses, List failures ) { - return new MLDeployModelControllerNodesResponse(clusterService.getClusterName(), responses, failures); + return new MLDeployControllerNodesResponse(clusterService.getClusterName(), responses, failures); } @Override - protected MLDeployModelControllerNodeRequest newNodeRequest(MLDeployModelControllerNodesRequest request) { - return new MLDeployModelControllerNodeRequest(request); + protected MLDeployControllerNodeRequest newNodeRequest(MLDeployControllerNodesRequest request) { + return new MLDeployControllerNodeRequest(request); } @Override - protected MLDeployModelControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { - return new MLDeployModelControllerNodeResponse(in); + protected MLDeployControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLDeployControllerNodeResponse(in); } @Override - protected MLDeployModelControllerNodeResponse nodeOperation(MLDeployModelControllerNodeRequest request) { - return createDeployModelControllerNodeResponse(request.getDeployModelControllerNodesRequest()); + protected MLDeployControllerNodeResponse nodeOperation(MLDeployControllerNodeRequest request) { + return createDeployControllerNodeResponse(request.getDeployControllerNodesRequest()); } - private MLDeployModelControllerNodeResponse createDeployModelControllerNodeResponse( - MLDeployModelControllerNodesRequest deployModelControllerNodesRequest - ) { - String modelId = deployModelControllerNodesRequest.getModelId(); + private MLDeployControllerNodeResponse createDeployControllerNodeResponse(MLDeployControllerNodesRequest deployControllerNodesRequest) { + String modelId = deployControllerNodesRequest.getModelId(); - Map modelControllerDeployStatus = new HashMap<>(); - modelControllerDeployStatus.put(modelId, "received"); + Map controllerDeployStatus = new HashMap<>(); + controllerDeployStatus.put(modelId, "received"); String localNodeId = clusterService.localNode().getId(); - mlModelManager.deployModelControllerWithDeployedModel(modelId, ActionListener.wrap(r -> { + mlModelManager.deployControllerWithDeployedModel(modelId, ActionListener.wrap(r -> { log.info("Successfully deployed model controller for model {} on node {}", modelId, localNodeId); }, e -> { log.error("Failed to deploy model controller for model {} on node {}", modelId, localNodeId, e); })); - return new MLDeployModelControllerNodeResponse(clusterService.localNode(), modelControllerDeployStatus); + return new MLDeployControllerNodeResponse(clusterService.localNode(), controllerDeployStatus); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java similarity index 81% rename from plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java index ddbf69b0f6..6cc7ae9e59 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/GetModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/GetControllerTransportAction.java @@ -6,7 +6,7 @@ package org.opensearch.ml.action.controller; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; @@ -26,10 +26,10 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.utils.RestActionUtils; @@ -43,7 +43,7 @@ @Log4j2 @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) -public class GetModelControllerTransportAction extends HandledTransportAction { +public class GetControllerTransportAction extends HandledTransportAction { Client client; NamedXContentRegistry xContentRegistry; ClusterService clusterService; @@ -51,7 +51,7 @@ public class GetModelControllerTransportAction extends HandledTransportAction actionListener) { - MLModelControllerGetRequest modelControllerGetRequest = MLModelControllerGetRequest.fromActionRequest(request); - String modelId = modelControllerGetRequest.getModelId(); - FetchSourceContext fetchSourceContext = getFetchSourceContext(modelControllerGetRequest.isReturnContent()); - GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); + protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + MLControllerGetRequest controllerGetRequest = MLControllerGetRequest.fromActionRequest(request); + String modelId = controllerGetRequest.getModelId(); + FetchSourceContext fetchSourceContext = getFetchSourceContext(controllerGetRequest.isReturnContent()); + GetRequest getRequest = new GetRequest(ML_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchSourceContext); User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); + ActionListener wrappedListener = ActionListener.runBefore(actionListener, context::restore); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelController modelController = MLModelController.parse(parser); + MLController controller = MLController.parse(parser); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { modelAccessControlHelper .validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(hasPermission -> { if (hasPermission) { - wrappedListener - .onResponse(MLModelControllerGetResponse.builder().modelController(modelController).build()); + wrappedListener.onResponse(MLControllerGetResponse.builder().controller(controller).build()); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployControllerTransportAction.java similarity index 52% rename from plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/UndeployControllerTransportAction.java index eb8bc04d8c..db175e0888 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UndeployControllerTransportAction.java @@ -20,11 +20,11 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeResponse; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; @@ -34,8 +34,8 @@ import lombok.extern.log4j.Log4j2; @Log4j2 -public class UndeployModelControllerTransportAction extends - TransportNodesAction { +public class UndeployControllerTransportAction extends + TransportNodesAction { private final MLModelManager mlModelManager; private final ClusterService clusterService; @@ -47,7 +47,7 @@ public class UndeployModelControllerTransportAction extends private ModelAccessControlHelper modelAccessControlHelper; @Inject - public UndeployModelControllerTransportAction( + public UndeployControllerTransportAction( TransportService transportService, ActionFilters actionFilters, MLModelManager mlModelManager, @@ -60,15 +60,15 @@ public UndeployModelControllerTransportAction( ModelAccessControlHelper modelAccessControlHelper ) { super( - MLUndeployModelControllerAction.NAME, + MLUndeployControllerAction.NAME, threadPool, clusterService, transportService, actionFilters, - MLUndeployModelControllerNodesRequest::new, - MLUndeployModelControllerNodeRequest::new, + MLUndeployControllerNodesRequest::new, + MLUndeployControllerNodeRequest::new, ThreadPool.Names.MANAGEMENT, - MLUndeployModelControllerNodeResponse.class + MLUndeployControllerNodeResponse.class ); this.mlModelManager = mlModelManager; this.clusterService = clusterService; @@ -80,42 +80,42 @@ public UndeployModelControllerTransportAction( } @Override - protected MLUndeployModelControllerNodesResponse newResponse( - MLUndeployModelControllerNodesRequest request, - List responses, + protected MLUndeployControllerNodesResponse newResponse( + MLUndeployControllerNodesRequest request, + List responses, List failures ) { - return new MLUndeployModelControllerNodesResponse(clusterService.getClusterName(), responses, failures); + return new MLUndeployControllerNodesResponse(clusterService.getClusterName(), responses, failures); } @Override - protected MLUndeployModelControllerNodeRequest newNodeRequest(MLUndeployModelControllerNodesRequest request) { - return new MLUndeployModelControllerNodeRequest(request); + protected MLUndeployControllerNodeRequest newNodeRequest(MLUndeployControllerNodesRequest request) { + return new MLUndeployControllerNodeRequest(request); } @Override - protected MLUndeployModelControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { - return new MLUndeployModelControllerNodeResponse(in); + protected MLUndeployControllerNodeResponse newNodeResponse(StreamInput in) throws IOException { + return new MLUndeployControllerNodeResponse(in); } @Override - protected MLUndeployModelControllerNodeResponse nodeOperation(MLUndeployModelControllerNodeRequest request) { - return createUndeployModelControllerNodeResponse(request.getUndeployModelControllerNodesRequest()); + protected MLUndeployControllerNodeResponse nodeOperation(MLUndeployControllerNodeRequest request) { + return createUndeployControllerNodeResponse(request.getUndeployControllerNodesRequest()); } - private MLUndeployModelControllerNodeResponse createUndeployModelControllerNodeResponse( - MLUndeployModelControllerNodesRequest undeployModelControllerNodesRequest + private MLUndeployControllerNodeResponse createUndeployControllerNodeResponse( + MLUndeployControllerNodesRequest undeployControllerNodesRequest ) { - String modelId = undeployModelControllerNodesRequest.getModelId(); + String modelId = undeployControllerNodesRequest.getModelId(); - Map modelControllerUndeployStatus = new HashMap<>(); - modelControllerUndeployStatus.put(modelId, "received"); + Map controllerUndeployStatus = new HashMap<>(); + controllerUndeployStatus.put(modelId, "received"); String localNodeId = clusterService.localNode().getId(); - mlModelManager.undeployModelController(modelId, ActionListener.wrap(r -> { + mlModelManager.undeployController(modelId, ActionListener.wrap(r -> { log.info("Successfully undeployed model controller for model {} on node {}", modelId, localNodeId); }, e -> { log.error("Failed to undeploy model controller for model {} on node {}", modelId, localNodeId, e); })); - return new MLUndeployModelControllerNodeResponse(clusterService.localNode(), modelControllerUndeployStatus); + return new MLUndeployControllerNodeResponse(clusterService.localNode(), controllerUndeployStatus); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java similarity index 66% rename from plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java rename to plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java index f5ce76066f..5c0c7e4c59 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/controller/UpdateControllerTransportAction.java @@ -5,7 +5,7 @@ package org.opensearch.ml.action.controller; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.FunctionName.REMOTE; import static org.opensearch.ml.common.FunctionName.TEXT_EMBEDDING; @@ -14,6 +14,7 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang3.ArrayUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.DocWriteResponse; @@ -35,12 +36,12 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -54,7 +55,7 @@ @Log4j2 @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) -public class UpdateModelControllerTransportAction extends HandledTransportAction { +public class UpdateControllerTransportAction extends HandledTransportAction { Client client; MLModelManager mlModelManager; MLModelCacheHelper mlModelCacheHelper; @@ -62,7 +63,7 @@ public class UpdateModelControllerTransportAction extends HandledTransportAction ModelAccessControlHelper modelAccessControlHelper; @Inject - public UpdateModelControllerTransportAction( + public UpdateControllerTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, @@ -71,7 +72,7 @@ public UpdateModelControllerTransportAction( MLModelCacheHelper mlModelCacheHelper, MLModelManager mlModelManager ) { - super(MLUpdateModelControllerAction.NAME, transportService, actionFilters, MLUpdateModelControllerRequest::new); + super(MLUpdateControllerAction.NAME, transportService, actionFilters, MLUpdateControllerRequest::new); this.client = client; this.mlModelManager = mlModelManager; this.clusterService = clusterService; @@ -81,9 +82,9 @@ public UpdateModelControllerTransportAction( @Override protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { - MLUpdateModelControllerRequest updateModelControllerRequest = MLUpdateModelControllerRequest.fromActionRequest(request); - MLModelController updateModelControllerInput = updateModelControllerRequest.getUpdateModelControllerInput(); - String modelId = updateModelControllerInput.getModelId(); + MLUpdateControllerRequest updateControllerRequest = MLUpdateControllerRequest.fromActionRequest(request); + MLController updateControllerInput = updateControllerRequest.getUpdateControllerInput(); + String modelId = updateControllerInput.getModelId(); User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; @@ -95,13 +96,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (hasPermission) { - mlModelManager.getModelController(modelId, ActionListener.wrap(modelController -> { - boolean isDeployRequiredAfterUpdate = modelController - .isDeployRequiredAfterUpdate(updateModelControllerInput); - modelController.update(updateModelControllerInput); - updateModelController(mlModel, modelController, isDeployRequiredAfterUpdate, wrappedListener); + mlModelManager.getController(modelId, ActionListener.wrap(controller -> { + boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput); + controller.update(updateControllerInput); + updateController(mlModel, controller, isDeployRequiredAfterUpdate, wrappedListener); }, e -> { - if (mlModel.getIsModelControllerEnabled() == null || !mlModel.getIsModelControllerEnabled()) { + if (mlModel.getIsControllerEnabled() == null || !mlModel.getIsControllerEnabled()) { wrappedListener .onFailure( new OpenSearchStatusException( @@ -161,9 +161,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener ) { @@ -177,54 +177,51 @@ private void updateModelController( modelId, updateResponse.getResult() ); - if (mlModelCacheHelper.isModelDeployed(modelId) && isDeployRequiredAfterUpdate) { + if (!ArrayUtils.isEmpty(mlModelCacheHelper.getWorkerNodes(modelId)) && isDeployRequiredAfterUpdate) { log .info( "Model {} is deployed and the user rate limiter config is constructable. Start to deploy the model controller into cache.", modelId ); String[] targetNodeIds = mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm()); - MLDeployModelControllerNodesRequest deployModelControllerNodesRequest = new MLDeployModelControllerNodesRequest( + MLDeployControllerNodesRequest deployControllerNodesRequest = new MLDeployControllerNodesRequest( targetNodeIds, modelId ); client - .execute( - MLDeployModelControllerAction.INSTANCE, - deployModelControllerNodesRequest, - ActionListener.wrap(nodesResponse -> { - if (nodesResponse != null && isDeployModelControllerSuccessOnAllNodes(nodesResponse)) { - log.info("Successfully update model controller and deploy it into cache with model ID {}", modelId); - actionListener.onResponse(updateResponse); - } else { - String[] nodeIds = getDeployModelControllerFailedNodesList(nodesResponse); - log - .error( - "Successfully update model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", - modelId, - Arrays.toString(nodeIds) - ); - actionListener - .onFailure( - new RuntimeException( - "Successfully update model controller index with model ID " - + modelId - + " but deploy model controller to cache was failed on following nodes " - + Arrays.toString(nodeIds) - + ", please retry." - ) - ); - } - }, e -> { - log.error("Failed to deploy model controller for model: {}" + modelId, e); - actionListener.onFailure(e); - }) - ); + .execute(MLDeployControllerAction.INSTANCE, deployControllerNodesRequest, ActionListener.wrap(nodesResponse -> { + if (nodesResponse != null && isDeployControllerSuccessOnAllNodes(nodesResponse)) { + log.info("Successfully update model controller and deploy it into cache with model ID {}", modelId); + actionListener.onResponse(updateResponse); + } else { + String[] nodeIds = getDeployControllerFailedNodesList(nodesResponse); + log + .error( + "Successfully update model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", + modelId, + Arrays.toString(nodeIds) + ); + actionListener + .onFailure( + new RuntimeException( + "Successfully update model controller index with model ID " + + modelId + + " but deploy model controller to cache was failed on following nodes " + + Arrays.toString(nodeIds) + + ", please retry." + ) + ); + } + }, e -> { + log.error("Failed to deploy model controller for model: {}" + modelId, e); + actionListener.onFailure(e); + })); } else { actionListener.onResponse(updateResponse); } } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { - // The update response returned an unexpected status may indicate a failed update + // The update response returned an unexpected status may indicate a failed + // update log .warn( "Update model controller for model {} got a result status other than update, result status: {}", @@ -237,8 +234,8 @@ private void updateModelController( actionListener.onFailure(new RuntimeException("Failed to update model controller with model ID: " + modelId)); } }, actionListener::onFailure); - UpdateRequest updateRequest = new UpdateRequest(ML_MODEL_CONTROLLER_INDEX, modelId); - updateRequest.doc(modelController.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); + UpdateRequest updateRequest = new UpdateRequest(ML_CONTROLLER_INDEX, modelId); + updateRequest.doc(controller.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)); updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.update(updateRequest, ActionListener.runBefore(updateResponseListener, context::restore)); } catch (Exception e) { @@ -247,16 +244,16 @@ private void updateModelController( } } - private boolean isDeployModelControllerSuccessOnAllNodes(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { - return deployModelControllerNodesResponse.failures() == null || deployModelControllerNodesResponse.failures().isEmpty(); + private boolean isDeployControllerSuccessOnAllNodes(MLDeployControllerNodesResponse deployControllerNodesResponse) { + return deployControllerNodesResponse.failures() == null || deployControllerNodesResponse.failures().isEmpty(); } - private String[] getDeployModelControllerFailedNodesList(MLDeployModelControllerNodesResponse deployModelControllerNodesResponse) { - if (deployModelControllerNodesResponse == null) { + private String[] getDeployControllerFailedNodesList(MLDeployControllerNodesResponse deployControllerNodesResponse) { + if (deployControllerNodesResponse == null) { return getAllNodes(); } else { List nodeIds = new ArrayList<>(); - for (FailedNodeException failedNodeException : deployModelControllerNodesResponse.failures()) { + for (FailedNodeException failedNodeException : deployControllerNodesResponse.failures()) { nodeIds.add(failedNodeException.nodeId()); } return nodeIds.toArray(new String[0]); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 4c06a0391c..69d8c7dbfd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,7 +6,7 @@ package org.opensearch.ml.action.models; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; @@ -34,6 +34,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -217,7 +218,7 @@ private void deleteModel(String modelId, ActionListener actionLi @Override public void onResponse(DeleteResponse deleteResponse) { deleteModelChunks(modelId, deleteResponse, actionListener); - deleteModelController(modelId); + deleteController(modelId); } @Override @@ -225,7 +226,7 @@ public void onFailure(Exception e) { log.error("Failed to delete model meta data for model: " + modelId, e); if (e instanceof ResourceNotFoundException) { deleteModelChunks(modelId, null, actionListener); - deleteModelController(modelId); + deleteController(modelId); } actionListener.onFailure(e); } @@ -233,12 +234,13 @@ public void onFailure(Exception e) { } /** - * Delete the model controller for a model after the model is deleted from the ML index. + * Delete the model controller for a model after the model is deleted from the + * ML index. * * @param modelId model ID */ - private void deleteModelController(String modelId, ActionListener actionListener) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_CONTROLLER_INDEX, modelId); + private void deleteController(String modelId, ActionListener actionListener) { + DeleteRequest deleteRequest = new DeleteRequest(ML_CONTROLLER_INDEX, modelId); client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { @@ -248,19 +250,25 @@ public void onResponse(DeleteResponse deleteResponse) { @Override public void onFailure(Exception e) { - log.error("Failed to delete model controller for model: " + modelId, e); - actionListener.onFailure(e); + if (e instanceof IndexNotFoundException) { + log.info("Model controller not deleted due to no model controller was found for model: " + modelId); + actionListener.onFailure(e); + } else { + log.error("Failed to delete model controller for model: " + modelId, e); + actionListener.onFailure(e); + } } }); } /** - * Delete the model controller for a model after the model is deleted from the ML index with build-in listener. + * Delete the model controller for a model after the model is deleted from the + * ML index with build-in listener. * * @param modelId model ID */ - private void deleteModelController(String modelId) { - deleteModelController(modelId, ActionListener.wrap(deleteResponse -> { + private void deleteController(String modelId) { + deleteController(modelId, ActionListener.wrap(deleteResponse -> { if (deleteResponse.getResult() == DocWriteResponse.Result.DELETED) { log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); } else { diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java index ccfb206aa9..d520977715 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/UpdateModelTransportAction.java @@ -211,10 +211,9 @@ private void updateRemoteOrTextEmbeddingModel( boolean isPredictorUpdate = (updateModelInput.getConnector() != null) || (newConnectorId != null) || !Objects.equals(updateModelInput.getIsEnabled(), mlModel.getIsEnabled()); - if (MLRateLimiter.updateValidityPreCheck(mlModel.getModelRateLimiterConfig(), updateModelInput.getModelRateLimiterConfig())) { - MLRateLimiter updatedRateLimiterConfig = MLRateLimiter - .update(mlModel.getModelRateLimiterConfig(), updateModelInput.getModelRateLimiterConfig()); - updateModelInput.setModelRateLimiterConfig(updatedRateLimiterConfig); + if (MLRateLimiter.updateValidityPreCheck(mlModel.getRateLimiter(), updateModelInput.getRateLimiter())) { + MLRateLimiter updatedRateLimiterConfig = MLRateLimiter.update(mlModel.getRateLimiter(), updateModelInput.getRateLimiter()); + updateModelInput.setRateLimiter(updatedRateLimiterConfig); // An un-constructable updatedRateLimiterConfig does not require predictor to be re-deployed. isPredictorUpdate = isPredictorUpdate || (updatedRateLimiterConfig.isValid()); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 3c5321697c..d0f3ea7391 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -107,12 +107,11 @@ public void onResponse(MLModel mlModel) { ); } else { if (modelCacheHelper.getIsModelEnabled(modelId) != null && !modelCacheHelper.getIsModelEnabled(modelId)) { - wrappedListener - .onFailure(new OpenSearchStatusException("Quota is depleted.", RestStatus.TOO_MANY_REQUESTS)); + wrappedListener.onFailure(new OpenSearchStatusException("Model is disabled.", RestStatus.FORBIDDEN)); } else { if (FunctionName.isDLModel(functionName)) { - if (modelCacheHelper.getModelRateLimiter(modelId) != null - && !modelCacheHelper.getModelRateLimiter(modelId).request()) { + if (modelCacheHelper.getRateLimiter(modelId) != null + && !modelCacheHelper.getRateLimiter(modelId).request()) { wrappedListener .onFailure( new OpenSearchStatusException( diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index a54247e359..16ed1be826 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -35,7 +35,7 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) FunctionName functionName; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Predictable predictor; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor; - private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) TokenBucket modelRateLimiter; + private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) TokenBucket rateLimiter; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Map userRateLimiterMap; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Boolean isModelEnabled; private final Set targetWorkerNodes; @@ -113,6 +113,7 @@ public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) * New ml node joins cluster, the new node will not be deployed with model, but in Cron job the new node will be regards as * a planning worker node and the model status is PARTIALLY_DEPLOYED, if we don't update here, the model status in model index * and profile API will be not consistent. + * * @param nodeId */ public void addWorkerNode(String nodeId) { @@ -163,7 +164,7 @@ public void clear() { executor.close(); } isModelEnabled = null; - modelRateLimiter = null; + rateLimiter = null; userRateLimiterMap = null; } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 570db1bc42..99ccc9cce1 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -7,7 +7,6 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MONITORING_REQUEST_COUNT; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -16,11 +15,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import org.opensearch.OpenSearchStatusException; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.TokenBucket; -import org.opensearch.core.rest.RestStatus; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; @@ -46,8 +43,9 @@ public MLModelCacheHelper(ClusterService clusterService, Settings settings) { /** * Initialize model state. - * @param modelId model id - * @param state model state + * + * @param modelId model id + * @param state model state * @param functionName function name */ public synchronized void initModelState( @@ -71,8 +69,9 @@ public synchronized void initModelState( /** * Set model state + * * @param modelId model id - * @param state model state + * @param state model state */ public synchronized void setModelState(String modelId, MLModelState state) { log.debug("Updating State of Model {} to state {}", modelId, state); @@ -81,12 +80,13 @@ public synchronized void setModelState(String modelId, MLModelState state) { /** * Set a rate limiter to enable model level throttling - * @param modelId model id + * + * @param modelId model id * @param rateLimiter rate limiter */ - public synchronized void setModelRateLimiter(String modelId, TokenBucket rateLimiter) { + public synchronized void setRateLimiter(String modelId, TokenBucket rateLimiter) { log.debug("Setting the rate limiter for Model {}", modelId); - getExistingModelCache(modelId).setModelRateLimiter(rateLimiter); + getExistingModelCache(modelId).setRateLimiter(rateLimiter); } /** @@ -94,92 +94,36 @@ public synchronized void setModelRateLimiter(String modelId, TokenBucket rateLim * * @param modelId model id */ - public TokenBucket getModelRateLimiter(String modelId) { + public TokenBucket getRateLimiter(String modelId) { MLModelCache modelCache = modelCaches.get(modelId); if (modelCache == null) { return null; } - return modelCache.getModelRateLimiter(); + return modelCache.getRateLimiter(); } /** * Remove the rate limiter from cache to disable model level throttling + * * @param modelId model id */ - public synchronized void removeModelRateLimiter(String modelId) { + public synchronized void removeRateLimiter(String modelId) { log.debug("Removing the rate limiter for Model {}", modelId); - getExistingModelCache(modelId).setModelRateLimiter(null); - } - - /** - * Set the user rate limiter map for a single user to enable user level throttling. - * - * @param modelId model id - * @param user user - * @param rateLimiter rate limiter - */ - public synchronized void setUserRateLimiterMap(String modelId, String user, TokenBucket rateLimiter) { - log.debug("Setting the user level rate limiter for Model {}", modelId); - Map userRateLimiterMap = new HashMap<>() { - { - put(user, rateLimiter); - } - }; - getExistingModelCache(modelId).setUserRateLimiterMap(userRateLimiterMap); + getExistingModelCache(modelId).setRateLimiter(null); } /** * Set the user rate limiter map to enable user level throttling. * - * @param modelId model id - * @param userRateLimiterMap a map with user's name and its corresponding rate limiter + * @param modelId model id + * @param userRateLimiterMap a map with user's name and its corresponding rate + * limiter */ public synchronized void setUserRateLimiterMap(String modelId, Map userRateLimiterMap) { log.debug("Setting the user level rate limiter for Model {}", modelId); getExistingModelCache(modelId).setUserRateLimiterMap(userRateLimiterMap); } - /** - * Update the user rate limiter map with the user rate limiter map. - * If the user rate limiter map doesn't exist for the model, consider calling setUserRateLimiterMap instead. - * - * @param modelId model id - * @param updateUserRateLimiterMap a map with user's name and its corresponding rate limiter - */ - public synchronized void updateUserRateLimiterMap(String modelId, Map updateUserRateLimiterMap) { - log.debug("Updating the user level rate limiter for Model {}", modelId); - Map userRateLimiterMap = getExistingModelCache(modelId).getUserRateLimiterMap(); - if (userRateLimiterMap != null) { - userRateLimiterMap.putAll(updateUserRateLimiterMap); - } else { - throw new OpenSearchStatusException( - "Model controller doesn't exist for the model. Consider calling create model controller api instead. Model ID: " + modelId, - RestStatus.CONFLICT - ); - } - } - - /** - * Update the user rate limiter map for a single user. - * If the user rate limiter map doesn't exist for the model, consider calling setUserRateLimiterMap instead. - * - * @param modelId model id - * @param user user - * @param rateLimiter rate limiter - */ - public synchronized void updateUserRateLimiterMap(String modelId, String user, TokenBucket rateLimiter) { - log.debug("Updating the user level rate limiter for Model {}", modelId); - Map userRateLimiterMap = getExistingModelCache(modelId).getUserRateLimiterMap(); - if (userRateLimiterMap != null) { - userRateLimiterMap.put(user, rateLimiter); - } else { - throw new OpenSearchStatusException( - "Model controller doesn't exist for the model. Consider calling create model controller api instead. Model ID: " + modelId, - RestStatus.CONFLICT - ); - } - } - /** * Remove the user rate limiter map from cache to disable user level throttling. * @@ -218,7 +162,8 @@ public TokenBucket getUserRateLimiter(String modelId, String user) { /** * Set a quota flag to control if the model can still receive request - * @param modelId model id + * + * @param modelId model id * @param isModelEnabled quota flag */ public synchronized void setIsModelEnabled(String modelId, Boolean isModelEnabled) { @@ -228,6 +173,7 @@ public synchronized void setIsModelEnabled(String modelId, Boolean isModelEnable /** * Get the current quota flag condition for the model + * * @param modelId model id */ public Boolean getIsModelEnabled(String modelId) { @@ -240,9 +186,10 @@ public Boolean getIsModelEnabled(String modelId) { /** * Set memory size estimation CPU/GPU + * * @param modelId model id - * @param format model format like onnx - * @param size memory size + * @param format model format like onnx + * @param size memory size */ public synchronized void setMemSizeEstimation(String modelId, MLModelFormat format, Long size) { Long memSize = getMemSizeEstimation(format, size); @@ -267,6 +214,7 @@ private Long getMemSizeEstimation(MLModelFormat format, Long size) { /** * Get CPU memory estimation. + * * @param modelId model id * @return Long */ @@ -280,6 +228,7 @@ public Long getMemEstCPU(String modelId) { /** * Get GPU memory estimation. + * * @param modelId model id * @return Long */ @@ -293,6 +242,7 @@ public Long getMemEstGPU(String modelId) { /** * Check if model deployed on node. + * * @param modelId model id * @return true if model deployed */ @@ -303,6 +253,7 @@ public synchronized boolean isModelDeployed(String modelId) { /** * Get deployed models on node. + * * @return array of model id */ public String[] getDeployedModels() { @@ -317,6 +268,7 @@ public String[] getDeployedModels() { /** * Get deployed local models on node. + * * @return array of model id */ public String[] getLocalDeployedModels() { @@ -334,6 +286,7 @@ public String[] getLocalDeployedModels() { /** * Check if model is running on node. + * * @param modelId model id * @return true if model is running on node. */ @@ -344,7 +297,8 @@ public boolean isModelRunningOnNode(String modelId) { /** * Set predictor of model. - * @param modelId model id + * + * @param modelId model id * @param predictor predictor */ public synchronized void setPredictor(String modelId, Predictable predictor) { @@ -367,6 +321,7 @@ public MLExecutable getMLExecutor(String modelId) { /** * Get predictor of model. + * * @param modelId model id * @return predictor */ @@ -380,7 +335,8 @@ public Predictable getPredictor(String modelId) { /** * Set target worker nodes of model. - * @param modelId model id + * + * @param modelId model id * @param targetWorkerNodes target worker nodes of model */ public void setTargetWorkerNodes(String modelId, List targetWorkerNodes) { @@ -392,6 +348,7 @@ public void setTargetWorkerNodes(String modelId, List targetWorkerNodes) /** * Remove model. + * * @param modelId model id */ public void removeModel(String modelId) { @@ -405,6 +362,7 @@ public void removeModel(String modelId) { /** * Get all model IDs in model cache. + * * @return array of model id */ public String[] getAllModels() { @@ -413,6 +371,7 @@ public String[] getAllModels() { /** * Get worker nodes of model. + * * @param modelId model id * @return array of node id; return null if model not exists in cache */ @@ -426,8 +385,9 @@ public String[] getWorkerNodes(String modelId) { /** * Add worker node of model. + * * @param modelId model id - * @param nodeId node id + * @param nodeId node id */ public synchronized void addWorkerNode(String modelId, String nodeId) { log.debug("add node {} to model routing table for model: {}", nodeId, modelId); @@ -437,6 +397,7 @@ public synchronized void addWorkerNode(String modelId, String nodeId) { /** * Remove worker nodes for all models. + * * @param removedNodes removed nodes */ public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) { @@ -454,11 +415,15 @@ public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) /** * Remove worker node of model. - * @param modelId model id - * @param nodeId node id - * @param isFromUndeploy Only allow custom deploy is true and user undeployed partial nodes, the isFromUndeploy is true, in - * this case, we need to change the deployToAllNodes flag to false in cache to make sure it's consistent - * with model index, also we need to change the target worker nodes to exclude the removed worker nodes. + * + * @param modelId model id + * @param nodeId node id + * @param isFromUndeploy Only allow custom deploy is true and user undeployed + * partial nodes, the isFromUndeploy is true, in + * this case, we need to change the deployToAllNodes flag + * to false in cache to make sure it's consistent + * with model index, also we need to change the target + * worker nodes to exclude the removed worker nodes. */ public void removeWorkerNode(String modelId, String nodeId, boolean isFromUndeploy) { MLModelCache modelCache = modelCaches.get(modelId); @@ -474,6 +439,7 @@ public void removeWorkerNode(String modelId, String nodeId, boolean isFromUndepl /** * Sync worker nodes for all models. + * * @param modelWorkerNodes worker nodes of all models */ public void syncWorkerNodes(Map> modelWorkerNodes) { @@ -499,6 +465,7 @@ public void clearWorkerNodes() { /** * Clear worker node of model. + * * @param modelId model id */ public void clearWorkerNodes(String modelId) { @@ -514,6 +481,7 @@ public void clearWorkerNodes(String modelId) { /** * Get model profile. + * * @param modelId model id * @return model profile */ @@ -545,7 +513,8 @@ public MLModelProfile getModelProfile(String modelId) { /** * Add model inference duration. - * @param modelId model id + * + * @param modelId model id * @param duration time in milliseconds used to run inference. */ public void addModelInferenceDuration(String modelId, double duration) { @@ -566,6 +535,7 @@ public void resizeMonitoringQueue(long monitoringReqCount) { /** * Get function name of model + * * @param modelId model id * @return function name */ diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 3214bf4417..4926f414d6 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -8,7 +8,7 @@ import static org.opensearch.common.xcontent.XContentType.JSON; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.NOT_FOUND; @@ -26,7 +26,7 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE; -import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.MODEL_RATE_LIMITER; +import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.RATE_LIMITER; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.SCRIPT_SERVICE; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.USER_RATE_LIMITER_MAP; import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.XCONTENT_REGISTRY; @@ -100,7 +100,7 @@ import org.opensearch.ml.common.MLTask; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.connector.Connector; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; @@ -138,7 +138,8 @@ import lombok.extern.log4j.Log4j2; /** - * Manager class for ML models. It contains ML model related operations like register, deploy model etc. + * Manager class for ML models. It contains ML model related operations like + * register, deploy model etc. */ @Log4j2 public class MLModelManager { @@ -286,7 +287,7 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput .version(version) .modelGroupId(mlRegisterModelMetaInput.getModelGroupId()) .description(mlRegisterModelMetaInput.getDescription()) - .modelRateLimiterConfig(mlRegisterModelMetaInput.getModelRateLimiterConfig()) + .rateLimiter(mlRegisterModelMetaInput.getRateLimiter()) .modelFormat(mlRegisterModelMetaInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(mlRegisterModelMetaInput.getModelConfig()) @@ -329,9 +330,9 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput /** * - * @param mlRegisterModelInput register model input for remote models - * @param mlTask ML task - * @param listener action listener + * @param mlRegisterModelInput register model input for remote models + * @param mlTask ML task + * @param listener action listener */ public void registerMLRemoteModel( MLRegisterModelInput mlRegisterModelInput, @@ -398,10 +399,11 @@ public void registerMLRemoteModel( } /** - * Register model. Basically download model file, split into chunks and save into model index. + * Register model. Basically download model file, split into chunks and save + * into model index. * * @param registerModelInput register model input - * @param mlTask ML task + * @param mlTask ML task */ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTask) { @@ -516,7 +518,7 @@ private void indexRemoteModel( .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) - .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) + .rateLimiter(registerModelInput.getRateLimiter()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -580,7 +582,7 @@ void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, St .modelGroupId(registerModelInput.getModelGroupId()) .version(version) .description(registerModelInput.getDescription()) - .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) + .rateLimiter(registerModelInput.getRateLimiter()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERED) .connector(registerModelInput.getConnector()) @@ -646,7 +648,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas .algorithm(functionName) .version(version) .description(registerModelInput.getDescription()) - .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) + .rateLimiter(registerModelInput.getRateLimiter()) .modelFormat(registerModelInput.getModelFormat()) .modelState(MLModelState.REGISTERING) .modelConfig(registerModelInput.getModelConfig()) @@ -729,7 +731,7 @@ private void registerModel( .algorithm(functionName) .version(version) .modelFormat(registerModelInput.getModelFormat()) - .modelRateLimiterConfig(registerModelInput.getModelRateLimiterConfig()) + .rateLimiter(registerModelInput.getRateLimiter()) .chunkNumber(chunkNum) .totalChunks(chunkFiles.size()) .content(Base64.getEncoder().encodeToString(bytes)) @@ -805,7 +807,8 @@ private ThreadedActionListener threadedActionListener(String threadPoolNa /** * Check if exceed running task limit and if circuit breaker is open. - * @param mlTask ML task + * + * @param mlTask ML task * @param runningTaskLimit limit */ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) { @@ -873,7 +876,8 @@ private void deleteModel(String modelId, MLRegisterModelInput registerModelInput } private void deleteOrUpdateModelGroup(String modelGroupID, Boolean doesVersionCreateModelGroup, String modelVersion) { - // This checks if model group is created when registering the version. If yes, model group is deleted since the version registration + // This checks if model group is created when registering the version. If yes, + // model group is deleted since the version registration // had failed. Else model group latest version is decremented by 1 if (doesVersionCreateModelGroup) { DeleteRequest deleteModelGroupRequest = new DeleteRequest(); @@ -915,7 +919,8 @@ private void handleException(FunctionName functionName, String taskId, Exception } /** - * Read model chunks from model index. Concat chunks into a whole model file, then load + * Read model chunks from model index. Concat chunks into a whole model file, + * then load * into memory. * * @param modelId model id @@ -960,10 +965,11 @@ public void deployModel( if (FunctionName.REMOTE == mlModel.getAlgorithm() || (!FunctionName.isDLModel(mlModel.getAlgorithm()) && mlModel.getAlgorithm() != FunctionName.METRICS_CORRELATION)) { // deploy remote model or model trained by built-in algorithm like kmeans - // deploy remote model with internal connector or model trained by built-in algorithm like kmeans - if (BooleanUtils.isTrue(mlModel.getIsModelControllerEnabled())) { - getModelController(modelId, ActionListener.wrap(modelController -> { - setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + // deploy remote model with internal connector or model trained by built-in + // algorithm like kmeans + if (BooleanUtils.isTrue(mlModel.getIsControllerEnabled())) { + getController(modelId, ActionListener.wrap(controller -> { + setupUserRateLimiterMap(modelId, eligibleNodeCount, controller.getUserRateLimiter()); log.info("Successfully redeployed model controller for model " + modelId); log.info("Trying to deploy remote model with model controller configured."); deployRemoteOrBuiltInModel(mlModel, eligibleNodeCount, wrappedListener); @@ -984,8 +990,8 @@ public void deployModel( return; } - setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); - deployModelControllerWithDeployingModel(mlModel, eligibleNodeCount); + setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); + deployControllerWithDeployingModel(mlModel, eligibleNodeCount); // check circuit breaker before deploying custom model chunks checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats); retrieveModelChunks(mlModel, ActionListener.wrap(modelZipFile -> {// read model chunks @@ -1045,7 +1051,7 @@ public void deployModel( private void deployRemoteOrBuiltInModel(MLModel mlModel, Integer eligibleNodeCount, ActionListener wrappedListener) { String modelId = mlModel.getModelId(); - setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); + setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupParamsAndPredictable(modelId, mlModel); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); @@ -1071,7 +1077,7 @@ private void setupParamsAndPredictable(String modelId, MLModel mlModel) { } private Map setUpParameterMap(String modelId) { - TokenBucket modelRateLimiter = getModelRateLimiter(modelId); + TokenBucket rateLimiter = getRateLimiter(modelId); Map userRateLimiterMap = getUserRateLimiterMap(modelId); Map params = new HashMap<>(); @@ -1081,19 +1087,19 @@ private Map setUpParameterMap(String modelId) { params.put(XCONTENT_REGISTRY, xContentRegistry); params.put(CLUSTER_SERVICE, clusterService); - if (modelRateLimiter == null && userRateLimiterMap == null) { + if (rateLimiter == null && userRateLimiterMap == null) { log.info("Setting up basic ML predictor parameters."); return Collections.unmodifiableMap(params); - } else if (modelRateLimiter != null && userRateLimiterMap == null) { - params.put(MODEL_RATE_LIMITER, modelRateLimiter); + } else if (rateLimiter != null && userRateLimiterMap == null) { + params.put(RATE_LIMITER, rateLimiter); log.info("Setting up basic ML predictor parameters with model level throttling."); return Collections.unmodifiableMap(params); - } else if (modelRateLimiter == null) { + } else if (rateLimiter == null) { params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); log.info("Setting up basic ML predictor parameters with user level throttling."); return Collections.unmodifiableMap(params); } else { - params.put(MODEL_RATE_LIMITER, modelRateLimiter); + params.put(RATE_LIMITER, rateLimiter); params.put(USER_RATE_LIMITER_MAP, userRateLimiterMap); log.info("Setting up basic ML predictor parameters with both model and user level throttling."); return Collections.unmodifiableMap(params); @@ -1118,7 +1124,7 @@ public synchronized void updateModelCache(String modelId, ActionListener getModel(modelId, ActionListener.wrap(mlModel -> { int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; modelCacheHelper.setIsModelEnabled(modelId, mlModel.getIsEnabled()); - setupModelRateLimiter(modelId, eligibleNodeCount, mlModel.getModelRateLimiterConfig()); + setupRateLimiter(modelId, eligibleNodeCount, mlModel.getRateLimiter()); if (mlModel.getAlgorithm() == FunctionName.REMOTE) { if (mlModel.getConnector() != null) { setupParamsAndPredictable(modelId, mlModel); @@ -1143,13 +1149,15 @@ public synchronized void updateModelCache(String modelId, ActionListener } /** - * Deploy the model controller with a model id. This method should be called AFTER a model is deployed. - * If you want to implement similar behavior during model deploy, deployModelControllerWithDeployingModel is the one supposed be called. + * Deploy the model controller with a model id. This method should be called + * AFTER a model is deployed. + * If you want to implement similar behavior during model deploy, + * deployControllerWithDeployingModel is the one supposed be called. * - * @param modelId ml model ID + * @param modelId ml model ID * @param listener action listener */ - public synchronized void deployModelControllerWithDeployedModel(String modelId, ActionListener listener) { + public synchronized void deployControllerWithDeployedModel(String modelId, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { if (!modelCacheHelper.isModelDeployed(modelId)) { throw new OpenSearchStatusException( @@ -1159,9 +1167,9 @@ public synchronized void deployModelControllerWithDeployedModel(String modelId, } ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); getModel(modelId, ActionListener.wrap(mlModel -> { - getModelController(modelId, ActionListener.wrap(modelController -> { + getController(modelId, ActionListener.wrap(controller -> { int eligibleNodeCount = getWorkerNodes(modelId, mlModel.getAlgorithm()).length; - setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + setupUserRateLimiterMap(modelId, eligibleNodeCount, controller.getUserRateLimiter()); if (mlModel.getAlgorithm() == FunctionName.REMOTE) { if (mlModel.getConnector() != null) { setupParamsAndPredictable(modelId, mlModel); @@ -1191,10 +1199,10 @@ public synchronized void deployModelControllerWithDeployedModel(String modelId, * Undploy the model controller for a model. * Usually this method is called during deleting the model controller. * - * @param modelId ml model ID + * @param modelId ml model ID * @param listener action listener */ - public synchronized void undeployModelController(String modelId, ActionListener listener) { + public synchronized void undeployController(String modelId, ActionListener listener) { if (modelCacheHelper.isModelDeployed(modelId)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { ActionListener wrappedListener = ActionListener.runBefore(listener, context::restore); @@ -1245,31 +1253,32 @@ public synchronized void undeployModelController(String modelId, ActionListener< /** * Deploy the model controller for a model during model is deploying. * - * @param mlModel ml model + * @param mlModel ml model * @param listener action listener */ - private synchronized void deployModelControllerWithDeployingModel( + private synchronized void deployControllerWithDeployingModel( MLModel mlModel, Integer eligibleNodeCount, ActionListener listener ) { String modelId = mlModel.getModelId(); FetchSourceContext fetchContext = new FetchSourceContext(true); - GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); + GetRequest getRequest = new GetRequest(ML_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelController modelController = MLModelController.parse(parser); - setupUserRateLimiterMap(modelId, eligibleNodeCount, modelController.getUserRateLimiterConfig()); + MLController controller = MLController.parse(parser); + setupUserRateLimiterMap(modelId, eligibleNodeCount, controller.getUserRateLimiter()); log.info("Successfully redeployed model controller for model " + modelId); listener.onResponse("Successfully redeployed model controller for model " + modelId); } catch (Exception e) { log.error("Failed to parse ml task" + r.getId(), e); listener.onFailure(e); } - } else if (mlModel.getIsModelControllerEnabled() == null || !mlModel.getIsModelControllerEnabled()) { - // Not going to respond the failure here due to the model deploy can still work well + } else if (mlModel.getIsControllerEnabled() == null || !mlModel.getIsControllerEnabled()) { + // Not going to respond the failure here due to the model deploy can still work + // well listener .onResponse( "The model " @@ -1289,19 +1298,21 @@ private synchronized void deployModelControllerWithDeployingModel( } /** - * Deploy the model controller for a model during model is deploying with build-in listener. - * Usually this method is called when re-deploying a previous un-deployed model with the model controller. + * Deploy the model controller for a model during model is deploying with + * build-in listener. + * Usually this method is called when re-deploying a previous un-deployed model + * with the model controller. * * @param mlModel ml model */ - public void deployModelControllerWithDeployingModel(MLModel mlModel, Integer eligibleNodeCount) { + public void deployControllerWithDeployingModel(MLModel mlModel, Integer eligibleNodeCount) { if (mlModel.getModelState() != MLModelState.DEPLOYING) { throw new OpenSearchStatusException( "This method should only be called when model is in DEPLOYING state, but the model is in state: " + mlModel.getModelState(), RestStatus.CONFLICT ); } - deployModelControllerWithDeployingModel(mlModel, eligibleNodeCount, ActionListener.wrap(response -> { + deployControllerWithDeployingModel(mlModel, eligibleNodeCount, ActionListener.wrap(response -> { if (response.startsWith("Successfully")) { log.debug(response, mlModel.getModelId()); } else if (response.startsWith("Failed")) { @@ -1312,19 +1323,18 @@ public void deployModelControllerWithDeployingModel(MLModel mlModel, Integer eli }, e -> log.error("Failed to re-deploy the model controller for model: " + mlModel.getModelId(), e))); } - private void setupModelRateLimiter(String modelId, Integer eligibleNodeCount, MLRateLimiter modelRateLimiter) { - if (modelRateLimiter != null) { - modelCacheHelper.setModelRateLimiter(modelId, rateLimiterConstructor(eligibleNodeCount, modelRateLimiter)); + private void setupRateLimiter(String modelId, Integer eligibleNodeCount, MLRateLimiter rateLimiter) { + if (rateLimiter != null) { + modelCacheHelper.setRateLimiter(modelId, createTokenBucket(eligibleNodeCount, rateLimiter)); } else { - modelCacheHelper.removeModelRateLimiter(modelId); + modelCacheHelper.removeRateLimiter(modelId); } } - private void setupUserRateLimiterMap(String modelId, Integer eligibleNodeCount, Map userRateLimiterConfig) { - if (userRateLimiterConfig != null && !userRateLimiterConfig.isEmpty()) { + private void setupUserRateLimiterMap(String modelId, Integer eligibleNodeCount, Map userRateLimiter) { + if (userRateLimiter != null && !userRateLimiter.isEmpty()) { Map userRateLimiterMap = new HashMap<>(); - userRateLimiterConfig - .forEach((user, rateLimiter) -> userRateLimiterMap.put(user, rateLimiterConstructor(eligibleNodeCount, rateLimiter))); + userRateLimiter.forEach((user, rateLimiter) -> userRateLimiterMap.put(user, createTokenBucket(eligibleNodeCount, rateLimiter))); modelCacheHelper.setUserRateLimiterMap(modelId, userRateLimiterMap); } else { modelCacheHelper.removeUserRateLimiterMap(modelId); @@ -1339,22 +1349,22 @@ private void removeUserRateLimiterMap(String modelId) { * Construct a TokenBucket object from its rate limiter config. * * @param eligibleNodeCount eligible node count - * @param modelRateLimiter model rate limiter config + * @param rateLimiter model rate limiter config * @return a TokenBucket object to enable throttling */ - private TokenBucket rateLimiterConstructor(Integer eligibleNodeCount, MLRateLimiter modelRateLimiter) { - if (modelRateLimiter.isValid()) { - double rateLimitNumber = Double.parseDouble(modelRateLimiter.getRateLimitNumber()); - TimeUnit rateLimitUnit = modelRateLimiter.getRateLimitUnit(); + private TokenBucket createTokenBucket(Integer eligibleNodeCount, MLRateLimiter rateLimiter) { + if (rateLimiter.isValid()) { + double limit = Double.parseDouble(rateLimiter.getLimit()); + TimeUnit unit = rateLimiter.getUnit(); log .info( "Initializing the rate limiter with setting {} per {} (TPS limit {}), evenly distributed on {} nodes", - rateLimitNumber, - rateLimitUnit, - rateLimitNumber / rateLimitUnit.toSeconds(1), + limit, + unit, + limit / unit.toSeconds(1), eligibleNodeCount ); - return new TokenBucket(System::nanoTime, rateLimitNumber / rateLimitUnit.toNanos(1) / eligibleNodeCount, rateLimitNumber); + return new TokenBucket(System::nanoTime, limit / unit.toNanos(1) / eligibleNodeCount, limit, limit / eligibleNodeCount); } return null; } @@ -1365,15 +1375,16 @@ private TokenBucket rateLimiterConstructor(Integer eligibleNodeCount, MLRateLimi * @param modelId model id * @return a TokenBucket object to enable model-level throttling */ - public TokenBucket getModelRateLimiter(String modelId) { - return modelCacheHelper.getModelRateLimiter(modelId); + public TokenBucket getRateLimiter(String modelId) { + return modelCacheHelper.getRateLimiter(modelId); } /** * Get model-level rate limiter with model id. * * @param modelId model id - * @return a map with user's name and its corresponding rate limiter object to track user-level throttling + * @return a map with user's name and its corresponding rate limiter object to + * track user-level throttling */ public Map getUserRateLimiterMap(String modelId) { return modelCacheHelper.getUserRateLimiterMap(modelId); @@ -1423,18 +1434,18 @@ public void getModel(String modelId, String[] includes, String[] excludes, Actio /** * Get model controller from model controller index. * - * @param modelId model id + * @param modelId model id * @param listener action listener */ - public void getModelController(String modelId, ActionListener listener) { + public void getController(String modelId, ActionListener listener) { FetchSourceContext fetchContext = new FetchSourceContext(true); - GetRequest getRequest = new GetRequest(ML_MODEL_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); + GetRequest getRequest = new GetRequest(ML_CONTROLLER_INDEX).id(modelId).fetchSourceContext(fetchContext); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelController modelController = MLModelController.parse(parser); - listener.onResponse(modelController); + MLController controller = MLController.parse(parser); + listener.onResponse(controller); } catch (Exception e) { log.error("Failed to parse ml task" + r.getId(), e); listener.onFailure(e); @@ -1449,7 +1460,7 @@ public void getModelController(String modelId, ActionListener * Get connector from connector index. * * @param connectorId connector id - * @param listener action listener + * @param listener action listener */ private void getConnector(String connectorId, ActionListener listener) { GetRequest getRequest = new GetRequest().index(CommonValue.ML_CONNECTOR_INDEX).id(connectorId); @@ -1479,7 +1490,7 @@ private void getConnector(String connectorId, ActionListener listener * Retreive a model's all chunks. * * @param mlModelMeta model meta - * @param listener action listener + * @param listener action listener */ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener listener) throws InterruptedException { String modelId = mlModelMeta.getModelId(); @@ -1570,7 +1581,7 @@ public void updateModel(String modelId, Map updatedFields, Actio /** * Get model chunk id. * - * @param modelId model id + * @param modelId model id * @param chunkNumber model chunk number * @return model chunk id */ @@ -1669,8 +1680,8 @@ private void removeModel(String modelId) { /** * Get worker nodes of specific model. * - * @param modelId model id - * @param functionName function name + * @param modelId model id + * @param functionName function name * @param onlyEligibleNode return only eligible node * @return list of worker node ids */ @@ -1697,7 +1708,7 @@ public int getWorkerNodesSize(String modelId, FunctionName functionName, boolean /** * Get worker node of specific model without filtering eligible node. * - * @param modelId model id + * @param modelId model id * @param functionName function name * @return list of worker node ids */ @@ -1720,7 +1731,8 @@ public Predictable getPredictor(String modelId) { } /** - * Get all model ids in cache, both local model id and remote model in routing table. + * Get all model ids in cache, both local model id and remote model in routing + * table. * * @return array of model ids */ diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 78309ce0a1..7a5a0f9f39 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,7 +7,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_CONTROLLER_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONTROLLER_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; @@ -46,12 +46,12 @@ import org.opensearch.ml.action.connector.SearchConnectorTransportAction; import org.opensearch.ml.action.connector.TransportCreateConnectorAction; import org.opensearch.ml.action.connector.UpdateConnectorTransportAction; -import org.opensearch.ml.action.controller.CreateModelControllerTransportAction; -import org.opensearch.ml.action.controller.DeleteModelControllerTransportAction; -import org.opensearch.ml.action.controller.DeployModelControllerTransportAction; -import org.opensearch.ml.action.controller.GetModelControllerTransportAction; -import org.opensearch.ml.action.controller.UndeployModelControllerTransportAction; -import org.opensearch.ml.action.controller.UpdateModelControllerTransportAction; +import org.opensearch.ml.action.controller.CreateControllerTransportAction; +import org.opensearch.ml.action.controller.DeleteControllerTransportAction; +import org.opensearch.ml.action.controller.DeployControllerTransportAction; +import org.opensearch.ml.action.controller.GetControllerTransportAction; +import org.opensearch.ml.action.controller.UndeployControllerTransportAction; +import org.opensearch.ml.action.controller.UpdateControllerTransportAction; import org.opensearch.ml.action.deploy.TransportDeployModelAction; import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; @@ -115,12 +115,12 @@ import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction; import org.opensearch.ml.common.transport.connector.MLUpdateConnectorAction; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; @@ -198,19 +198,19 @@ import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; import org.opensearch.ml.rest.RestMLCreateConnectorAction; -import org.opensearch.ml.rest.RestMLCreateModelControllerAction; +import org.opensearch.ml.rest.RestMLCreateControllerAction; import org.opensearch.ml.rest.RestMLDeleteAgentAction; import org.opensearch.ml.rest.RestMLDeleteConnectorAction; +import org.opensearch.ml.rest.RestMLDeleteControllerAction; import org.opensearch.ml.rest.RestMLDeleteModelAction; -import org.opensearch.ml.rest.RestMLDeleteModelControllerAction; import org.opensearch.ml.rest.RestMLDeleteModelGroupAction; import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; +import org.opensearch.ml.rest.RestMLGetControllerAction; import org.opensearch.ml.rest.RestMLGetModelAction; -import org.opensearch.ml.rest.RestMLGetModelControllerAction; import org.opensearch.ml.rest.RestMLGetModelGroupAction; import org.opensearch.ml.rest.RestMLGetTaskAction; import org.opensearch.ml.rest.RestMLGetToolAction; @@ -231,8 +231,8 @@ import org.opensearch.ml.rest.RestMLTrainingAction; import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateConnectorAction; +import org.opensearch.ml.rest.RestMLUpdateControllerAction; import org.opensearch.ml.rest.RestMLUpdateModelAction; -import org.opensearch.ml.rest.RestMLUpdateModelControllerAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; import org.opensearch.ml.rest.RestMemoryCreateConversationAction; @@ -390,12 +390,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class), - new ActionHandler<>(MLCreateModelControllerAction.INSTANCE, CreateModelControllerTransportAction.class), - new ActionHandler<>(MLModelControllerGetAction.INSTANCE, GetModelControllerTransportAction.class), - new ActionHandler<>(MLDeployModelControllerAction.INSTANCE, DeployModelControllerTransportAction.class), - new ActionHandler<>(MLUpdateModelControllerAction.INSTANCE, UpdateModelControllerTransportAction.class), - new ActionHandler<>(MLModelControllerDeleteAction.INSTANCE, DeleteModelControllerTransportAction.class), - new ActionHandler<>(MLUndeployModelControllerAction.INSTANCE, UndeployModelControllerTransportAction.class), + new ActionHandler<>(MLCreateControllerAction.INSTANCE, CreateControllerTransportAction.class), + new ActionHandler<>(MLControllerGetAction.INSTANCE, GetControllerTransportAction.class), + new ActionHandler<>(MLDeployControllerAction.INSTANCE, DeployControllerTransportAction.class), + new ActionHandler<>(MLUpdateControllerAction.INSTANCE, UpdateControllerTransportAction.class), + new ActionHandler<>(MLControllerDeleteAction.INSTANCE, DeleteControllerTransportAction.class), + new ActionHandler<>(MLUndeployControllerAction.INSTANCE, UndeployControllerTransportAction.class), new ActionHandler<>(MLAgentGetAction.INSTANCE, GetAgentTransportAction.class), new ActionHandler<>(MLAgentDeleteAction.INSTANCE, DeleteAgentTransportAction.class), new ActionHandler<>(UpdateConversationAction.INSTANCE, UpdateConversationTransportAction.class), @@ -451,8 +451,8 @@ public Collection createComponents( stats.put(MLClusterLevelStat.ML_TASK_INDEX_STATUS, new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_TASK_INDEX))); stats .put( - MLClusterLevelStat.ML_MODEL_CONTROLLER_INDEX_STATUS, - new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_MODEL_CONTROLLER_INDEX)) + MLClusterLevelStat.ML_CONTROLLER_INDEX_STATUS, + new MLStat<>(true, new IndexStatusSupplier(indexUtils, ML_CONTROLLER_INDEX)) ); stats.put(MLClusterLevelStat.ML_MODEL_COUNT, new MLStat<>(true, new CounterSupplier())); stats.put(MLClusterLevelStat.ML_CONNECTOR_COUNT, new MLStat<>(true, new CounterSupplier())); @@ -609,7 +609,8 @@ public Collection createComponents( ); // TODO move this into MLFeatureEnabledSetting - // search processor factories below will get BooleanSupplier that supplies the current value being updated through this. + // search processor factories below will get BooleanSupplier that supplies the + // current value being updated through this. clusterService .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); @@ -696,10 +697,10 @@ public List getRestHandlers( RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); - RestMLCreateModelControllerAction restMLCreateModelControllerAction = new RestMLCreateModelControllerAction(); - RestMLGetModelControllerAction restMLGetModelControllerAction = new RestMLGetModelControllerAction(); - RestMLUpdateModelControllerAction restMLUpdateModelControllerAction = new RestMLUpdateModelControllerAction(); - RestMLDeleteModelControllerAction restMLDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); + RestMLCreateControllerAction restMLCreateControllerAction = new RestMLCreateControllerAction(); + RestMLGetControllerAction restMLGetControllerAction = new RestMLGetControllerAction(); + RestMLUpdateControllerAction restMLUpdateControllerAction = new RestMLUpdateControllerAction(); + RestMLDeleteControllerAction restMLDeleteControllerAction = new RestMLDeleteControllerAction(); RestMLGetAgentAction restMLGetAgentAction = new RestMLGetAgentAction(); RestMLDeleteAgentAction restMLDeleteAgentAction = new RestMLDeleteAgentAction(); RestMemoryUpdateConversationAction restMemoryUpdateConversationAction = new RestMemoryUpdateConversationAction(); @@ -748,10 +749,10 @@ public List getRestHandlers( restSearchInteractionsAction, restGetConversationAction, restGetInteractionAction, - restMLCreateModelControllerAction, - restMLGetModelControllerAction, - restMLUpdateModelControllerAction, - restMLDeleteModelControllerAction, + restMLCreateControllerAction, + restMLGetControllerAction, + restMLUpdateControllerAction, + restMLDeleteControllerAction, restMLGetAgentAction, restMLDeleteAgentAction, restMemoryUpdateConversationAction, diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java similarity index 55% rename from plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java rename to plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java index b6f180e49a..6eb0041edd 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateModelControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLCreateControllerAction.java @@ -17,52 +17,50 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import com.google.common.collect.ImmutableList; -public class RestMLCreateModelControllerAction extends BaseRestHandler { +public class RestMLCreateControllerAction extends BaseRestHandler { - public final static String ML_CREATE_MODEL_CONTROLLER_ACTION = "ml_create_model_controller_action"; + public final static String ML_CREATE_CONTROLLER_ACTION = "ml_create_controller_action"; /** * Constructor */ - public RestMLCreateModelControllerAction() {} + public RestMLCreateControllerAction() {} @Override public String getName() { - return ML_CREATE_MODEL_CONTROLLER_ACTION; + return ML_CREATE_CONTROLLER_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) - ); + .of(new Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); } @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - MLCreateModelControllerRequest createModelControllerRequest = getRequest(request); + MLCreateControllerRequest createControllerRequest = getRequest(request); return channel -> { - client.execute(MLCreateModelControllerAction.INSTANCE, createModelControllerRequest, new RestToXContentListener<>(channel)); + client.execute(MLCreateControllerAction.INSTANCE, createControllerRequest, new RestToXContentListener<>(channel)); }; } /** - * Creates a MLCreateModelControllerRequest from a RestRequest + * Creates a MLCreateControllerRequest from a RestRequest * * @param request RestRequest - * @return MLCreateModelControllerRequest + * @return MLCreateControllerRequest */ - private MLCreateModelControllerRequest getRequest(RestRequest request) throws IOException { + private MLCreateControllerRequest getRequest(RestRequest request) throws IOException { if (!request.hasContent()) { throw new OpenSearchParseException("Create model controller request has empty body"); } @@ -70,8 +68,8 @@ private MLCreateModelControllerRequest getRequest(RestRequest request) throws IO String modelId = getParameterId(request, PARAMETER_MODEL_ID); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelController modelControllerInput = MLModelController.parse(parser); - modelControllerInput.setModelId(modelId); - return new MLCreateModelControllerRequest(modelControllerInput); + MLController controllerInput = MLController.parse(parser); + controllerInput.setModelId(modelId); + return new MLCreateControllerRequest(controllerInput); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java similarity index 52% rename from plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java rename to plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java index 51208a82f0..1524c20d6e 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteModelControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteControllerAction.java @@ -13,8 +13,8 @@ import java.util.Locale; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -24,33 +24,28 @@ /** * This class consists of the REST handler to delete ML Model. */ -public class RestMLDeleteModelControllerAction extends BaseRestHandler { - private static final String ML_DELETE_MODEL_CONTROLLER_ACTION = "ml_delete_model_controller_action"; +public class RestMLDeleteControllerAction extends BaseRestHandler { + private static final String ML_DELETE_CONTROLLER_ACTION = "ml_delete_controller_action"; - public void RestMLDeleteModelControllerAction() {} + public void RestMLDeleteControllerAction() {} @Override public String getName() { - return ML_DELETE_MODEL_CONTROLLER_ACTION; + return ML_DELETE_CONTROLLER_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route( - RestRequest.Method.DELETE, - String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID) - ) - ); + .of(new Route(RestRequest.Method.DELETE, String.format(Locale.ROOT, "%s/controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); } @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { String modelId = request.param(PARAMETER_MODEL_ID); - MLModelControllerDeleteRequest mlModelControllerDeleteRequest = new MLModelControllerDeleteRequest(modelId); + MLControllerDeleteRequest mlControllerDeleteRequest = new MLControllerDeleteRequest(modelId); return channel -> client - .execute(MLModelControllerDeleteAction.INSTANCE, mlModelControllerDeleteRequest, new RestToXContentListener<>(channel)); + .execute(MLControllerDeleteAction.INSTANCE, mlControllerDeleteRequest, new RestToXContentListener<>(channel)); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java similarity index 53% rename from plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java rename to plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java index d4946ec2e1..0df1ce2893 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetModelControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetControllerAction.java @@ -15,8 +15,8 @@ import java.util.Locale; import org.opensearch.client.node.NodeClient; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; @@ -24,45 +24,42 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -public class RestMLGetModelControllerAction extends BaseRestHandler { - private static final String ML_GET_MODEL_CONTROLLER_ACTION = "ml_get_model_controller_action"; +public class RestMLGetControllerAction extends BaseRestHandler { + private static final String ML_GET_CONTROLLER_ACTION = "ml_get_controller_action"; /** * Constructor */ - public RestMLGetModelControllerAction() {} + public RestMLGetControllerAction() {} @Override public String getName() { - return ML_GET_MODEL_CONTROLLER_ACTION; + return ML_GET_CONTROLLER_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) - ); + .of(new Route(RestRequest.Method.GET, String.format(Locale.ROOT, "%s/controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); } @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - MLModelControllerGetRequest modelControllerGetRequest = getRequest(request); - return channel -> client - .execute(MLModelControllerGetAction.INSTANCE, modelControllerGetRequest, new RestToXContentListener<>(channel)); + MLControllerGetRequest controllerGetRequest = getRequest(request); + return channel -> client.execute(MLControllerGetAction.INSTANCE, controllerGetRequest, new RestToXContentListener<>(channel)); } /** - * Creates a MLModelControllerGetRequest from a RestRequest + * Creates a MLControllerGetRequest from a RestRequest * * @param request RestRequest - * @return MLModelControllerGetRequest + * @return MLControllerGetRequest */ @VisibleForTesting - MLModelControllerGetRequest getRequest(RestRequest request) throws IOException { + MLControllerGetRequest getRequest(RestRequest request) throws IOException { String modelId = getParameterId(request, PARAMETER_MODEL_ID); boolean returnContent = returnContent(request); - return new MLModelControllerGetRequest(modelId, returnContent); + return new MLControllerGetRequest(modelId, returnContent); } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java similarity index 56% rename from plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java rename to plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java index b64d3e37e7..07fa1cc8a9 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateModelControllerAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateControllerAction.java @@ -17,53 +17,51 @@ import org.opensearch.OpenSearchParseException; import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import com.google.common.collect.ImmutableList; -public class RestMLUpdateModelControllerAction extends BaseRestHandler { +public class RestMLUpdateControllerAction extends BaseRestHandler { - public final static String ML_UPDATE_MODEL_CONTROLLER_ACTION = "ml_update_model_controller_action"; + public final static String ML_UPDATE_CONTROLLER_ACTION = "ml_update_controller_action"; /** * Constructor */ - public RestMLUpdateModelControllerAction() {} + public RestMLUpdateControllerAction() {} @Override public String getName() { - return ML_UPDATE_MODEL_CONTROLLER_ACTION; + return ML_UPDATE_CONTROLLER_ACTION; } @Override public List routes() { return ImmutableList - .of( - new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/model_controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID)) - ); + .of(new Route(RestRequest.Method.PUT, String.format(Locale.ROOT, "%s/controllers/{%s}", ML_BASE_URI, PARAMETER_MODEL_ID))); } @Override public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - MLUpdateModelControllerRequest updateModelControllerRequest = getRequest(request); + MLUpdateControllerRequest updateControllerRequest = getRequest(request); return channel -> { - client.execute(MLUpdateModelControllerAction.INSTANCE, updateModelControllerRequest, new RestToXContentListener<>(channel)); + client.execute(MLUpdateControllerAction.INSTANCE, updateControllerRequest, new RestToXContentListener<>(channel)); }; } /** - * Creates a MLUpdateModelControllerRequest from a RestRequest + * Creates a MLUpdateControllerRequest from a RestRequest * * @param request RestRequest to parse - * @return MLUpdateModelControllerRequest + * @return MLUpdateControllerRequest * @throws IOException if an error occurs while parsing the request */ - private MLUpdateModelControllerRequest getRequest(RestRequest request) throws IOException { + private MLUpdateControllerRequest getRequest(RestRequest request) throws IOException { if (!request.hasContent()) { throw new OpenSearchParseException("Update model controller request has empty body"); } @@ -71,8 +69,8 @@ private MLUpdateModelControllerRequest getRequest(RestRequest request) throws IO String modelId = getParameterId(request, PARAMETER_MODEL_ID); XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelController modelControllerInput = MLModelController.parse(parser); - modelControllerInput.setModelId(modelId); - return new MLUpdateModelControllerRequest(modelControllerInput); + MLController controllerInput = MLController.parse(parser); + controllerInput.setModelId(modelId); + return new MLUpdateControllerRequest(controllerInput); } } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java index b07b876825..fe77d068ad 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLClusterLevelStat.java @@ -14,7 +14,7 @@ public enum MLClusterLevelStat { ML_CONNECTOR_INDEX_STATUS, ML_CONFIG_INDEX_STATUS, ML_TASK_INDEX_STATUS, - ML_MODEL_CONTROLLER_INDEX_STATUS, + ML_CONTROLLER_INDEX_STATUS, ML_MODEL_COUNT, ML_CONNECTOR_COUNT; diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java similarity index 62% rename from plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java index f19a7eda1f..22d610a0a4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/CreateModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/CreateControllerTransportActionTests.java @@ -46,13 +46,13 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.model.MLModelState; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerResponse; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; +import org.opensearch.ml.common.transport.controller.MLCreateControllerResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; @@ -61,7 +61,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class CreateModelControllerTransportActionTests extends OpenSearchTestCase { +public class CreateControllerTransportActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; @@ -75,7 +75,7 @@ public class CreateModelControllerTransportActionTests extends OpenSearchTestCas ActionFilters actionFilters; @Mock - ActionListener actionListener; + ActionListener actionListener; @Mock IndexResponse indexResponse; @@ -102,13 +102,13 @@ public class CreateModelControllerTransportActionTests extends OpenSearchTestCas MLModel mlModel; @Mock - MLDeployModelControllerNodesResponse mlDeployModelControllerNodesResponse; + MLDeployControllerNodesResponse mlDeployControllerNodesResponse; @Rule public ExpectedException exceptionRule = ExpectedException.none(); - CreateModelControllerTransportAction createModelControllerTransportAction; - MLCreateModelControllerRequest createModelControllerRequest; + CreateControllerTransportAction createControllerTransportAction; + MLCreateControllerRequest createControllerRequest; ThreadContext threadContext; @Before @@ -139,8 +139,8 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - createModelControllerTransportAction = spy( - new CreateModelControllerTransportAction( + createControllerTransportAction = spy( + new CreateControllerTransportAction( transportService, actionFilters, mlIndicesHandler, @@ -152,15 +152,15 @@ public void setup() throws IOException { ) ); - MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); - MLModelController modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + MLController controller = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { { put("testUser", rateLimiter); } }).build(); - createModelControllerRequest = MLCreateModelControllerRequest.builder().modelControllerInput(modelController).build(); + createControllerRequest = MLCreateControllerRequest.builder().controllerInput(controller).build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -179,7 +179,7 @@ public void setup() throws IOException { ActionListener listener = invocation.getArgument(0); listener.onResponse(true); return null; - }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + }).when(mlIndicesHandler).initMLControllerIndex(isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -201,27 +201,27 @@ public void setup() throws IOException { } @Test - public void testCreateModelControllerSuccess() { - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); - verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + public void testCreateControllerSuccess() { + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } @Test - public void testCreateModelControllerWithTextEmbeddingModelSuccess() { + public void testCreateControllerWithTextEmbeddingModelSuccess() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); - verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } @Test - public void testCreateModelControllerWithModelAccessControlNoPermission() { + public void testCreateControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -231,28 +231,28 @@ public void testCreateModelControllerWithModelAccessControlNoPermission() { } @Test - public void testCreateModelControllerWithModelAccessControlOtherException() { + public void testCreateControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithModelNotFound() { + public void testCreateControllerWithModelNotFound() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(null); return null; }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -262,104 +262,108 @@ public void testCreateModelControllerWithModelNotFound() { } @Test - public void testCreateModelControllerWithModelStateDeploying() { + public void testCreateControllerWithModelStateDeploying() { when(mlModel.getModelState()).thenReturn(MLModelState.DEPLOYING); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it. Model ID: testModelId", argumentCaptor.getValue().getMessage()); + assertEquals( + "Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it. Model ID: testModelId", + argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithModelFunctionUnsupported() { + public void testCreateControllerWithModelFunctionUnsupported() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.METRICS_CORRELATION); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", argumentCaptor.getValue().getMessage()); + assertEquals( + "Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", + argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithIndexCreatedFailure() { + public void testCreateControllerWithIndexCreatedFailure() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); listener.onResponse(false); return null; - }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + }).when(mlIndicesHandler).initMLControllerIndex(isA(ActionListener.class)); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to create model controller index.", argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithIndexCreatedOtherException() { + public void testCreateControllerWithIndexCreatedOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlIndicesHandler).initMLModelControllerIndex(isA(ActionListener.class)); + }).when(mlIndicesHandler).initMLControllerIndex(isA(ActionListener.class)); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithIndexResponseUpdated() { + public void testCreateControllerWithIndexResponseUpdated() { when(indexResponse.getResult()).thenReturn(DocWriteResponse.Result.UPDATED); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); - verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } @Test - public void testCreateModelControllerWithDeploySuccessNullFailures() { + public void testCreateControllerWithDeploySuccessNullFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(null); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(null); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); - verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } @Test - public void testCreateModelControllerWithUndeploySuccessEmptyFailures() { + public void testCreateControllerWithUndeploySuccessEmptyFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); - verify(actionListener).onResponse(any(MLCreateModelControllerResponse.class)); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); + verify(actionListener).onResponse(any(MLCreateControllerResponse.class)); } @Test - public void testCreateModelControllerWithUndeploySuccessPartiallyFailures() { + public void testCreateControllerWithUndeploySuccessPartiallyFailures() { List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(failures); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(failures); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -370,44 +374,40 @@ public void testCreateModelControllerWithUndeploySuccessPartiallyFailures() { } @Test - public void testCreateModelControllerWithUndeployNullResponse() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); - + public void testCreateControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(null); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Successfully create model controller index with model ID testModelId " + "but deploy model controller to cache was failed on following nodes [foo1, foo2], please retry.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } @Test - public void testCreateModelControllerWithUndeployOtherException() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + public void testCreateControllerWithUndeployOtherException() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(2); actionListener .onFailure( - new RuntimeException("Exception occurred. Please check log for more details.") - ); + new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); - createModelControllerTransportAction.doExecute(null, createModelControllerRequest, actionListener); + createControllerTransportAction.doExecute(null, createControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Exception occurred. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java similarity index 66% rename from plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java index 5da46c6479..57631d3410 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeleteControllerTransportActionTests.java @@ -42,10 +42,10 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -53,7 +53,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class DeleteModelControllerTransportActionTests extends OpenSearchTestCase { +public class DeleteControllerTransportActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; @@ -94,16 +94,16 @@ public class DeleteModelControllerTransportActionTests extends OpenSearchTestCas MLModel mlModel; @Mock - MLModelController mlModelController; + MLController mlController; @Mock - MLUndeployModelControllerNodesResponse mlUndeployModelControllerNodesResponse; + MLUndeployControllerNodesResponse mlUndeployControllerNodesResponse; @Rule public ExpectedException exceptionRule = ExpectedException.none(); - DeleteModelControllerTransportAction deleteModelControllerTransportAction; - MLModelControllerDeleteRequest mlModelControllerDeleteRequest; + DeleteControllerTransportAction deleteControllerTransportAction; + MLControllerDeleteRequest mlControllerDeleteRequest; ThreadContext threadContext; @Before @@ -133,8 +133,8 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); - deleteModelControllerTransportAction = spy( - new DeleteModelControllerTransportAction( + deleteControllerTransportAction = spy( + new DeleteControllerTransportAction( transportService, actionFilters, client, @@ -146,7 +146,7 @@ public void setup() throws IOException { ) ); - mlModelControllerDeleteRequest = MLModelControllerDeleteRequest.builder().modelId("testModelId").build(); + mlControllerDeleteRequest = MLControllerDeleteRequest.builder().modelId("testModelId").build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -161,10 +161,10 @@ public void setup() throws IOException { }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(mlModelController); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mlController); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -183,20 +183,20 @@ public void setup() throws IOException { } @Test - public void testDeleteModelControllerSuccess() { - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + public void testDeleteControllerSuccess() { + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); } @Test - public void testDeleteModelControllerWithModelAccessControlNoPermission() { + public void testDeleteControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -206,61 +206,61 @@ public void testDeleteModelControllerWithModelAccessControlNoPermission() { } @Test - public void testDeleteModelControllerWithModelAccessControlOtherException() { + public void testDeleteControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testDeleteModelControllerWithGetModelNotFoundSuccess() { + public void testDeleteControllerWithGetModelNotFoundSuccess() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(null); return null; }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); } @Test - public void testDeleteModelControllerOtherException() { + public void testDeleteControllerOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(client).delete(any(), any()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testDeleteModelControllerWithGetModelControllerOtherException() { + public void testDeleteControllerWithGetControllerOtherException() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testDeleteModelControllerWithGetModelNotFoundWithGetModelControllerOtherException() { + public void testDeleteControllerWithGetModelNotFoundWithGetControllerOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(null); @@ -268,61 +268,61 @@ public void testDeleteModelControllerWithGetModelNotFoundWithGetModelControllerO }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testDeleteModelControllerWithUndeploySuccessNullFailures() { + public void testDeleteControllerWithUndeploySuccessNullFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlUndeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); - when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(null); + }).when(client).execute(eq(MLUndeployControllerAction.INSTANCE), any(), any()); + when(mlUndeployControllerNodesResponse.failures()).thenReturn(null); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); } @Test - public void testDeleteModelControllerWithUndeploySuccessEmptyFailures() { + public void testDeleteControllerWithUndeploySuccessEmptyFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlUndeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); - when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + }).when(client).execute(eq(MLUndeployControllerAction.INSTANCE), any(), any()); + when(mlUndeployControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); verify(actionListener).onResponse(deleteResponse); } @Test - public void testDeleteModelControllerWithUndeploySuccessPartiallyFailures() { + public void testDeleteControllerWithUndeploySuccessPartiallyFailures() { List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlUndeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlUndeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); - when(mlUndeployModelControllerNodesResponse.failures()).thenReturn(failures); + }).when(client).execute(eq(MLUndeployControllerAction.INSTANCE), any(), any()); + when(mlUndeployControllerNodesResponse.failures()).thenReturn(failures); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -332,43 +332,40 @@ public void testDeleteModelControllerWithUndeploySuccessPartiallyFailures() { } @Test - public void testDeleteModelControllerWithUndeployNullResponse() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + public void testDeleteControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(null); return null; - }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLUndeployControllerAction.INSTANCE), any(), any()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Failed to undeploy model controller with model ID testModelId on following nodes [foo1, foo2], deletion is aborted. Please retry or undeploy the model manually and then perform the deletion.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } @Test - public void testDeleteModelControllerWithUndeployOtherException() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + public void testDeleteControllerWithUndeployOtherException() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(2); actionListener .onFailure( - new RuntimeException("Exception occurred. Please check log for more details.") - ); + new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(client).execute(eq(MLUndeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLUndeployControllerAction.INSTANCE), any(), any()); - deleteModelControllerTransportAction.doExecute(null, mlModelControllerDeleteRequest, actionListener); + deleteControllerTransportAction.doExecute(null, mlControllerDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Exception occurred. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/DeployControllerTransportActionTests.java similarity index 61% rename from plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/DeployControllerTransportActionTests.java index 4f37e1e354..5d584639d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/DeployModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/DeployControllerTransportActionTests.java @@ -36,17 +36,17 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodeResponse; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; import org.opensearch.transport.TransportService; @RunWith(MockitoJUnitRunner.class) -public class DeployModelControllerTransportActionTests { +public class DeployControllerTransportActionTests { @Mock private TransportService transportService; @@ -72,7 +72,7 @@ public class DeployModelControllerTransportActionTests { @Mock NamedXContentRegistry xContentRegistry; - private DeployModelControllerTransportAction action; + private DeployControllerTransportAction action; private DiscoveryNode localNode; @@ -81,7 +81,7 @@ public class DeployModelControllerTransportActionTests { @Before public void setUp() throws Exception { - action = new DeployModelControllerTransportAction( + action = new DeployControllerTransportAction( transportService, actionFilters, mlModelManager, @@ -108,55 +108,52 @@ public void setUp() throws Exception { ActionListener listener = invocation.getArgument(1); listener.onResponse("successful"); return null; - }).when(mlModelManager).deployModelControllerWithDeployedModel(any(), any()); + }).when(mlModelManager).deployControllerWithDeployedModel(any(), any()); } @Test public void testNewResponses() { - final MLDeployModelControllerNodesRequest nodesRequest = new MLDeployModelControllerNodesRequest( + final MLDeployControllerNodesRequest nodesRequest = new MLDeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - Map modelDeployModelControllerStatusMap = new HashMap<>(); - modelDeployModelControllerStatusMap.put("modelName:version", "response"); - MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse( - localNode, - modelDeployModelControllerStatusMap - ); - final List responses = List.of(response); + Map modelDeployControllerStatusMap = new HashMap<>(); + modelDeployControllerStatusMap.put("modelName:version", "response"); + MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, modelDeployControllerStatusMap); + final List responses = List.of(response); final List failures = new ArrayList<>(); - MLDeployModelControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + MLDeployControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); assertNotNull(response1); } @Test public void testNewNodeRequest() { - final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + final MLDeployControllerNodesRequest request = new MLDeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLDeployModelControllerNodeRequest deployModelControllerNodeRequest = action.newNodeRequest(request); - assertNotNull(deployModelControllerNodeRequest); + final MLDeployControllerNodeRequest deployControllerNodeRequest = action.newNodeRequest(request); + assertNotNull(deployControllerNodeRequest); } @Test public void testNewNodeStreamRequest() throws IOException { - Map deployModelControllerStatus = new HashMap<>(); - deployModelControllerStatus.put("modelId1", "response"); - MLDeployModelControllerNodeResponse response = new MLDeployModelControllerNodeResponse(localNode, deployModelControllerStatus); + Map deployControllerStatus = new HashMap<>(); + deployControllerStatus.put("modelId1", "response"); + MLDeployControllerNodeResponse response = new MLDeployControllerNodeResponse(localNode, deployControllerStatus); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - final MLDeployModelControllerNodeResponse deployModelControllerNodeResponse = action.newNodeResponse(output.bytes().streamInput()); - assertNotNull(deployModelControllerNodeResponse); + final MLDeployControllerNodeResponse deployControllerNodeResponse = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(deployControllerNodeResponse); } @Test public void testNodeOperation() { - final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + final MLDeployControllerNodesRequest request = new MLDeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLDeployModelControllerNodeResponse response = action.nodeOperation(new MLDeployModelControllerNodeRequest(request)); + final MLDeployControllerNodeResponse response = action.nodeOperation(new MLDeployControllerNodeRequest(request)); assertNotNull(response); } @@ -166,12 +163,12 @@ public void testNodeOperationException() { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Test exception")); return null; - }).when(mlModelManager).deployModelControllerWithDeployedModel(any(), any()); - final MLDeployModelControllerNodesRequest request = new MLDeployModelControllerNodesRequest( + }).when(mlModelManager).deployControllerWithDeployedModel(any(), any()); + final MLDeployControllerNodesRequest request = new MLDeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLDeployModelControllerNodeResponse response = action.nodeOperation(new MLDeployModelControllerNodeRequest(request)); + final MLDeployControllerNodeResponse response = action.nodeOperation(new MLDeployControllerNodeRequest(request)); assertNotNull(response); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java similarity index 75% rename from plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java index 9e8ff54f10..bb9c76a195 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/GetModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/GetControllerTransportActionTests.java @@ -40,17 +40,17 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class GetModelControllerTransportActionTests extends OpenSearchTestCase { +public class GetControllerTransportActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; @@ -67,7 +67,7 @@ public class GetModelControllerTransportActionTests extends OpenSearchTestCase { ActionFilters actionFilters; @Mock - ActionListener actionListener; + ActionListener actionListener; @Mock ClusterService clusterService; @@ -84,8 +84,8 @@ public class GetModelControllerTransportActionTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - GetModelControllerTransportAction getModelControllerTransportAction; - MLModelControllerGetRequest mlModelControllerGetRequest; + GetControllerTransportAction getControllerTransportAction; + MLControllerGetRequest mlControllerGetRequest; ThreadContext threadContext; @Before @@ -93,8 +93,8 @@ public void setup() throws IOException { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().build(); - getModelControllerTransportAction = spy( - new GetModelControllerTransportAction( + getControllerTransportAction = spy( + new GetControllerTransportAction( transportService, actionFilters, client, @@ -104,7 +104,7 @@ public void setup() throws IOException { modelAccessControlHelper ) ); - mlModelControllerGetRequest = MLModelControllerGetRequest.builder().modelId("testModelId").build(); + mlControllerGetRequest = MLControllerGetRequest.builder().modelId("testModelId").build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -118,7 +118,7 @@ public void setup() throws IOException { return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - GetResponse getResponse = prepareModelControllerGetResponse(); + GetResponse getResponse = prepareControllerGetResponse(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(getResponse); @@ -131,20 +131,20 @@ public void setup() throws IOException { } @Test - public void testGetModelControllerSuccess() { - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); - verify(actionListener).onResponse(any(MLModelControllerGetResponse.class)); + public void testGetControllerSuccess() { + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); + verify(actionListener).onResponse(any(MLControllerGetResponse.class)); } @Test - public void testGetModelControllerWithModelAccessControlNoPermission() { + public void testGetControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -154,28 +154,28 @@ public void testGetModelControllerWithModelAccessControlNoPermission() { } @Test - public void testGetModelControllerWithModelAccessControlOtherException() { + public void testGetControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testGetModelControllerWithGetModelNotFound() { + public void testGetControllerWithGetModelNotFound() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(null); return null; }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -185,14 +185,14 @@ public void testGetModelControllerWithGetModelNotFound() { } @Test - public void testGetModelControllerWithGetModelOtherException() { + public void testGetControllerWithGetModelOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -202,67 +202,67 @@ public void testGetModelControllerWithGetModelOtherException() { } @Test - public void testGetModelControllerOtherException() { + public void testGetControllerOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(client).get(any(), any()); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testGetModelControllerNotFound() { + public void testGetControllerNotFound() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(null); return null; }).when(client).get(any(), any()); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model controller with the provided model ID: testModelId", argumentCaptor.getValue().getMessage()); } @Test - public void testGetModelControllerClientFailedToGetThreadPool() { + public void testGetControllerClientFailedToGetThreadPool() { mock_client_get_NotExist(client); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model controller with the provided model ID: testModelId", argumentCaptor.getValue().getMessage()); } @Test - public void testGetModelControllerIndexNotFoundException() { + public void testGetControllerIndexNotFoundException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new IndexNotFoundException("Failed to find model controller")); return null; }).when(client).get(any(), any()); - getModelControllerTransportAction.doExecute(null, mlModelControllerGetRequest, actionListener); + getControllerTransportAction.doExecute(null, mlControllerGetRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to find model controller", argumentCaptor.getValue().getMessage()); } - public GetResponse prepareModelControllerGetResponse() throws IOException { + public GetResponse prepareControllerGetResponse() throws IOException { - MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); - MLModelController modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + MLController controller = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { { put("testUser", rateLimiter); } }).build(); - XContentBuilder content = modelController.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + XContentBuilder content = controller.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); return new GetResponse(getResult); diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployControllerTransportActionTests.java similarity index 61% rename from plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/UndeployControllerTransportActionTests.java index b18218cbd0..042a2a5974 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UndeployControllerTransportActionTests.java @@ -36,17 +36,17 @@ import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.cluster.DiscoveryNodeHelper; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodeResponse; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesRequest; -import org.opensearch.ml.common.transport.controller.MLUndeployModelControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodeRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodeResponse; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesRequest; +import org.opensearch.ml.common.transport.controller.MLUndeployControllerNodesResponse; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.stats.MLStats; import org.opensearch.transport.TransportService; @RunWith(MockitoJUnitRunner.class) -public class UndeployModelControllerTransportActionTests { +public class UndeployControllerTransportActionTests { @Mock private TransportService transportService; @@ -72,7 +72,7 @@ public class UndeployModelControllerTransportActionTests { @Mock NamedXContentRegistry xContentRegistry; - private UndeployModelControllerTransportAction action; + private UndeployControllerTransportAction action; private DiscoveryNode localNode; @@ -81,7 +81,7 @@ public class UndeployModelControllerTransportActionTests { @Before public void setUp() throws Exception { - action = new UndeployModelControllerTransportAction( + action = new UndeployControllerTransportAction( transportService, actionFilters, mlModelManager, @@ -108,59 +108,52 @@ public void setUp() throws Exception { ActionListener listener = invocation.getArgument(1); listener.onResponse("successful"); return null; - }).when(mlModelManager).undeployModelController(any(), any()); + }).when(mlModelManager).undeployController(any(), any()); } @Test public void testNewResponses() { - final MLUndeployModelControllerNodesRequest nodesRequest = new MLUndeployModelControllerNodesRequest( + final MLUndeployControllerNodesRequest nodesRequest = new MLUndeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - Map modelUndeployModelControllerStatusMap = new HashMap<>(); - modelUndeployModelControllerStatusMap.put("modelName:version", "response"); - MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse( - localNode, - modelUndeployModelControllerStatusMap - ); - final List responses = List.of(response); + Map modelUndeployControllerStatusMap = new HashMap<>(); + modelUndeployControllerStatusMap.put("modelName:version", "response"); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, modelUndeployControllerStatusMap); + final List responses = List.of(response); final List failures = new ArrayList<>(); - MLUndeployModelControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); + MLUndeployControllerNodesResponse response1 = action.newResponse(nodesRequest, responses, failures); assertNotNull(response1); } @Test public void testNewNodeRequest() { - final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + final MLUndeployControllerNodesRequest request = new MLUndeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLUndeployModelControllerNodeRequest undeployModelControllerNodeRequest = action.newNodeRequest(request); - assertNotNull(undeployModelControllerNodeRequest); + final MLUndeployControllerNodeRequest undeployControllerNodeRequest = action.newNodeRequest(request); + assertNotNull(undeployControllerNodeRequest); } @Test public void testNewNodeStreamRequest() throws IOException { - Map undeployModelControllerStatus = new HashMap<>(); - undeployModelControllerStatus.put("modelId1", "response"); - MLUndeployModelControllerNodeResponse response = new MLUndeployModelControllerNodeResponse( - localNode, - undeployModelControllerStatus - ); + Map undeployControllerStatus = new HashMap<>(); + undeployControllerStatus.put("modelId1", "response"); + MLUndeployControllerNodeResponse response = new MLUndeployControllerNodeResponse(localNode, undeployControllerStatus); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); - final MLUndeployModelControllerNodeResponse undeployModelControllerNodeResponse = action - .newNodeResponse(output.bytes().streamInput()); - assertNotNull(undeployModelControllerNodeResponse); + final MLUndeployControllerNodeResponse undeployControllerNodeResponse = action.newNodeResponse(output.bytes().streamInput()); + assertNotNull(undeployControllerNodeResponse); } @Test public void testNodeOperation() { - final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + final MLUndeployControllerNodesRequest request = new MLUndeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLUndeployModelControllerNodeResponse response = action.nodeOperation(new MLUndeployModelControllerNodeRequest(request)); + final MLUndeployControllerNodeResponse response = action.nodeOperation(new MLUndeployControllerNodeRequest(request)); assertNotNull(response); } @@ -170,12 +163,12 @@ public void testNodeOperationException() { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Test exception")); return null; - }).when(mlModelManager).undeployModelController(any(), any()); - final MLUndeployModelControllerNodesRequest request = new MLUndeployModelControllerNodesRequest( + }).when(mlModelManager).undeployController(any(), any()); + final MLUndeployControllerNodesRequest request = new MLUndeployControllerNodesRequest( new String[] { "nodeId1", "nodeId2" }, "testModelId" ); - final MLUndeployModelControllerNodeResponse response = action.nodeOperation(new MLUndeployModelControllerNodeRequest(request)); + final MLUndeployControllerNodeResponse response = action.nodeOperation(new MLUndeployControllerNodeRequest(request)); assertNotNull(response); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java similarity index 62% rename from plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java rename to plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java index 64fd99ee4a..cc610441e3 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateModelControllerTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/controller/UpdateControllerTransportActionTests.java @@ -46,11 +46,11 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; -import org.opensearch.ml.common.controller.MLModelController; +import org.opensearch.ml.common.controller.MLController; import org.opensearch.ml.common.controller.MLRateLimiter; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLDeployModelControllerNodesResponse; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.common.transport.controller.MLDeployControllerAction; +import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -58,7 +58,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -public class UpdateModelControllerTransportActionTests extends OpenSearchTestCase { +public class UpdateControllerTransportActionTests extends OpenSearchTestCase { @Mock ThreadPool threadPool; @@ -99,15 +99,15 @@ public class UpdateModelControllerTransportActionTests extends OpenSearchTestCas MLModel mlModel; @Mock - MLDeployModelControllerNodesResponse mlDeployModelControllerNodesResponse; + MLDeployControllerNodesResponse mlDeployControllerNodesResponse; @Rule public ExpectedException exceptionRule = ExpectedException.none(); - MLModelController modelController; - MLModelController updatedModelController; - UpdateModelControllerTransportAction updateModelControllerTransportAction; - MLUpdateModelControllerRequest updateModelControllerRequest; + MLController controller; + MLController updatedController; + UpdateControllerTransportAction updateControllerTransportAction; + MLUpdateControllerRequest updateControllerRequest; ThreadContext threadContext; @Before @@ -138,8 +138,8 @@ public void setup() throws IOException { DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; - updateModelControllerTransportAction = spy( - new UpdateModelControllerTransportAction( + updateControllerTransportAction = spy( + new UpdateControllerTransportAction( transportService, actionFilters, client, @@ -150,23 +150,23 @@ public void setup() throws IOException { ) ); - MLRateLimiter rateLimiter = MLRateLimiter.builder().rateLimitNumber("1").rateLimitUnit(TimeUnit.MILLISECONDS).build(); + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); - modelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + controller = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { { put("testUser", rateLimiter); } }).build(); - MLRateLimiter updateRateLimiter = MLRateLimiter.builder().rateLimitNumber("2").rateLimitUnit(TimeUnit.NANOSECONDS).build(); + MLRateLimiter updateRateLimiter = MLRateLimiter.builder().limit("2").unit(TimeUnit.NANOSECONDS).build(); - updatedModelController = MLModelController.builder().modelId("testModelId").userRateLimiterConfig(new HashMap<>() { + updatedController = MLController.builder().modelId("testModelId").userRateLimiter(new HashMap<>() { { put("newUser", updateRateLimiter); } }).build(); - updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(updatedModelController).build(); + updateControllerRequest = MLUpdateControllerRequest.builder().updateControllerInput(updatedController).build(); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -183,10 +183,10 @@ public void setup() throws IOException { when(mlModel.getModelId()).thenReturn("testModelId"); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(modelController); + ActionListener listener = invocation.getArgument(1); + listener.onResponse(controller); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -207,27 +207,27 @@ public void setup() throws IOException { } @Test - public void testUpdateModelControllerSuccess() { - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + public void testUpdateControllerSuccess() { + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithTextEmbeddingModelSuccess() { + public void testUpdateControllerWithTextEmbeddingModelSuccess() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithModelAccessControlNoPermission() { + public void testUpdateControllerWithModelAccessControlNoPermission() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(false); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -237,29 +237,29 @@ public void testUpdateModelControllerWithModelAccessControlNoPermission() { } @Test - public void testUpdateModelControllerWithModelAccessControlOtherException() { + public void testUpdateControllerWithModelAccessControlOtherException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateModelControllerWithModelControllerEnabledNull() { + public void testUpdateControllerWithControllerEnabledNull() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); - when(mlModel.getIsModelControllerEnabled()).thenReturn(null); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsControllerEnabled()).thenReturn(null); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -269,15 +269,15 @@ public void testUpdateModelControllerWithModelControllerEnabledNull() { } @Test - public void testUpdateModelControllerWithModelControllerNotEnabled() { + public void testUpdateControllerWithControllerNotEnabled() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); - when(mlModel.getIsModelControllerEnabled()).thenReturn(false); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsControllerEnabled()).thenReturn(false); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -287,39 +287,41 @@ public void testUpdateModelControllerWithModelControllerNotEnabled() { } @Test - public void testUpdateModelControllerWithModelControllerEnabledNotFound() { + public void testUpdateControllerWithControllerEnabledNotFound() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(mlModelManager).getModelController(eq("testModelId"), isA(ActionListener.class)); - when(mlModel.getIsModelControllerEnabled()).thenReturn(true); + }).when(mlModelManager).getController(eq("testModelId"), isA(ActionListener.class)); + when(mlModel.getIsControllerEnabled()).thenReturn(true); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateModelControllerWithModelFunctionUnsupported() { + public void testUpdateControllerWithModelFunctionUnsupported() { when(mlModel.getAlgorithm()).thenReturn(FunctionName.METRICS_CORRELATION); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", argumentCaptor.getValue().getMessage()); + assertEquals( + "Creating model controller on this operation on the function category METRICS_CORRELATION is not supported.", + argumentCaptor.getValue().getMessage()); } @Test - public void tesUpdateModelControllerWithGetModelNotFound() { + public void tesUpdateControllerWithGetModelNotFound() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); listener.onResponse(null); return null; }).when(mlModelManager).getModel(eq("testModelId"), any(), any(), isA(ActionListener.class)); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -329,21 +331,21 @@ public void tesUpdateModelControllerWithGetModelNotFound() { } @Test - public void testUpdateModelControllerWithUpdateResponseNoop() { + public void testUpdateControllerWithUpdateResponseNoop() { when(updateResponse.getResult()).thenReturn(DocWriteResponse.Result.NOOP); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithNullUpdateResponse() { + public void testUpdateControllerWithNullUpdateResponse() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(null); return null; }).when(client).update(any(), any()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals("Failed to update model controller with model ID: testModelId", argumentCaptor.getValue().getMessage()); @@ -351,64 +353,64 @@ public void testUpdateModelControllerWithNullUpdateResponse() { } @Test - public void testUpdateModelControllerWithDeploySuccessNullFailures() { + public void testUpdateControllerWithDeploySuccessNullFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(null); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(null); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithDeployNotRequiredAfterUpdateSuccess() { - updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(modelController).build(); + public void testUpdateControllerWithDeployNotRequiredAfterUpdateSuccess() { + updateControllerRequest = MLUpdateControllerRequest.builder().updateControllerInput(controller).build(); when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithModelNDeployedAndDeployNotRequiredAfterUpdateSuccess() { - updateModelControllerRequest = MLUpdateModelControllerRequest.builder().updateModelControllerInput(modelController).build(); + public void testUpdateControllerWithModelNDeployedAndDeployNotRequiredAfterUpdateSuccess() { + updateControllerRequest = MLUpdateControllerRequest.builder().updateControllerInput(controller).build(); when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(false); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithUndeploySuccessEmptyFailures() { + public void testUpdateControllerWithUndeploySuccessEmptyFailures() { when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(new ArrayList<>()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); - verify(actionListener).onResponse(updateResponse); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); + verify(actionListener).onResponse(updateResponse); } @Test - public void testUpdateModelControllerWithUndeploySuccessPartiallyFailures() { + public void testUpdateControllerWithUndeploySuccessPartiallyFailures() { List failures = List .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onResponse(mlDeployModelControllerNodesResponse); + ActionListener listener = invocation.getArgument(2); + listener.onResponse(mlDeployControllerNodesResponse); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); - when(mlDeployModelControllerNodesResponse.failures()).thenReturn(failures); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); + when(mlDeployControllerNodesResponse.failures()).thenReturn(failures); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( @@ -419,45 +421,42 @@ public void testUpdateModelControllerWithUndeploySuccessPartiallyFailures() { } @Test - public void testUpdateModelControllerWithUndeployNullResponse() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + public void testUpdateControllerWithUndeployNullResponse() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(null); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Successfully update model controller index with model ID testModelId " + "but deploy model controller to cache was failed on following nodes [foo1, foo2], please retry.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } @Test - public void testUpdateModelControllerWithUndeployOtherException() { - when(mlModelCacheHelper.isModelDeployed("testModelId")).thenReturn(true); + public void testUpdateControllerWithUndeployOtherException() { + when(mlModelCacheHelper.getWorkerNodes("testModelId")).thenReturn(new String[] { "foo1", "foo2" }); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(2); actionListener .onFailure( - new RuntimeException("Exception occurred. Please check log for more details.") - ); + new RuntimeException("Exception occurred. Please check log for more details.")); return null; - }).when(client).execute(eq(MLDeployModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLDeployControllerAction.INSTANCE), any(), any()); - updateModelControllerTransportAction.doExecute(null, updateModelControllerRequest, actionListener); + updateControllerTransportAction.doExecute(null, updateControllerRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( "Exception occurred. Please check log for more details.", - argumentCaptor.getValue().getMessage() - ); + argumentCaptor.getValue().getMessage()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java index 529163cc5c..c24d4bf816 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/UpdateModelTransportActionTests.java @@ -14,14 +14,18 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; -import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Ignore; @@ -32,13 +36,17 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; +import org.opensearch.Version; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.FailedNodeException; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -47,6 +55,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.xcontent.ToXContent; @@ -59,6 +68,7 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; +import org.opensearch.ml.common.controller.MLRateLimiter; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -143,7 +153,8 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { ThreadContext threadContext; - ClusterState testState; + @Mock + ClusterState clusterState; @Mock ClusterService clusterService; @@ -156,7 +167,6 @@ public class UpdateModelTransportActionTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - testState = setupTestClusterState(); updateLocalModelInput = MLUpdateModelInput .builder() .modelId("test_model_id") @@ -182,13 +192,40 @@ public void setup() throws IOException { ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX); + InetAddress inetAddress1 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 1 }); + InetAddress inetAddress2 = InetAddress.getByAddress(new byte[] { (byte) 192, (byte) 168, (byte) 0, (byte) 2 }); + + DiscoveryNode node1 = new DiscoveryNode( + "foo1", + "foo1", + new TransportAddress(inetAddress1, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNode node2 = new DiscoveryNode( + "foo2", + "foo2", + new TransportAddress(inetAddress2, 9300), + Collections.emptyMap(), + Collections.singleton(CLUSTER_MANAGER_ROLE), + Version.CURRENT + ); + + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + String[] targetNodeIds = new String[] { node1.getId(), node2.getId() }; + localModel = prepareMLModel("TEXT_EMBEDDING"); threadContext = new ThreadContext(settings); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - when(clusterService.state()).thenReturn(testState); + when(clusterService.state()).thenReturn(clusterState); when(clusterService.getSettings()).thenReturn(settings); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); + when(clusterState.nodes()).thenReturn(nodes); + when(mlModelManager.getWorkerNodes("test_model_id", FunctionName.REMOTE)).thenReturn(targetNodeIds); + shardId = new ShardId(new Index("indexName", "uuid"), 1); updateResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.UPDATED); @@ -705,6 +742,20 @@ public void testGetUpdateResponseListenerWithVersionBumpOtherException() { ); } + @Test + public void testGetUpdateResponseListenerWithNullUpdateResponse() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to update ML model: test_model_id", argumentCaptor.getValue().getMessage()); + } + @Test public void testGetUpdateResponseListenerWrongStatus() { UpdateResponse updateWrongResponse = new UpdateResponse(shardId, "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED); @@ -758,6 +809,24 @@ public void testUpdateModelStateDeployingException() { ); } + @Test + public void testUpdateModelStateLoadingException() { + MLModel testDeployingModel = prepareMLModel("TEXT_EMBEDDING", MLModelState.LOADING); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testDeployingModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + transportUpdateModelAction.doExecute(task, updateLocalModelRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Model is deploying. Please wait for the model to complete deployment. model ID test_model_id", + argumentCaptor.getValue().getMessage() + ); + } + @Test public void testUpdateModelCacheModelStateDeployedSuccess() { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); @@ -779,6 +848,251 @@ public void testUpdateModelCacheModelStateDeployedSuccess() { verify(actionListener).onResponse(updateResponse); } + @Test + public void testUpdateModelCacheModelWithIsModelEnabledSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); + testUpdateModelCacheRequest.getUpdateModelInput().setIsEnabled(true); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelWithoutUpdateConnectorWithRateLimiterSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); + testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelWithRateLimiterSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").unit(TimeUnit.MILLISECONDS).build(); + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelWithPartialRateLimiterSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").build(); + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); + testUpdateModelCacheRequest.getUpdateModelInput().setConnector(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheModelWithPartialRateLimiterSuccess() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLRateLimiter rateLimiter = MLRateLimiter.builder().limit("1").build(); + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + testUpdateModelCacheRequest.getUpdateModelInput().setRateLimiter(rateLimiter); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateModelCacheUpdateResponseListenerWithNullUpdateResponse() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + when(updateModelCacheNodesResponse.failures()).thenReturn(null); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Failed to update ML model: test_model_id", argumentCaptor.getValue().getMessage()); + } + + @Test + public void testUpdateModelCacheModelWithUndeploySuccessEmptyFailures() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + when(updateModelCacheNodesResponse.failures()).thenReturn(new ArrayList<>()); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + verify(actionListener).onResponse(updateResponse); + } + + @Test + public void testUpdateControllerWithUndeploySuccessPartiallyFailures() { + List failures = List + .of(new FailedNodeException("foo1", "Undeploy failed.", new RuntimeException("Exception occurred."))); + + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(updateModelCacheNodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + when(updateModelCacheNodesResponse.failures()).thenReturn(failures); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully update ML model index with model ID test_model_id but update model cache was failed on following nodes [foo1], please retry or redeploy model manually.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateControllerWithUndeployNullResponse() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Successfully update ML model index with model ID test_model_id but update model cache was failed on following nodes [foo1, foo2], please retry or redeploy model manually.", + argumentCaptor.getValue().getMessage() + ); + } + + @Test + public void testUpdateControllerWithUndeployOtherException() { + MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(testUpdateModelCacheModel); + return null; + }).when(mlModelManager).getModel(eq("test_model_id"), any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException("Exception occurred. Please check log for more details.")); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + MLUpdateModelRequest testUpdateModelCacheRequest = prepareRemoteRequest("REMOTE_INTERNAL"); + testUpdateModelCacheRequest.getUpdateModelInput().setModelGroupId(null); + transportUpdateModelAction.doExecute(task, testUpdateModelCacheRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Exception occurred. Please check log for more details.", argumentCaptor.getValue().getMessage()); + } + @Test public void testUpdateModelCacheModelStateDeployedWrongStatus() { MLModel testUpdateModelCacheModel = prepareMLModel("REMOTE_INTERNAL", MLModelState.DEPLOYED); @@ -1181,56 +1495,6 @@ private GetResponse prepareGetResponse(MLModelGroup mlModelGroup) throws IOExcep return new GetResponse(getResult); } - @Ignore - @Test - public void testUpdateModelStateLoadingException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.LOADING).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } - - @Ignore - @Test - public void testUpdateModelStateLoadedException() { - doReturn(mockUpdateModelInput).when(mockUpdateModelRequest).getUpdateModelInput(); - doReturn("mockId").when(mockUpdateModelInput).getModelId(); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(3); - listener.onResponse(mockModel); - return null; - }).when(mlModelManager).getModel(eq("mockId"), any(), any(), isA(ActionListener.class)); - - doReturn("test_model_group_id").when(mockModel).getModelGroupId(); - doReturn(FunctionName.TEXT_EMBEDDING).when(mockModel).getAlgorithm(); - doReturn(MLModelState.LOADED).when(mockModel).getModelState(); - - transportUpdateModelAction.doExecute(task, mockUpdateModelRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "ML Model mockId is in deploying or deployed state, please undeploy the models first!", - argumentCaptor.getValue().getMessage() - ); - } - @Ignore @Test public void testUpdateModelStatePartiallyLoadedException() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index 603c315a12..232290520d 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -22,15 +22,18 @@ import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; +import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.TokenBucket; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; +import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.algorithms.text_embedding.TextEmbeddingDenseModel; import org.opensearch.ml.profile.MLModelProfile; import org.opensearch.ml.profile.MLPredictRequestStats; @@ -53,6 +56,13 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase { private int maxMonitoringRequests; private List targetWorkerNodes; + private Map userRateLimiterMap; + + @Mock + private MLExecutable mlExecutor; + + @Mock + private TokenBucket rateLimiter; @Before public void setup() { @@ -70,6 +80,8 @@ public void setup() { predictor = spy(new TextEmbeddingDenseModel()); targetWorkerNodes = new ArrayList<>(); targetWorkerNodes.add(nodeId); + + userRateLimiterMap = Map.of("user1", rateLimiter); } public void testModelState() { @@ -129,6 +141,42 @@ public void testPredictor() { assertEquals(predictor, cacheHelper.getPredictor(modelId)); } + public void testExecutor() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.METRICS_CORRELATION, targetWorkerNodes, true); + assertNull(cacheHelper.getMLExecutor(modelId)); + cacheHelper.setMLExecutor(modelId, mlExecutor); + assertEquals(mlExecutor, cacheHelper.getMLExecutor(modelId)); + cacheHelper.removeModel(modelId); + assertNull(cacheHelper.getMLExecutor(modelId)); + } + + public void testRateLimiter() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.METRICS_CORRELATION, targetWorkerNodes, true); + assertNull(cacheHelper.getRateLimiter(modelId)); + cacheHelper.setRateLimiter(modelId, rateLimiter); + assertEquals(rateLimiter, cacheHelper.getRateLimiter(modelId)); + cacheHelper.removeRateLimiter(modelId); + assertNull(cacheHelper.getRateLimiter(modelId)); + } + + public void testModelEnabled() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.METRICS_CORRELATION, targetWorkerNodes, true); + assertNull(cacheHelper.getIsModelEnabled(modelId)); + cacheHelper.setIsModelEnabled(modelId, true); + assertTrue(cacheHelper.getIsModelEnabled(modelId)); + } + + public void testUserRateLimiter() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.METRICS_CORRELATION, targetWorkerNodes, true); + assertNull(cacheHelper.getUserRateLimiterMap(modelId)); + cacheHelper.setUserRateLimiterMap(modelId, userRateLimiterMap); + assertEquals(userRateLimiterMap, cacheHelper.getUserRateLimiterMap(modelId)); + assertEquals(rateLimiter, cacheHelper.getUserRateLimiter(modelId, "user1")); + assertNull(cacheHelper.getUserRateLimiter(modelId, "user2")); + cacheHelper.removeUserRateLimiterMap(modelId); + assertNull(cacheHelper.getUserRateLimiterMap(modelId)); + } + public void testGetAndRemoveModel() { assertFalse(cacheHelper.isModelRunningOnNode(modelId)); cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); @@ -201,10 +249,21 @@ public void testRemoveWorkerNode_ModelState() { public void testRemoveModel_Deployed() { cacheHelper.initModelState(modelId, MLModelState.DEPLOYING, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); - cacheHelper.setModelState(modelId, MLModelState.DEPLOYED); + cacheHelper.setIsModelEnabled(modelId, true); + cacheHelper.setRateLimiter(modelId, rateLimiter); + cacheHelper.setUserRateLimiterMap(modelId, userRateLimiterMap); cacheHelper.setPredictor(modelId, predictor); + cacheHelper.setModelState(modelId, MLModelState.DEPLOYED); cacheHelper.removeModel(modelId); verify(predictor, times(1)).close(); + assertNull(cacheHelper.getPredictor(modelId)); + assertNull(cacheHelper.getMemEstCPU(modelId)); + assertNull(cacheHelper.getMemEstGPU(modelId)); + assertNull(cacheHelper.getModelInfo(modelId)); + assertNull(cacheHelper.getIsModelEnabled(modelId)); + assertNull(cacheHelper.getRateLimiter(modelId)); + assertNull(cacheHelper.getUserRateLimiter(modelId, "user1")); + assertNull(cacheHelper.getUserRateLimiterMap(modelId)); } public void testClearWorkerNodes_NullModelState() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java similarity index 66% rename from plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java rename to plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java index 9e358b8ceb..42b4cbe92c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateModelControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLCreateControllerActionTests.java @@ -31,9 +31,9 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLCreateModelControllerRequest; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLCreateControllerAction; +import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -42,11 +42,11 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -public class RestMLCreateModelControllerActionTests extends OpenSearchTestCase { +public class RestMLCreateControllerActionTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private RestMLCreateModelControllerAction restMLCreateModelControllerAction; + private RestMLCreateControllerAction restMLCreateControllerAction; private NodeClient client; private ThreadPool threadPool; @@ -58,11 +58,11 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLCreateModelControllerAction = new RestMLCreateModelControllerAction(); + restMLCreateControllerAction = new RestMLCreateControllerAction(); doAnswer(invocation -> { invocation.getArgument(2); return null; - }).when(client).execute(eq(MLCreateModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLCreateControllerAction.INSTANCE), any(), any()); } @Override @@ -74,68 +74,68 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLCreateModelControllerAction CreateModelAction = new RestMLCreateModelControllerAction(); + RestMLCreateControllerAction CreateModelAction = new RestMLCreateControllerAction(); assertNotNull(CreateModelAction); } @Test public void testGetName() { - String actionName = restMLCreateModelControllerAction.getName(); + String actionName = restMLCreateControllerAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_create_model_controller_action", actionName); + assertEquals("ml_create_controller_action", actionName); } @Test public void testRoutes() { - List routes = restMLCreateModelControllerAction.routes(); + List routes = restMLCreateControllerAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.POST, route.getMethod()); - assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } @Test - public void testCreateModelControllerRequest() throws Exception { + public void testCreateControllerRequest() throws Exception { RestRequest request = getRestRequest(); - restMLCreateModelControllerAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateModelControllerRequest.class); - verify(client, times(1)).execute(eq(MLCreateModelControllerAction.INSTANCE), argumentCaptor.capture(), any()); - MLModelController createModelControllerInput = argumentCaptor.getValue().getModelControllerInput(); - assertEquals("testModelId", createModelControllerInput.getModelId()); + restMLCreateControllerAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLCreateControllerRequest.class); + verify(client, times(1)).execute(eq(MLCreateControllerAction.INSTANCE), argumentCaptor.capture(), any()); + MLController createControllerInput = argumentCaptor.getValue().getControllerInput(); + assertEquals("testModelId", createControllerInput.getModelId()); } @Test - public void testCreateModelControllerRequestWithEmptyContent() throws Exception { + public void testCreateControllerRequestWithEmptyContent() throws Exception { exceptionRule.expect(OpenSearchParseException.class); exceptionRule.expectMessage("Create model controller request has empty body"); RestRequest request = getRestRequestWithEmptyContent(); - restMLCreateModelControllerAction.handleRequest(request, channel, client); + restMLCreateControllerAction.handleRequest(request, channel, client); } @Test - public void testCreateModelControllerRequestWithNullModelId() throws Exception { + public void testCreateControllerRequestWithNullModelId() throws Exception { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Request should contain model_id"); RestRequest request = getRestRequestWithNullModelId(); - restMLCreateModelControllerAction.handleRequest(request, channel, client); + restMLCreateControllerAction.handleRequest(request, channel, client); } @Test - public void testCreateModelControllerRequestWithNullField() throws Exception { + public void testCreateControllerRequestWithNullField() throws Exception { exceptionRule.expect(ParsingException.class); exceptionRule.expectMessage("expecting token of type [START_OBJECT] but found [VALUE_NULL]"); RestRequest request = getRestRequestWithNullField(); - restMLCreateModelControllerAction.handleRequest(request, channel, client); + restMLCreateControllerAction.handleRequest(request, channel, client); } private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.POST; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}"; Map params = Map.of("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -147,7 +147,7 @@ private RestRequest getRestRequestWithEmptyContent() { Map params = Map.of("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -156,11 +156,11 @@ private RestRequest getRestRequestWithEmptyContent() { private RestRequest getRestRequestWithNullModelId() { RestRequest.Method method = RestRequest.Method.POST; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}"; Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -169,12 +169,12 @@ private RestRequest getRestRequestWithNullModelId() { private RestRequest getRestRequestWithNullField() { RestRequest.Method method = RestRequest.Method.POST; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":null}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":null}}"; Map params = new HashMap<>(); params.put("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java similarity index 67% rename from plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java rename to plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java index 867ec2ce33..d3ba17ae5a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteModelControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeleteControllerActionTests.java @@ -28,8 +28,8 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerDeleteRequest; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteAction; +import org.opensearch.ml.common.transport.controller.MLControllerDeleteRequest; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -38,12 +38,12 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -public class RestMLDeleteModelControllerActionTests extends OpenSearchTestCase { +public class RestMLDeleteControllerActionTests extends OpenSearchTestCase { @Rule public ExpectedException thrown = ExpectedException.none(); - private RestMLDeleteModelControllerAction restMLDeleteModelControllerAction; + private RestMLDeleteControllerAction restMLDeleteControllerAction; NodeClient client; private ThreadPool threadPool; @@ -53,7 +53,7 @@ public class RestMLDeleteModelControllerActionTests extends OpenSearchTestCase { @Before public void setup() { - restMLDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); + restMLDeleteControllerAction = new RestMLDeleteControllerAction(); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); @@ -61,7 +61,7 @@ public void setup() { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); return null; - }).when(client).execute(eq(MLModelControllerDeleteAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLControllerDeleteAction.INSTANCE), any(), any()); } @@ -73,31 +73,31 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLDeleteModelControllerAction mlDeleteModelControllerAction = new RestMLDeleteModelControllerAction(); - assertNotNull(mlDeleteModelControllerAction); + RestMLDeleteControllerAction mlDeleteControllerAction = new RestMLDeleteControllerAction(); + assertNotNull(mlDeleteControllerAction); } public void testGetName() { - String actionName = restMLDeleteModelControllerAction.getName(); + String actionName = restMLDeleteControllerAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_delete_model_controller_action", actionName); + assertEquals("ml_delete_controller_action", actionName); } public void testRoutes() { - List routes = restMLDeleteModelControllerAction.routes(); + List routes = restMLDeleteControllerAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.DELETE, route.getMethod()); - assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); - restMLDeleteModelControllerAction.handleRequest(request, channel, client); + restMLDeleteControllerAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelControllerDeleteRequest.class); - verify(client, times(1)).execute(eq(MLModelControllerDeleteAction.INSTANCE), argumentCaptor.capture(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLControllerDeleteRequest.class); + verify(client, times(1)).execute(eq(MLControllerDeleteAction.INSTANCE), argumentCaptor.capture(), any()); String taskId = argumentCaptor.getValue().getModelId(); assertEquals(taskId, "testModelId"); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java similarity index 64% rename from plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java rename to plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java index 7597aed17d..2c16533d65 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetControllerActionTests.java @@ -27,9 +27,9 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetAction; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetRequest; -import org.opensearch.ml.common.transport.controller.MLModelControllerGetResponse; +import org.opensearch.ml.common.transport.controller.MLControllerGetAction; +import org.opensearch.ml.common.transport.controller.MLControllerGetRequest; +import org.opensearch.ml.common.transport.controller.MLControllerGetResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -38,12 +38,12 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -public class RestMLGetModelControllerActionTests extends OpenSearchTestCase { +public class RestMLGetControllerActionTests extends OpenSearchTestCase { @Rule public ExpectedException thrown = ExpectedException.none(); - private RestMLGetModelControllerAction restMLGetModelControllerAction; + private RestMLGetControllerAction restMLGetControllerAction; NodeClient client; private ThreadPool threadPool; @@ -53,15 +53,15 @@ public class RestMLGetModelControllerActionTests extends OpenSearchTestCase { @Before public void setup() { - restMLGetModelControllerAction = new RestMLGetModelControllerAction(); + restMLGetControllerAction = new RestMLGetControllerAction(); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(2); return null; - }).when(client).execute(eq(MLModelControllerGetAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLControllerGetAction.INSTANCE), any(), any()); } @@ -73,31 +73,31 @@ public void tearDown() throws Exception { } public void testConstructor() { - RestMLGetModelControllerAction mlGetModelControllerAction = new RestMLGetModelControllerAction(); - assertNotNull(mlGetModelControllerAction); + RestMLGetControllerAction mlGetControllerAction = new RestMLGetControllerAction(); + assertNotNull(mlGetControllerAction); } public void testGetName() { - String actionName = restMLGetModelControllerAction.getName(); + String actionName = restMLGetControllerAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_get_model_controller_action", actionName); + assertEquals("ml_get_controller_action", actionName); } public void testRoutes() { - List routes = restMLGetModelControllerAction.routes(); + List routes = restMLGetControllerAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.GET, route.getMethod()); - assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } public void test_PrepareRequest() throws Exception { RestRequest request = getRestRequest(); - restMLGetModelControllerAction.handleRequest(request, channel, client); + restMLGetControllerAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLModelControllerGetRequest.class); - verify(client, times(1)).execute(eq(MLModelControllerGetAction.INSTANCE), argumentCaptor.capture(), any()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLControllerGetRequest.class); + verify(client, times(1)).execute(eq(MLControllerGetAction.INSTANCE), argumentCaptor.capture(), any()); String taskId = argumentCaptor.getValue().getModelId(); assertEquals(taskId, "testModelId"); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java similarity index 66% rename from plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java rename to plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java index caadbab1fb..98ab0f1e73 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateModelControllerActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLUpdateControllerActionTests.java @@ -31,9 +31,9 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.controller.MLModelController; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerAction; -import org.opensearch.ml.common.transport.controller.MLUpdateModelControllerRequest; +import org.opensearch.ml.common.controller.MLController; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; +import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; @@ -42,11 +42,11 @@ import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; -public class RestMLUpdateModelControllerActionTests extends OpenSearchTestCase { +public class RestMLUpdateControllerActionTests extends OpenSearchTestCase { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private RestMLUpdateModelControllerAction restMLUpdateModelControllerAction; + private RestMLUpdateControllerAction restMLUpdateControllerAction; private NodeClient client; private ThreadPool threadPool; @@ -58,11 +58,11 @@ public void setup() { MockitoAnnotations.openMocks(this); threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); client = spy(new NodeClient(Settings.EMPTY, threadPool)); - restMLUpdateModelControllerAction = new RestMLUpdateModelControllerAction(); + restMLUpdateControllerAction = new RestMLUpdateControllerAction(); doAnswer(invocation -> { invocation.getArgument(2); return null; - }).when(client).execute(eq(MLUpdateModelControllerAction.INSTANCE), any(), any()); + }).when(client).execute(eq(MLUpdateControllerAction.INSTANCE), any(), any()); } @Override @@ -74,68 +74,68 @@ public void tearDown() throws Exception { @Test public void testConstructor() { - RestMLUpdateModelControllerAction UpdateModelAction = new RestMLUpdateModelControllerAction(); + RestMLUpdateControllerAction UpdateModelAction = new RestMLUpdateControllerAction(); assertNotNull(UpdateModelAction); } @Test public void testGetName() { - String actionName = restMLUpdateModelControllerAction.getName(); + String actionName = restMLUpdateControllerAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); - assertEquals("ml_update_model_controller_action", actionName); + assertEquals("ml_update_controller_action", actionName); } @Test public void testRoutes() { - List routes = restMLUpdateModelControllerAction.routes(); + List routes = restMLUpdateControllerAction.routes(); assertNotNull(routes); assertFalse(routes.isEmpty()); RestHandler.Route route = routes.get(0); assertEquals(RestRequest.Method.PUT, route.getMethod()); - assertEquals("/_plugins/_ml/model_controllers/{model_id}", route.getPath()); + assertEquals("/_plugins/_ml/controllers/{model_id}", route.getPath()); } @Test - public void testUpdateModelControllerRequest() throws Exception { + public void testUpdateControllerRequest() throws Exception { RestRequest request = getRestRequest(); - restMLUpdateModelControllerAction.handleRequest(request, channel, client); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateModelControllerRequest.class); - verify(client, times(1)).execute(eq(MLUpdateModelControllerAction.INSTANCE), argumentCaptor.capture(), any()); - MLModelController updateModelControllerInput = argumentCaptor.getValue().getUpdateModelControllerInput(); - assertEquals("testModelId", updateModelControllerInput.getModelId()); + restMLUpdateControllerAction.handleRequest(request, channel, client); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLUpdateControllerRequest.class); + verify(client, times(1)).execute(eq(MLUpdateControllerAction.INSTANCE), argumentCaptor.capture(), any()); + MLController updateControllerInput = argumentCaptor.getValue().getUpdateControllerInput(); + assertEquals("testModelId", updateControllerInput.getModelId()); } @Test - public void testUpdateModelControllerRequestWithEmptyContent() throws Exception { + public void testUpdateControllerRequestWithEmptyContent() throws Exception { exceptionRule.expect(OpenSearchParseException.class); exceptionRule.expectMessage("Update model controller request has empty body"); RestRequest request = getRestRequestWithEmptyContent(); - restMLUpdateModelControllerAction.handleRequest(request, channel, client); + restMLUpdateControllerAction.handleRequest(request, channel, client); } @Test - public void testUpdateModelControllerRequestWithNullModelId() throws Exception { + public void testUpdateControllerRequestWithNullModelId() throws Exception { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Request should contain model_id"); RestRequest request = getRestRequestWithNullModelId(); - restMLUpdateModelControllerAction.handleRequest(request, channel, client); + restMLUpdateControllerAction.handleRequest(request, channel, client); } @Test - public void testUpdateModelControllerRequestWithNullField() throws Exception { + public void testUpdateControllerRequestWithNullField() throws Exception { exceptionRule.expect(ParsingException.class); exceptionRule.expectMessage("expecting token of type [START_OBJECT] but found [VALUE_NULL]"); RestRequest request = getRestRequestWithNullField(); - restMLUpdateModelControllerAction.handleRequest(request, channel, client); + restMLUpdateControllerAction.handleRequest(request, channel, client); } private RestRequest getRestRequest() { RestRequest.Method method = RestRequest.Method.PUT; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}"; Map params = Map.of("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -147,7 +147,7 @@ private RestRequest getRestRequestWithEmptyContent() { Map params = Map.of("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(""), XContentType.JSON) .build(); @@ -156,11 +156,11 @@ private RestRequest getRestRequestWithEmptyContent() { private RestRequest getRestRequestWithNullModelId() { RestRequest.Method method = RestRequest.Method.PUT; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":{}}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":{}}}"; Map params = new HashMap<>(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build(); @@ -169,12 +169,12 @@ private RestRequest getRestRequestWithNullModelId() { private RestRequest getRestRequestWithNullField() { RestRequest.Method method = RestRequest.Method.PUT; - String requestContent = "{\"user_rate_limiter_config\":{\"testUser\":null}}"; + String requestContent = "{\"user_rate_limiter\":{\"testUser\":null}}"; Map params = new HashMap<>(); params.put("model_id", "testModelId"); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) - .withPath("/_plugins/_ml/model_controllers/{model_id}") + .withPath("/_plugins/_ml/controllers/{model_id}") .withParams(params) .withContent(new BytesArray(requestContent), XContentType.JSON) .build();