Skip to content

Commit

Permalink
fix more places where thread context not restored (#1421) (#1423)
Browse files Browse the repository at this point in the history
* fix more places where thread context not restored

* fix failed ut

* remove unused import

---------

Signed-off-by: Yaliang Wu <[email protected]>
(cherry picked from commit d8c1162)
  • Loading branch information
ylwu-amzn authored and github-actions[bot] committed Oct 4, 2023
1 parent 8b6199d commit c57f0de
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
List<String> excludes = Optional
.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
Expand All @@ -98,9 +99,9 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
);
request.source().fetchSource(rebuiltFetchSourceContext);
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchRequest modelGroupSearchRequest = new SearchRequest();
Expand All @@ -119,15 +120,15 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });

request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
client.search(request, listener);
client.search(request, wrappedListener);
} else {
log.debug("No model group found");
request.source().query(rewriteQueryBuilder(request.source().query(), null));
client.search(request, listener);
client.search(request, wrappedListener);
}
}, e -> {
log.error("Fail to search model groups!", e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
});
client.search(modelGroupSearchRequest, modelGroupSearchActionListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));
Expand All @@ -87,13 +88,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId);
actionListener.onResponse(deleteResponse);
wrappedListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
});
} else {
Expand All @@ -102,16 +103,16 @@ public void onFailure(Exception e) {

}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
}));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
Expand All @@ -102,17 +103,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
}
} else {
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
listener.onFailure(e);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
Expand Down Expand Up @@ -186,15 +187,16 @@ private void updateModelGroup(String modelGroupId, Map<String, Object> source, A
UpdateRequest updateModelGroupRequest = new UpdateRequest();
updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client
.update(
updateModelGroupRequest,
ActionListener.wrap(r -> { listener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> {
ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
log.error("Failed to update model group", e, log);
listener.onFailure(new MLValidationException("Failed to update Model Group"));
wrappedListener.onFailure(new MLValidationException("Failed to update Model Group"));
}
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
Expand All @@ -113,7 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener
wrappedListener
.onFailure(
new MLValidationException("User doesn't have privilege to perform this operation on this model")
);
Expand All @@ -125,7 +126,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
actionListener
wrappedListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
Expand All @@ -140,27 +141,27 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
&& response.getHits().getTotalHits().value == 1) {
isLastModelOfGroup = true;
}
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, actionListener);
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, wrappedListener);
}, e -> {
log.error("Failed to Search Model index " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
} else {
deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener);
deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener);
}
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to parse ml model " + r.getId(), e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
}, e -> { actionListener.onFailure(e); }));
}, e -> { wrappedListener.onFailure(e); }));
} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
ActionListener<MLModelGetResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Expand All @@ -90,7 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
Expand All @@ -100,19 +101,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
if (connector != null) {
connector.removeCredential();
}
actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));

} catch (Exception e) {
log.error("Failed to parse ml model " + r.getId(), e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model with the provided model id: " + modelId,
Expand All @@ -122,12 +123,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
} else {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
}), () -> context.restore()));
}));
} catch (Exception e) {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {

if (r != null && r.isExists()) {
Expand All @@ -72,24 +73,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
wrappedListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ML task " + taskId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
}
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
}, e -> { wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
} catch (Exception e) {
log.error("Failed to delete ml task " + taskId, e);
actionListener.onFailure(e);
Expand Down
Loading

0 comments on commit c57f0de

Please sign in to comment.