diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a6e1cebe85..c739a29faa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -262,6 +262,7 @@ private void createModelGroup(MLRegisterModelInput registerModelInput, ActionLis listener.onFailure(e); })); } else { + registerModelInput.setDoesVersionCreateModelGroup(false); registerModel(registerModelInput, listener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 66c78d2431..5153b2845b 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -495,6 +495,42 @@ public void test_ModelNameAlreadyExists() throws IOException { verify(actionListener).onResponse(argumentCaptor.capture()); } + public void test_DoesVersionCreateModelGroupFieldSetToTrueByUserByMistake() throws IOException { + when(node1.getId()).thenReturn("NodeId1"); + when(node2.getId()).thenReturn("NodeId2"); + MLForwardResponse forwardResponse = Mockito.mock(MLForwardResponse.class); + doAnswer(invocation -> { + ActionListenerResponseHandler handler = invocation.getArgument(3); + handler.handleResponse(forwardResponse); + return null; + }).when(transportService).sendRequest(any(), any(), any(), any()); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .functionName(FunctionName.BATCH_RCF) + .modelGroupId("model_group_ID") + .modelName("Test Model") + .modelConfig( + new TextEmbeddingModelConfig( + "CUSTOM", + 123, + TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS, + "all config", + TextEmbeddingModelConfig.PoolingMode.MEAN, + true, + 512 + ) + ) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .url("http://test_url") + .doesVersionCreateModelGroup(true) + .build(); + + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLRegisterModelResponse.class); + verify(actionListener).onResponse(argumentCaptor.capture()); + } + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { SearchResponse searchResponse = createModelGroupSearchResponse(1); doAnswer(invocation -> {