From f18eaf3a9cf04cad5519f2845594615373c9c13a Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 7 Feb 2024 23:17:19 +0800 Subject: [PATCH] Fix long pending issue when deleting model (#1882) * Fix long pending issue when deleting model Signed-off-by: zane-neo * Refine the delete model code Signed-off-by: zane-neo * refactor delete model flow to make sure all dependent resources are deleted together with model metadata Signed-off-by: zane-neo * fix minor issue to make sure only non-remote model will deelete chunks Signed-off-by: zane-neo * format code Signed-off-by: zane-neo * fix failure UTs Signed-off-by: zane-neo * Change to delete model metadata first Signed-off-by: zane-neo * format code Signed-off-by: zane-neo * Remove remote function check Signed-off-by: zane-neo * Fix failure UTs Signed-off-by: zane-neo --------- Signed-off-by: zane-neo --- .../models/DeleteModelTransportAction.java | 102 +++---- .../TransportUndeployModelAction.java | 3 +- .../DeleteModelTransportActionTests.java | 262 +++++++++++++++--- 3 files changed, 287 insertions(+), 80 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 307ddc97b0..464faecdeb 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -14,16 +14,19 @@ import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; -import org.opensearch.action.DocWriteResponse; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -34,7 +37,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -167,7 +169,9 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { wrappedListener.onFailure((new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND))); })); } catch (Exception e) { @@ -177,7 +181,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { + void deleteModelChunks(String modelId, ActionListener actionListener) { DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); deleteModelsRequest.setQuery(new TermsQueryBuilder(MODEL_ID_FIELD, modelId)); @@ -185,23 +189,17 @@ void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionList if ((r.getBulkFailures() == null || r.getBulkFailures().size() == 0) && (r.getSearchFailures() == null || r.getSearchFailures().size() == 0)) { log.debug("All model chunks are deleted for model {}", modelId); - if (deleteResponse != null) { - // If model metaData not found and deleteResponse is null, do not return here. - // ResourceNotFound is returned to notify that this model was deleted. - // This is a walk around to avoid cleaning up model leftovers. Will revisit if - // necessary. - actionListener.onResponse(deleteResponse); - } + actionListener.onResponse(true); } else { returnFailure(r, modelId, actionListener); } }, e -> { - log.error("Failed to delete ML model for " + modelId, e); + log.error("Failed to delete model chunks for: " + modelId, e); actionListener.onFailure(e); })); } - private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { + private void returnFailure(BulkByScrollResponse response, String modelId, ActionListener actionListener) { String errorMessage = ""; if (response.isTimedOut()) { errorMessage = OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + modelId; @@ -215,24 +213,56 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action } private void deleteModel(String modelId, ActionListener actionListener) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId); - client.delete(deleteRequest, new ActionListener() { + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { - deleteModelChunks(modelId, deleteResponse, actionListener); - deleteController(modelId); + deleteModelChunksAndController(actionListener, modelId, deleteResponse); } @Override public void onFailure(Exception e) { - log.error("Failed to delete model meta data for model: " + modelId, e); if (e instanceof ResourceNotFoundException) { - deleteModelChunks(modelId, null, actionListener); - deleteController(modelId); + deleteModelChunksAndController(actionListener, modelId, null); + } else { + log.error("Model is not all cleaned up, please try again: " + modelId, e); + actionListener.onFailure(e); + } + } + }); + } + + private void deleteModelChunksAndController( + ActionListener actionListener, + String modelId, + DeleteResponse deleteResponse + ) { + CountDownLatch countDownLatch = new CountDownLatch(2); + AtomicBoolean bothDeleted = new AtomicBoolean(true); + ActionListener countDownActionListener = ActionListener.wrap(b -> { + countDownLatch.countDown(); + bothDeleted.compareAndSet(true, b); + if (countDownLatch.getCount() == 0) { + if (bothDeleted.get()) { + log.debug("model chunks and model controller for model {} deleted successfully", modelId); + if (deleteResponse != null) { + actionListener.onResponse(deleteResponse); + } else { + actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + } + } else { + actionListener.onFailure(new IllegalStateException("Model is not all cleaned up, please try again: " + modelId)); } - actionListener.onFailure(e); + } + }, e -> { + countDownLatch.countDown(); + bothDeleted.compareAndSet(true, false); + if (countDownLatch.getCount() == 0) { + actionListener.onFailure(new IllegalStateException("Model is not all cleaned up, please try again: " + modelId, e)); } }); + deleteModelChunks(modelId, countDownActionListener); + deleteController(modelId, countDownActionListener); } /** @@ -241,20 +271,20 @@ public void onFailure(Exception e) { * * @param modelId model ID */ - private void deleteController(String modelId, ActionListener actionListener) { + private void deleteController(String modelId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_CONTROLLER_INDEX, modelId); client.delete(deleteRequest, new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); - actionListener.onResponse(deleteResponse); + actionListener.onResponse(true); } @Override public void onFailure(Exception e) { - if (e instanceof IndexNotFoundException) { + if (e instanceof ResourceNotFoundException) { log.info("Model controller not deleted due to no model controller found for model: " + modelId); - actionListener.onFailure(e); + actionListener.onResponse(true); // we consider this as success } else { log.error("Failed to delete model controller for model: " + modelId, e); actionListener.onFailure(e); @@ -263,28 +293,6 @@ public void onFailure(Exception e) { }); } - /** - * Delete the model controller for a model after the model is deleted from the - * ML index with build-in listener. - * - * @param modelId model ID - */ - private void deleteController(String modelId) { - deleteController(modelId, ActionListener.wrap(deleteResponse -> { - if (deleteResponse.getResult() == DocWriteResponse.Result.DELETED) { - log.info("Model controller for model {} successfully deleted from index, result: {}", modelId, deleteResponse.getResult()); - } else { - log.info("The deletion of model controller for model {} returned with result: {}", modelId, deleteResponse.getResult()); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - log.debug("Model controller not deleted due to no model controller found for model: " + modelId); - } else { - log.error("Failed to delete model controller for model: " + modelId, e); - } - })); - } - private Boolean isModelNotDeployed(MLModelState mlModelState) { return !mlModelState.equals(MLModelState.LOADED) && !mlModelState.equals(MLModelState.LOADING) diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java index b8097c72e5..662971b2c7 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelAction.java @@ -20,6 +20,7 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; @@ -174,7 +175,7 @@ protected MLUndeployModelNodesResponse newResponse( deployToAllNodes.put(modelId, false); } updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(updateDocument); - bulkRequest.add(updateRequest); + bulkRequest.add(updateRequest).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); } syncUpInput.setDeployToAllNodes(deployToAllNodes); ActionListener actionListener = ActionListener.wrap(r -> { diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index e9c4e23cfd..bfc7f628ff 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -6,9 +6,12 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.action.models.DeleteModelTransportAction.BULK_FAILURE_MSG; @@ -46,7 +49,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.get.GetResult; import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; @@ -154,6 +159,119 @@ public void testDeleteModel_Success() throws IOException { verify(actionListener).onResponse(deleteResponse); } + public void testDeleteRemoteModel_Success() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + verify(actionListener).onResponse(deleteResponse); + } + + public void testDeleteRemoteModel_deleteModelController_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Model is not all cleaned up, please try again: test_id", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteLocalModel_deleteModelController_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.TEXT_EMBEDDING); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Model is not all cleaned up, please try again: test_id", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteRemoteModel_deleteModelChunks_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + + GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Model is not all cleaned up, please try again: test_id", argumentCaptor.getValue().getMessage()); + } + public void testDeleteHiddenModel_Success() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -283,8 +401,12 @@ public void testDeleteModel_ModelNotFoundException() throws IOException { assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); } - public void testDeleteModel_ResourceNotFoundException() throws IOException { + public void testDeleteModel_deleteModelController_ResourceNotFoundException() throws IOException { doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new ResourceNotFoundException("errorMessage")); return null; @@ -305,9 +427,8 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener, times(1)).onResponse(argumentCaptor.capture()); } public void test_ValidationFailedException() throws IOException { @@ -330,48 +451,103 @@ public void test_ValidationFailedException() throws IOException { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } - public void testModelNotFound() throws IOException { + public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new ResourceNotFoundException("resource not found")); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(null); + actionListener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); + assert argumentCaptor.getValue().getMessage().equals("Failed to find model"); } - public void testDeleteModelChunks_Success() { - when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); + public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - listener.onResponse(bulkByScrollResponse); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); return null; }).when(client).execute(any(), any(), any()); - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); - verify(actionListener).onResponse(deleteResponse); + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue().getMessage().equals("runtime exception"); } - public void testDeleteModel_RuntimeException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); + public void testModelNotFound_modelChunks_modelController_delete_success() throws IOException { doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); + actionListener.onResponse(null); return null; }).when(client).get(any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("errorMessage")); + listener.onResponse(deleteResponse); return null; }).when(client).delete(any(), any()); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteModelChunks_Success() { + when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); + verify(deleteChunksListener).onResponse(true); } @Ignore @@ -389,10 +565,10 @@ public void test_FailToDeleteModel() { listener.onFailure(new RuntimeException("errorMessage")); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); } @@ -404,10 +580,10 @@ public void test_FailToDeleteAllModelChunks() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + BULK_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } @@ -420,10 +596,10 @@ public void test_FailToDeleteAllModelChunks_TimeOut() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + TIMEOUT_MSG + "test_id", argumentCaptor.getValue().getMessage()); } @@ -442,16 +618,38 @@ public void test_FailToDeleteAllModelChunks_SearchFailure() { listener.onResponse(bulkByScrollResponse); return null; }).when(client).execute(any(), any(), any()); - - deleteModelTransportAction.deleteModelChunks("test_id", deleteResponse, actionListener); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); + verify(deleteChunksListener).onFailure(argumentCaptor.capture()); assertEquals(OS_STATUS_EXCEPTION_MESSAGE + ", " + SEARCH_FAILURE_MSG + "test_id", argumentCaptor.getValue().getMessage()); } public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID, boolean isHidden) throws IOException { - MLModel mlModel; - mlModel = MLModel.builder().modelId("test_id").modelState(mlModelState).modelGroupId(modelGroupID).isHidden(isHidden).build(); + MLModel mlModel = MLModel + .builder() + .modelId("test_id") + .modelState(mlModelState) + .modelGroupId(modelGroupID) + .isHidden(isHidden) + .build(); + return buildResponse(mlModel); + } + + public GetResponse prepareModelWithFunction(MLModelState mlModelState, String modelGroupID, boolean isHidden, FunctionName functionName) + throws IOException { + MLModel mlModel = MLModel + .builder() + .modelId("test_id") + .algorithm(functionName) + .modelState(mlModelState) + .modelGroupId(modelGroupID) + .isHidden(isHidden) + .build(); + return buildResponse(mlModel); + } + + private GetResponse buildResponse(MLModel mlModel) throws IOException { XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); BytesReference bytesReference = BytesReference.bytes(content); GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null);