diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 698b97733c..a20f74c5db 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -169,7 +169,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { - // Always delete model chunks and model controller first, because deleting metadata first user is not able clean up model chunks and - // model controller. + DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.delete(deleteRequest, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteModelChunksAndController(functionName, actionListener, modelId, deleteResponse); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof ResourceNotFoundException) { + deleteModelChunksAndController(functionName, actionListener, modelId, null); + } else { + log.error("Failed to delete model meta data for model, please try again: " + modelId, e); + actionListener.onFailure(e); + } + } + }); + } + + private void deleteModelChunksAndController(FunctionName functionName, ActionListener actionListener, String modelId, DeleteResponse deleteResponse) { if (FunctionName.REMOTE != functionName) { CountDownLatch countDownLatch = new CountDownLatch(2); AtomicBoolean bothDeleted = new AtomicBoolean(true); @@ -223,10 +242,14 @@ private void deleteModel(String modelId, FunctionName functionName, ActionListen if (bothDeleted.get()) { log .debug( - "model chunks and model controller for model {} deleted successfully, starting to delete model meta data", + "model chunks and model controller for model {} deleted successfully", modelId ); - deleteModelMetadata(modelId, actionListener); + if (deleteResponse != null) { + actionListener.onResponse(deleteResponse); + } else { + actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + } } else { actionListener .onFailure( @@ -248,35 +271,23 @@ private void deleteModel(String modelId, FunctionName functionName, ActionListen deleteController(modelId, countDownActionListener); } else { ActionListener deleteControllerListener = ActionListener.wrap(b -> { - log.debug("model controller for model {} deleted successfully, starting to delete model meta data", modelId); - deleteModelMetadata(modelId, actionListener); + log.debug("model controller for model {} deleted successfully", modelId); + if (deleteResponse != null) { + actionListener.onResponse(deleteResponse); + } else { + actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND)); + } }, e -> { - log.error("Failed to delete model chunks or model controller, please try again: " + modelId, e); + log.error("Failed to delete model controller, please try again: " + modelId, e); actionListener .onFailure( - new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e) + new IllegalStateException("Failed to delete model controller, please try again: " + modelId, e) ); }); deleteController(modelId, deleteControllerListener); } } - private void deleteModelMetadata(String modelId, ActionListener actionListener) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.delete(deleteRequest, new ActionListener<>() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - actionListener.onResponse(deleteResponse); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to delete model meta data for model, please try again: " + modelId, e); - actionListener.onFailure(e); - } - }); - } - /** * Delete the model controller for a model after the model is deleted from the * ML index. diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index f931bf225e..8e2f108b60 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -11,6 +11,7 @@ import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.action.models.DeleteModelTransportAction.BULK_FAILURE_MSG; @@ -185,6 +186,10 @@ public void testDeleteRemoteModel_Success() throws IOException { public void testDeleteRemoteModel_deleteModelController_failed() throws IOException { doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new RuntimeException("runtime exception")); return null; @@ -204,6 +209,40 @@ public void testDeleteRemoteModel_deleteModelController_failed() throws IOExcept return null; }).when(client).get(any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Failed to delete model controller, please try again: test_id", + argumentCaptor.getValue().getMessage() + ); + } + + public void testDeleteLocalModel_deleteModelController_failed() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.TEXT_EMBEDDING); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); @@ -371,8 +410,12 @@ public void testDeleteModel_ModelNotFoundException() throws IOException { assertEquals("Fail to find model", argumentCaptor.getValue().getMessage()); } - public void testDeleteModel_ResourceNotFoundException() throws IOException { + public void testDeleteModel_deleteModelController_ResourceNotFoundException() throws IOException { doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new ResourceNotFoundException("errorMessage")); return null; @@ -393,9 +436,8 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException { }).when(client).get(any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("errorMessage", argumentCaptor.getValue().getMessage()); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + verify(actionListener, times(1)).onResponse(argumentCaptor.capture()); } public void test_ValidationFailedException() throws IOException { @@ -418,56 +460,103 @@ public void test_ValidationFailedException() throws IOException { assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage()); } - public void testModelNotFound() throws IOException { + public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new ResourceNotFoundException("resource not found")); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); + return null; + }).when(client).execute(any(), any(), any()); + + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(null); + actionListener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); + assert argumentCaptor.getValue().getMessage().equals("Failed to find model"); } - public void testDeleteModelChunks_Success() { - when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); + public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOException { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("runtime exception")); + return null; + }).doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(deleteResponse); + return null; + }).when(client).delete(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - listener.onResponse(bulkByScrollResponse); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); return null; }).when(client).execute(any(), any(), any()); - ActionListener deleteChunksListener = mock(ActionListener.class); - deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); - verify(deleteChunksListener).onResponse(true); - } - public void testDeleteModel_RuntimeException() throws IOException { - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false); + GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onResponse(getResponse); return null; }).when(client).get(any(), any()); + deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assert argumentCaptor.getValue().getMessage().equals("runtime exception"); + } + + public void testModelNotFound_modelChunks_modelController_delete_success() throws IOException { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(null); + return null; + }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); - listener.onFailure(new RuntimeException("errorMessage")); + listener.onResponse(deleteResponse); return null; }).when(client).delete(any(), any()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(2); - listener.onResponse(mock(BulkByScrollResponse.class)); + BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); + listener.onResponse(response); return null; - }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); + }).when(client).execute(any(), any(), any()); deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "Failed to delete model chunks or model controller, please try again: test_id", - argumentCaptor.getValue().getMessage() - ); + assertEquals("Failed to find model", argumentCaptor.getValue().getMessage()); + } + + public void testDeleteModelChunks_Success() { + when(bulkByScrollResponse.getBulkFailures()).thenReturn(null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(bulkByScrollResponse); + return null; + }).when(client).execute(any(), any(), any()); + ActionListener deleteChunksListener = mock(ActionListener.class); + deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener); + verify(deleteChunksListener).onResponse(true); } @Ignore