Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix more places where thread context not restored #1421

Merged
merged 3 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we can't provide this listener as generic from utils because context doesn't exist?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

List<String> excludes = Optional
.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
Expand All @@ -98,9 +99,9 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
);
request.source().fetchSource(rebuiltFetchSourceContext);
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchRequest modelGroupSearchRequest = new SearchRequest();
Expand All @@ -119,15 +120,15 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });

request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
client.search(request, listener);
client.search(request, wrappedListener);
} else {
log.debug("No model group found");
request.source().query(rewriteQueryBuilder(request.source().query(), null));
client.search(request, listener);
client.search(request, wrappedListener);
}
}, e -> {
log.error("Fail to search model groups!", e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
});
client.search(modelGroupSearchRequest, modelGroupSearchActionListener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
actionListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));
Expand All @@ -87,13 +88,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId);
actionListener.onResponse(deleteResponse);
wrappedListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
});
} else {
Expand All @@ -102,16 +103,16 @@ public void onFailure(Exception e) {

}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
}));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + modelGroupId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
try (
Expand All @@ -102,17 +103,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, listener, user);
updateModelGroup(modelGroupId, modelGroup.getSource(), updateModelGroupInput, wrappedListener, user);
}
} else {
listener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model group", RestStatus.NOT_FOUND));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
logException("Failed to get model group", e, log);
listener.onFailure(e);
wrappedListener.onFailure(e);
}
}));
} catch (Exception e) {
Expand Down Expand Up @@ -186,15 +187,16 @@ private void updateModelGroup(String modelGroupId, Map<String, Object> source, A
UpdateRequest updateModelGroupRequest = new UpdateRequest();
updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLUpdateModelGroupResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
client
.update(
updateModelGroupRequest,
ActionListener.wrap(r -> { listener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> {
ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> {
if (e instanceof IndexNotFoundException) {
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
} else {
log.error("Failed to update model group", e, log);
listener.onFailure(new MLValidationException("Failed to update Model Group"));
wrappedListener.onFailure(new MLValidationException("Failed to update Model Group"));
}
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
Expand All @@ -113,7 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener
wrappedListener
.onFailure(
new MLValidationException("User doesn't have privilege to perform this operation on this model")
);
Expand All @@ -125,7 +126,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
|| mlModelState.equals(MLModelState.DEPLOYED)
|| mlModelState.equals(MLModelState.DEPLOYING)
|| mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED)) {
actionListener
wrappedListener
.onFailure(
new Exception(
"Model cannot be deleted in deploying or deployed state. Try undeploy model first then delete"
Expand All @@ -140,27 +141,27 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
&& response.getHits().getTotalHits().value == 1) {
isLastModelOfGroup = true;
}
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, actionListener);
deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, wrappedListener);
}, e -> {
log.error("Failed to Search Model index " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
} else {
deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener);
deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener);
}
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to parse ml model " + r.getId(), e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
}, e -> { actionListener.onFailure(e); }));
}, e -> { wrappedListener.onFailure(e); }));
} catch (Exception e) {
log.error("Failed to delete ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
ActionListener<MLModelGetResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Expand All @@ -90,7 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
modelAccessControlHelper
.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
actionListener
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
Expand All @@ -100,19 +101,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
if (connector != null) {
connector.removeCredential();
}
actionListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
wrappedListener.onResponse(MLModelGetResponse.builder().mlModel(mlModel).build());
}
}, e -> {
log.error("Failed to validate Access for Model Id " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}));

} catch (Exception e) {
log.error("Failed to parse ml model " + r.getId(), e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model with the provided model id: " + modelId,
Expand All @@ -122,12 +123,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLMode
}
}, e -> {
if (e instanceof IndexNotFoundException) {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model"));
} else {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
}), () -> context.restore()));
}));
} catch (Exception e) {
log.error("Failed to get ML model " + modelId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
GetRequest getRequest = new GetRequest(ML_TASK_INDEX).id(taskId);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<DeleteResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
client.get(getRequest, ActionListener.wrap(r -> {

if (r != null && r.isExists()) {
Expand All @@ -72,24 +73,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Task Request, task id:{} deleted", taskId);
actionListener.onResponse(deleteResponse);
wrappedListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Task " + taskId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error("Failed to parse ML task " + taskId, e);
actionListener.onFailure(e);
wrappedListener.onFailure(e);
}
} else {
actionListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task"));
}
}, e -> { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
}, e -> { wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task")); }));
} catch (Exception e) {
log.error("Failed to delete ml task " + taskId, e);
actionListener.onFailure(e);
Expand Down
Loading