From 127ba1ecfb9a5f3692df17a2e0d102b3a9480acb Mon Sep 17 00:00:00 2001 From: zane-neo Date: Thu, 25 May 2023 00:54:42 +0800 Subject: [PATCH] 1. revert model content hash change; 2.fix task search with permission issue; 3.fix failure UTs (#910) Signed-off-by: Zan Niu --- .../transport/forward/MLForwardInputTest.java | 1 + .../forward/MLForwardRequestTest.java | 1 + .../register/MLRegisterModelInputTest.java | 17 +++++++++---- .../register/MLRegisterModelRequestTest.java | 1 + .../MLRegisterModelMetaInputTest.java | 11 ++++++-- .../org/opensearch/ml/engine/ModelHelper.java | 2 +- .../MetricsCorrelation.java | 1 + .../text_embedding/ModelHelperTest.java | 6 +++++ .../DeleteModelGroupTransportAction.java | 2 +- .../tasks/SearchTaskTransportAction.java | 17 ++++++++++--- .../forward/TransportForwardActionTests.java | 1 + .../tasks/SearchTaskTransportActionTests.java | 18 ++++++++----- .../ml/model/MLModelManagerTests.java | 25 ++++++++++++++++++- .../rest/RestMLRegisterModelActionTests.java | 2 +- .../RestMLRegisterModelMetaActionTests.java | 2 ++ .../org/opensearch/ml/utils/MockHelper.java | 6 ++++- 16 files changed, 91 insertions(+), 22 deletions(-) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java index 6136c99dc2..f7f1b6901d 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardInputTest.java @@ -75,6 +75,7 @@ public void setUp() throws Exception { .functionName(functionName) .modelName("testModelName") .version("testModelVersion") + .modelGroupId("mockModelGroupId") .url("url") .modelFormat(MLModelFormat.ONNX) .modelConfig(config) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java index a8e2ad6338..735b459c22 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/forward/MLForwardRequestTest.java @@ -79,6 +79,7 @@ public void setUp() throws Exception { .functionName(functionName) .modelName("testModelName") .version("testModelVersion") + .modelGroupId("modelGroupId") .url("url") .modelFormat(MLModelFormat.ONNX) .modelConfig(config) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java index 1de3623c99..cb7b61ca50 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelInputTest.java @@ -42,7 +42,7 @@ public class MLRegisterModelInputTest { @Rule public ExpectedException exceptionRule = ExpectedException.none(); - private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," + + private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"url\":\"url\",\"model_format\":\"ONNX\"," + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + "},\"deploy_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; @@ -51,6 +51,8 @@ public class MLRegisterModelInputTest { private final String version = "version"; private final String url = "url"; + private final String modelGroupId = "modelGroupId"; + @Before public void setUp() throws Exception { config = TextEmbeddingModelConfig.builder() @@ -64,6 +66,7 @@ public void setUp() throws Exception { .functionName(functionName) .modelName(modelName) .version(version) + .modelGroupId(modelGroupId) .url(url) .modelFormat(MLModelFormat.ONNX) .modelConfig(config) @@ -86,18 +89,19 @@ public void constructor_NullModelName() { exceptionRule.expectMessage("model name is null"); MLRegisterModelInput.builder() .functionName(functionName) + .modelGroupId(modelGroupId) .modelName(null) .build(); } @Test - public void constructor_NullModelVersion() { + public void constructor_NullModelGroupId() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("model version is null"); + exceptionRule.expectMessage("model group id is null"); MLRegisterModelInput.builder() .functionName(functionName) .modelName(modelName) - .version(null) + .modelGroupId(null) .build(); } @@ -109,6 +113,7 @@ public void constructor_NullModelFormat() { .functionName(functionName) .modelName(modelName) .version(version) + .modelGroupId(modelGroupId) .modelFormat(null) .url(url) .build(); @@ -122,6 +127,7 @@ public void constructor_NullModelConfig() { .functionName(functionName) .modelName(modelName) .version(version) + .modelGroupId(modelGroupId) .modelFormat(MLModelFormat.ONNX) .modelConfig(null) .url(url) @@ -133,6 +139,7 @@ public void constructor_SuccessWithMinimalSetup() { MLRegisterModelInput input = MLRegisterModelInput.builder() .modelName(modelName) .version(version) + .modelGroupId(modelGroupId) .modelFormat(MLModelFormat.ONNX) .modelConfig(config) .url(url) @@ -158,7 +165,7 @@ public void testToXContent() throws Exception { public void testToXContent_Incomplete() throws Exception { String expectedIncompleteInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\"," + - "\"version\":\"version\",\"deploy_model\":true}"; + "\"version\":\"version\",\"model_group_id\":\"modelGroupId\",\"deploy_model\":true}"; input.setUrl(null); input.setModelConfig(null); input.setModelFormat(null); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java index 3ea52a36ec..b5289da2a7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/register/MLRegisterModelRequestTest.java @@ -35,6 +35,7 @@ public void setUp(){ .functionName(FunctionName.KMEANS) .modelName("modelName") .version("version") + .modelGroupId("modelGroupId") .url("url") .modelFormat(MLModelFormat.ONNX) .modelConfig(config) diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java index 7f89d9914e..a27c556642 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLRegisterModelMetaInputTest.java @@ -71,11 +71,18 @@ private void readInputStream(MLRegisterModelMetaInput input) throws IOException @Test - public void testToXContent() throws IOException { + public void testToXContent() throws IOException {{ XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + + final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; + assertEquals(expected, mlModelContent); + } + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + mLRegisterModelMetaInput.toXContent(builder, EMPTY_PARAMS); + String mlModelContent = TestHelper.xContentBuilderToString(builder); + final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"model_group_id\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"DEPLOYING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_mode\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; assertEquals(expected, mlModelContent); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java index bf7a7b31fd..abcc9a9ecb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/ModelHelper.java @@ -203,7 +203,7 @@ public void downloadAndSplit(MLModelFormat modelFormat, String taskId, String mo DownloadUtils.download(url, modelPath, new ProgressBar()); verifyModelZipFile(modelFormat, modelPath, modelName); String hash = calculateFileHash(modelZipFile); - if (modelContentHash == null || hash.equals(modelContentHash)) { + if (hash.equals(modelContentHash)) { List chunkFiles = splitFileIntoChunks(modelZipFile, modelPartsPath, CHUNK_SIZE); Map result = new HashMap<>(); result.put(CHUNK_FILES, chunkFiles); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 8d3eca4c02..7177e6e862 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -196,6 +196,7 @@ void registerModel(ActionListener listener) throws Inte .functionName(functionName) .modelName(FunctionName.METRICS_CORRELATION.name()) .version(MCORR_ML_VERSION) + .modelGroupId(functionName.name()) .modelFormat(modelFormat) .hashValue(MODEL_CONTENT_HASH) .modelConfig(modelConfig) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index e06c7ec890..fc2cf82f4e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -141,6 +141,7 @@ public void testDownloadPrebuiltModelConfig_WrongModelName() { MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("test_model_name") .version("1.0.1") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) @@ -157,6 +158,7 @@ public void testDownloadPrebuiltModelConfig() { MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("huggingface/sentence-transformers/all-mpnet-base-v2") .version("1.0.1") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) @@ -176,6 +178,7 @@ public void testDownloadPrebuiltModelMetaList() throws PrivilegedActionException MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("huggingface/sentence-transformers/all-mpnet-base-v2") .version("1.0.1") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) @@ -190,6 +193,7 @@ public void testIsModelAllowed_true() throws PrivilegedActionException { MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("huggingface/sentence-transformers/all-mpnet-base-v2") .version("1.0.1") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) @@ -204,6 +208,7 @@ public void testIsModelAllowed_WrongModelName() throws PrivilegedActionException MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("huggingface/sentence-transformers/all-mpnet-base-v2-wrong") .version("1.0.1") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) @@ -218,6 +223,7 @@ public void testIsModelAllowed_WrongModelVersion() throws PrivilegedActionExcept MLRegisterModelInput registerModelInput = MLRegisterModelInput.builder() .modelName("huggingface/sentence-transformers/all-mpnet-base-v2") .version("000") + .modelGroupId("mockGroupId") .modelFormat(modelFormat) .deployModel(false) .modelNodeIds(new String[]{"node_id1"}) diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index c271b1508a..d1367fba90 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -73,7 +73,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - actionListener.onFailure(new MLValidationException("User Doesn't have previlege to perform this operation")); + actionListener.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation")); } else { BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java index 4d811064bc..317a14e5cf 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java @@ -5,28 +5,37 @@ package org.opensearch.ml.action.tasks; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.common.transport.task.MLTaskSearchAction; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +@Log4j2 public class SearchTaskTransportAction extends HandledTransportAction { - private MLSearchHandler mlSearchHandler; + private Client client; @Inject - public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, MLSearchHandler mlSearchHandler) { + public SearchTaskTransportAction(TransportService transportService, ActionFilters actionFilters, Client client) { super(MLTaskSearchAction.NAME, transportService, actionFilters, SearchRequest::new); - this.mlSearchHandler = mlSearchHandler; + this.client = client; } @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - mlSearchHandler.search(request, actionListener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + client.search(request, actionListener); + } catch (Exception e) { + log.error(e.getMessage(), e); + actionListener.onFailure(e); + } } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java index 16f2ea804f..9acce3c108 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/forward/TransportForwardActionTests.java @@ -328,6 +328,7 @@ private MLRegisterModelInput prepareInput() { .functionName(FunctionName.BATCH_RCF) .deployModel(true) .version("1.0") + .modelGroupId("model group id") .modelName("Test Model") .modelConfig( new TextEmbeddingModelConfig( diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java index 219ce7befb..8dc9f354c9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -5,8 +5,9 @@ package org.opensearch.ml.action.tasks; -import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import org.junit.Before; import org.mockito.Mock; @@ -16,12 +17,15 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.action.handler.MLSearchHandler; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; + public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock Client client; @@ -41,7 +45,6 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Mock ActionListener actionListener; - MLSearchHandler mlSearchHandler; SearchTaskTransportAction searchTaskTransportAction; @Mock @@ -50,12 +53,15 @@ public class SearchTaskTransportActionTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper)); - searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, mlSearchHandler); + searchTaskTransportAction = new SearchTaskTransportAction(transportService, actionFilters, client); + ThreadPool threadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(threadPool); + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + when(threadPool.getThreadContext()).thenReturn(threadContext); } public void test_DoExecute() { searchTaskTransportAction.doExecute(null, searchRequest, actionListener); - verify(mlSearchHandler).search(searchRequest, actionListener); + verify(client).search(searchRequest, actionListener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index c2aaa6b3e6..5b7112e071 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -11,6 +11,7 @@ import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; @@ -50,6 +51,7 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.Base64; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -64,7 +66,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -178,7 +183,7 @@ public void setup() throws URISyntaxException { modelName = "model_name1"; modelId = randomAlphaOfLength(10); modelContentHashValue = "c446f747520bcc6af053813cb1e8d34944a7c4686bbb405aeaa23883b5a806c8"; - version = "1.0.0"; + version = "1"; url = "http://testurl"; MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() @@ -191,6 +196,7 @@ public void setup() throws URISyntaxException { .builder() .modelName(modelName) .version(version) + .modelGroupId("modelGroupId") .functionName(FunctionName.TEXT_EMBEDDING) .modelFormat(modelFormat) .modelConfig(modelConfig) @@ -263,6 +269,23 @@ public void setup() throws URISyntaxException { .build(); modelChunk0 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk1".getBytes(StandardCharsets.UTF_8))).build(); modelChunk1 = model.toBuilder().content(Base64.getEncoder().encodeToString("test chunk2".getBytes(StandardCharsets.UTF_8))).build(); + + GetResponse getResponse = mock(GetResponse.class); + when(getResponse.isExists()).thenReturn(true); + Map sourceMap = new HashMap<>(); + sourceMap.put("latest_version", 0); + when(getResponse.getSourceAsMap()).thenReturn(sourceMap); + doAnswer(invocation -> { + ActionListener getResponseActionListener = invocation.getArgument(1); + getResponseActionListener.onResponse(getResponse); + return null; + }).when(client).get(any(GetRequest.class), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener updateActionListener = invocation.getArgument(1); + updateActionListener.onResponse(null); + return null; + }).when(client).update(any(UpdateRequest.class), isA(ActionListener.class)); } public void testRegisterMLModel_ExceedMaxRunningTask() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java index 52c4850d88..cb1a666eb6 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelActionTests.java @@ -180,7 +180,7 @@ private RestRequest getRestRequest_NullModelId() { final Map modelConfig = Map .of("model_type", "bert", "embedding_dimension", 384, "framework_type", "sentence_transformers", "all_config", "All Config"); final Map model = Map - .of("name", "test_model", "version", "2", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig); + .of("name", "test_model", "version", "2", "model_group_id", "modelGroupId", "url", "testUrl", "model_format", "TORCH_SCRIPT", "model_config", modelConfig); String requestContent = new Gson().toJson(model).toString(); RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) .withMethod(method) diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java index 516859b3ac..cc3c8ff97a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRegisterModelMetaActionTests.java @@ -170,6 +170,8 @@ private String prepareCustomModel() { "all-MiniLM-L6-v3", "version", "1", + "model_group_id", + "1", "model_format", "TORCH_SCRIPT", "model_task_type", diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java index 894cb7aec1..d4604db332 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MockHelper.java @@ -105,7 +105,11 @@ public static void mock_client_index(Client client, String modelId) { public static void mock_client_update_failure(Client client) { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("failed to update")); + listener.onResponse(null); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("update failure")); return null; }).when(client).update(any(), any()); }