diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index c4fd51eb29..c8e0af525e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -80,6 +80,7 @@ public void search(SearchRequest request, ActionListener actionL User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -98,9 +99,9 @@ public void search(SearchRequest request, ActionListener actionL ); request.source().fetchSource(rebuiltFetchSourceContext); if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, listener); + client.search(request, wrappedListener); } else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { - client.search(request, listener); + client.search(request, wrappedListener); } else { SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); SearchRequest modelGroupSearchRequest = new SearchRequest(); @@ -119,15 +120,15 @@ public void search(SearchRequest request, ActionListener actionL Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); }); request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds)); - client.search(request, listener); + client.search(request, wrappedListener); } else { log.debug("No model group found"); request.source().query(rewriteQueryBuilder(request.source().query(), null)); - client.search(request, listener); + client.search(request, wrappedListener); } }, e -> { log.error("Fail to search model groups!", e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); }); client.search(modelGroupSearchRequest, modelGroupSearchActionListener); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java index 295afab26c..0fb5ff55c4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/DeleteModelGroupTransportAction.java @@ -72,9 +72,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> { if (!access) { - actionListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); + wrappedListener.onFailure(new MLValidationException("User doesn't have privilege to delete this model group")); } else { BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId)); @@ -87,13 +88,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - actionListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } })); } }, e -> { log.error("Failed to validate Access for Model Group " + modelGroupId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java index ea699da4af..3e4b9cb78a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java @@ -93,6 +93,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { try ( @@ -102,17 +103,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { logException("Failed to get model group", e, log); - listener.onFailure(e); + wrappedListener.onFailure(e); } })); } catch (Exception e) { @@ -186,15 +187,16 @@ private void updateModelGroup(String modelGroupId, Map source, A UpdateRequest updateModelGroupRequest = new UpdateRequest(); updateModelGroupRequest.index(ML_MODEL_GROUP_INDEX).id(modelGroupId).doc(source); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client .update( updateModelGroupRequest, - ActionListener.wrap(r -> { listener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { + ActionListener.wrap(r -> { wrappedListener.onResponse(new MLUpdateModelGroupResponse("Updated")); }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to update model group", e, log); - listener.onFailure(new MLValidationException("Failed to update Model Group")); + wrappedListener.onFailure(new MLValidationException("Failed to update Model Group")); } }) ); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 6e193d6909..812cfc7f9f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -99,6 +99,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { @@ -113,7 +114,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - actionListener + wrappedListener .onFailure( new MLValidationException("User doesn't have privilege to perform this operation on this model") ); @@ -125,7 +126,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to Search Model index " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } else { - deleteModel(modelId, mlModel.getModelGroupId(), false, actionListener); + deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener); } } }, e -> { log.error("Failed to validate Access for Model Id " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } } else { - actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); } - }, e -> { actionListener.onFailure(e); })); + }, e -> { wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to delete ML model " + modelId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java index 05c4deaa86..622ceec8f0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/GetModelTransportAction.java @@ -79,7 +79,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); + client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -90,7 +91,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - actionListener + wrappedListener .onFailure( new MLValidationException("User Doesn't have privilege to perform this operation on this model") ); @@ -100,19 +101,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to validate Access for Model Id " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } } else { - actionListener + wrappedListener .onFailure( new OpenSearchStatusException( "Failed to find model with the provided model id: " + modelId, @@ -122,12 +123,12 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { - actionListener.onFailure(new MLResourceNotFoundException("Fail to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model")); } else { log.error("Failed to get ML model " + modelId, e); - actionListener.onFailure(e); + wrappedListener.onFailure(e); } - }), () -> context.restore())); + })); } catch (Exception e) { log.error("Failed to get ML model " + modelId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java index 26b808faa9..a5aa023c79 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/DeleteTaskTransportAction.java @@ -57,6 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { @@ -72,24 +73,24 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { actionListener.onFailure(new MLResourceNotFoundException("Fail to find task")); })); + }, e -> { wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find task")); })); } catch (Exception e) { log.error("Failed to delete ml task " + taskId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java index 11e9c67d2c..8ae9dfaaa8 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java @@ -57,7 +57,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + client.get(getRequest, ActionListener.runBefore(ActionListener.wrap(r -> { log.debug("Completed Get Task Request, id:{}", taskId); if (r != null && r.isExists()) { @@ -79,7 +79,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener context.restore())); } catch (Exception e) { log.error("Failed to get ML task " + taskId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java index 166f776260..5f35fdf788 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/SearchTaskTransportAction.java @@ -32,7 +32,7 @@ public SearchTaskTransportAction(TransportService transportService, ActionFilter @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.search(request, actionListener); + client.search(request, ActionListener.runBefore(actionListener, () -> context.restore())); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 103ea757e1..7801397ab1 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -116,12 +116,12 @@ private void validateAccess(String modelId, ActionListener listener) { User user = RestActionUtils.getUserContext(client); String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD }; try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { + mlModelManager.getModel(modelId, null, excludes, ActionListener.runBefore(ActionListener.wrap(mlModel -> { modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, listener); }, e -> { log.error("Failed to find Model", e); listener.onFailure(e); - })); + }), () -> context.restore())); } catch (Exception e) { log.error("Failed to undeploy ML model"); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java index 77be5ef425..1b9c1868bf 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java +++ b/plugin/src/main/java/org/opensearch/ml/action/upload_chunk/MLModelChunkUploader.java @@ -68,6 +68,7 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); client.get(getRequest, ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { @@ -82,7 +83,7 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti .validateModelGroupAccess(user, existingModel.getModelGroupId(), client, ActionListener.wrap(access -> { if (!access) { log.error("You don't have permissions to perform this operation on this model."); - listener + wrappedListener .onFailure( new IllegalArgumentException( "You don't have permissions to perform this operation on this model." @@ -167,36 +168,36 @@ public void uploadModelChunk(MLUploadModelChunkInput uploadModelChunkInput, Acti }, e -> { log.error("Failed to update model state", e); semaphore.release(); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } - listener.onResponse(new MLUploadModelChunkResponse("Uploaded")); + wrappedListener.onResponse(new MLUploadModelChunkResponse("Uploaded")); }, e -> { log.error("Failed to upload chunk model", e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, ex -> { log.error("Failed to init model index", ex); - listener.onFailure(ex); + wrappedListener.onFailure(ex); })); } }, e -> { logException("Failed to validate model access", e, log); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to parse ml model " + r.getId(), e); - listener.onFailure(e); + wrappedListener.onFailure(e); } } else { - listener.onFailure(new MLResourceNotFoundException("Failed to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Failed to find model")); } }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Failed to find model")); + wrappedListener.onFailure(new MLResourceNotFoundException("Failed to find model")); } else { log.error("Failed to get ML model " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); } })); } catch (Exception e) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java index c9b6209ccd..6c616e6537 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/SearchTaskTransportActionTests.java @@ -5,6 +5,8 @@ package org.opensearch.ml.action.tasks; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -61,6 +63,6 @@ public void setup() { public void test_DoExecute() { searchTaskTransportAction.doExecute(null, searchRequest, actionListener); - verify(client).search(searchRequest, actionListener); + verify(client).search(eq(searchRequest), any()); } }