diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index 494f197857..9d6dc32b86 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -9,7 +9,6 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.utils.MLExceptionUtils.logException; -import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; @@ -90,39 +89,41 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { - if (modelGroup.isExists()) { - try ( - XContentParser parser = MLNodeUtils - .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); + client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { + if (modelGroup.isExists()) { + try ( + XContentParser parser = MLNodeUtils + .createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLModelGroup mlModelGroup = MLModelGroup.parse(parser); + if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) { validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup); - updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user); + } else { + validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); } - } else { - wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); - } else { - logException("Failed to get model group", e, log); + updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user); + } catch (Exception e) { + log.error("Failed to parse ml model group" + modelGroup.getId(), e); wrappedListener.onFailure(e); } - })); - } catch (Exception e) { - logException("Failed to Update model group", e, log); - listener.onFailure(e); - } - } else { - validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput); - updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user); + } else { + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND)); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + } else { + logException("Failed to get model group", e, log); + wrappedListener.onFailure(e); + } + })); + } catch (Exception e) { + logException("Failed to Update model group", e, log); + listener.onFailure(e); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java index 1a67977291..a99488d0e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.MLModelGroup; @@ -367,6 +368,20 @@ public void test_FailedToGetModelGroupException() { assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage()); } + public void test_ModelGroupIndexNotFoundException() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IndexNotFoundException("Fail to find model group")); + return null; + }).when(client).get(any(), any()); + + MLUpdateModelGroupRequest actionRequest = prepareRequest(null, AccessMode.RESTRICTED, null); + transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage()); + } + public void test_FailedToUpdatetModelGroupException() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -414,15 +429,16 @@ public void test_ModelGroupNameNotUnique() throws IOException { } public void test_ExceptionSecurityDisabledCluster() { - exceptionRule.expect(IllegalArgumentException.class); - exceptionRule - .expectMessage( - "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster." - ); when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false); MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, true); transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.", + argumentCaptor.getValue().getMessage() + ); } private MLUpdateModelGroupRequest prepareRequest(List backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) {