Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.11] throw exception when model group not found during update request #1451

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
Expand Down Expand Up @@ -90,39 +89,41 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
MLUpdateModelGroupInput updateModelGroupInput = updateModelGroupRequest.getUpdateModelGroupInput();
String modelGroupId = updateModelGroupInput.getModelGroupID();
User user = RestActionUtils.getUserContext(client);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
XContentParser parser = MLNodeUtils
.createXContentParserFromRegistry(NamedXContentRegistry.EMPTY, modelGroup.getSourceAsBytesRef())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
}
} else {
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
} catch (Exception e) {
log.error("Failed to parse ml model group" + modelGroup.getId(), e);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
updateModelGroup(modelGroupId, new HashMap<>(), updateModelGroupInput, listener, user);
} else {
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
logException("Failed to Update model group", e, log);
listener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.get.GetResult;
import org.opensearch.ml.common.AccessMode;
import org.opensearch.ml.common.MLModelGroup;
Expand Down Expand Up @@ -367,6 +368,20 @@ public void test_FailedToGetModelGroupException() {
assertEquals("Failed to get model group", argumentCaptor.getValue().getMessage());
}

public void test_ModelGroupIndexNotFoundException() {
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new IndexNotFoundException("Fail to find model group"));
return null;
}).when(client).get(any(), any());

MLUpdateModelGroupRequest actionRequest = prepareRequest(null, AccessMode.RESTRICTED, null);
transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Fail to find model group", argumentCaptor.getValue().getMessage());
}

public void test_FailedToUpdatetModelGroupException() {
doAnswer(invocation -> {
ActionListener<UpdateResponse> listener = invocation.getArgument(1);
Expand Down Expand Up @@ -414,15 +429,16 @@ public void test_ModelGroupNameNotUnique() throws IOException {
}

public void test_ExceptionSecurityDisabledCluster() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule
.expectMessage(
"You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster."
);
when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false);

MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, true);
transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"You cannot specify model access control parameters because the Security plugin or model access control is disabled on your cluster.",
argumentCaptor.getValue().getMessage()
);
}

private MLUpdateModelGroupRequest prepareRequest(List<String> backendRoles, AccessMode modelAccessMode, Boolean isAddAllBackendRoles) {
Expand Down
Loading