diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index 68c97ca28b..22b8e8be5b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -119,6 +119,7 @@ private void handleConnectorAccessValidationFailure(String connectorId, Exceptio private void checkForModelsUsingConnector(String connectorId, String tenantId, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener restoringListener = ActionListener.runBefore(actionListener, context::restore); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId)); if (mlFeatureEnabledSetting.isMultiTenancyEnabled()) { @@ -133,26 +134,25 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A sdkClient .searchDataObjectAsync(searchDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) .whenComplete((sr, st) -> { - context.restore(); if (sr != null) { try { SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { - deleteConnector(connectorId, actionListener); + deleteConnector(connectorId, restoringListener); } else { - handleModelsUsingConnector(searchHits, connectorId, actionListener); + handleModelsUsingConnector(searchHits, connectorId, restoringListener); } } catch (Exception e) { log.error("Failed to parse search response", e); - actionListener + restoringListener .onFailure( new OpenSearchStatusException("Failed to parse search response", RestStatus.INTERNAL_SERVER_ERROR) ); } } else { Exception cause = SdkClientUtils.unwrapAndConvertToException(st); - handleSearchFailure(connectorId, cause, actionListener); + handleSearchFailure(connectorId, cause, restoringListener); } }); } catch (Exception e) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java index 2eef299c7a..362761393a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/UpdateConnectorTransportAction.java @@ -115,6 +115,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + // context is already restored here if (TenantAwareHelper .validateTenantResource( mlFeatureEnabledSetting, @@ -123,7 +124,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener, - ThreadContext.StoredContext context + ActionListener listener ) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); @@ -180,11 +186,7 @@ private void updateUndeployedConnector( sdkClient .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) .whenComplete((r, throwable) -> { - handleUpdateDataObjectCompletionStage( - r, - throwable, - getUpdateResponseListener(connectorId, listener, context) - ); + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener)); }); } else { log.error(searchHits.length + " models are still using this connector, please undeploy the models first!"); @@ -214,11 +216,7 @@ private void updateUndeployedConnector( sdkClient .updateDataObjectAsync(updateDataObjectRequest, client.threadPool().executor(GENERAL_THREAD_POOL)) .whenComplete((r, throwable) -> { - handleUpdateDataObjectCompletionStage( - r, - throwable, - getUpdateResponseListener(connectorId, listener, context) - ); + handleUpdateDataObjectCompletionStage(r, throwable, getUpdateResponseListener(connectorId, listener)); }); return; } else { @@ -246,12 +244,8 @@ private void handleUpdateDataObjectCompletionStage( } } - private ActionListener getUpdateResponseListener( - String connectorId, - ActionListener actionListener, - ThreadContext.StoredContext context - ) { - return ActionListener.runBefore(ActionListener.wrap(updateResponse -> { + private ActionListener getUpdateResponseListener(String connectorId, ActionListener actionListener) { + return ActionListener.wrap(updateResponse -> { if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) { log.error("Failed to update the connector with ID: {}", connectorId); actionListener.onResponse(updateResponse); @@ -262,6 +256,6 @@ private ActionListener getUpdateResponseListener( }, exception -> { log.error("Failed to update ML connector with ID {}. Details: {}", connectorId, exception); actionListener.onFailure(exception); - }), context::restore); + }); } } diff --git a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java index ff650e3c41..91eb3d1d74 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/ConnectorAccessControlHelper.java @@ -162,6 +162,15 @@ public void getConnector(Client client, String connectorId, ActionListener