Skip to content

Commit

Permalink
Change to delete model metadata first
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Feb 7, 2024
1 parent a5687bd commit a60ef31
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
wrappedListener.onFailure(e);
}
} else {
wrappedListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
// when model metadata is not found, model chunk and controller might still there, delete them here and return success response
deleteModelChunksAndController(null, wrappedListener, modelId, null);
}
}, wrappedListener::onFailure));
} catch (Exception e) {
Expand Down Expand Up @@ -211,8 +212,26 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action
}

private void deleteModel(String modelId, FunctionName functionName, ActionListener<DeleteResponse> actionListener) {
// Always delete model chunks and model controller first, because deleting metadata first user is not able clean up model chunks and
// model controller.
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
deleteModelChunksAndController(functionName, actionListener, modelId, deleteResponse);
}

@Override
public void onFailure(Exception e) {
if (e instanceof ResourceNotFoundException) {
deleteModelChunksAndController(functionName, actionListener, modelId, null);
} else {
log.error("Failed to delete model meta data for model, please try again: " + modelId, e);
actionListener.onFailure(e);
}
}
});
}

private void deleteModelChunksAndController(FunctionName functionName, ActionListener<DeleteResponse> actionListener, String modelId, DeleteResponse deleteResponse) {
if (FunctionName.REMOTE != functionName) {
CountDownLatch countDownLatch = new CountDownLatch(2);
AtomicBoolean bothDeleted = new AtomicBoolean(true);
Expand All @@ -223,10 +242,14 @@ private void deleteModel(String modelId, FunctionName functionName, ActionListen
if (bothDeleted.get()) {
log
.debug(
"model chunks and model controller for model {} deleted successfully, starting to delete model meta data",
"model chunks and model controller for model {} deleted successfully",
modelId
);
deleteModelMetadata(modelId, actionListener);
if (deleteResponse != null) {
actionListener.onResponse(deleteResponse);
} else {
actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
} else {
actionListener
.onFailure(
Expand All @@ -248,35 +271,23 @@ private void deleteModel(String modelId, FunctionName functionName, ActionListen
deleteController(modelId, countDownActionListener);
} else {
ActionListener<Boolean> deleteControllerListener = ActionListener.wrap(b -> {
log.debug("model controller for model {} deleted successfully, starting to delete model meta data", modelId);
deleteModelMetadata(modelId, actionListener);
log.debug("model controller for model {} deleted successfully", modelId);
if (deleteResponse != null) {
actionListener.onResponse(deleteResponse);
} else {
actionListener.onFailure(new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND));
}
}, e -> {
log.error("Failed to delete model chunks or model controller, please try again: " + modelId, e);
log.error("Failed to delete model controller, please try again: " + modelId, e);
actionListener
.onFailure(
new IllegalStateException("Failed to delete model chunks or model controller, please try again: " + modelId, e)
new IllegalStateException("Failed to delete model controller, please try again: " + modelId, e)
);
});
deleteController(modelId, deleteControllerListener);
}
}

private void deleteModelMetadata(String modelId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId).setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.delete(deleteRequest, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete model meta data for model, please try again: " + modelId, e);
actionListener.onFailure(e);
}
});
}

/**
* Delete the model controller for a model after the model is deleted from the
* ML index.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.action.models.DeleteModelTransportAction.BULK_FAILURE_MSG;
Expand Down Expand Up @@ -185,6 +186,10 @@ public void testDeleteRemoteModel_Success() throws IOException {

public void testDeleteRemoteModel_deleteModelController_failed() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("runtime exception"));
return null;
Expand All @@ -204,6 +209,40 @@ public void testDeleteRemoteModel_deleteModelController_failed() throws IOExcept
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Failed to delete model controller, please try again: test_id",
argumentCaptor.getValue().getMessage()
);
}

public void testDeleteLocalModel_deleteModelController_failed() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("runtime exception"));
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.TEXT_EMBEDDING);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
Expand Down Expand Up @@ -371,8 +410,12 @@ public void testDeleteModel_ModelNotFoundException() throws IOException {
assertEquals("Fail to find model", argumentCaptor.getValue().getMessage());
}

public void testDeleteModel_ResourceNotFoundException() throws IOException {
public void testDeleteModel_deleteModelController_ResourceNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceNotFoundException("errorMessage"));
return null;
Expand All @@ -393,9 +436,8 @@ public void testDeleteModel_ResourceNotFoundException() throws IOException {
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<ResourceNotFoundException> argumentCaptor = ArgumentCaptor.forClass(ResourceNotFoundException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("errorMessage", argumentCaptor.getValue().getMessage());
ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
verify(actionListener, times(1)).onResponse(argumentCaptor.capture());
}

public void test_ValidationFailedException() throws IOException {
Expand All @@ -418,56 +460,103 @@ public void test_ValidationFailedException() throws IOException {
assertEquals("Failed to validate access", argumentCaptor.getValue().getMessage());
}

public void testModelNotFound() throws IOException {
public void testDeleteRemoteModel_modelNotFound_ResourceNotFoundException() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new ResourceNotFoundException("resource not found"));
return null;
}).doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());

GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(null);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("Failed to find model", argumentCaptor.getValue().getMessage());
assert argumentCaptor.getValue().getMessage().equals("Failed to find model");
}

public void testDeleteModelChunks_Success() {
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
public void testDeleteRemoteModel_modelNotFound_RuntimeException() throws IOException {
doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("runtime exception"));
return null;
}).doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(any(), any(), any());
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
verify(deleteChunksListener).onResponse(true);
}

public void testDeleteModel_RuntimeException() throws IOException {
GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, null, false);
GetResponse getResponse = prepareModelWithFunction(MLModelState.REGISTERED, null, false, FunctionName.REMOTE);
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<RuntimeException> argumentCaptor = ArgumentCaptor.forClass(RuntimeException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue().getMessage().equals("runtime exception");
}

public void testModelNotFound_modelChunks_modelController_delete_success() throws IOException {
doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(null);
return null;
}).when(client).get(any(), any());

doAnswer(invocation -> {
ActionListener<DeleteResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("errorMessage"));
listener.onResponse(deleteResponse);
return null;
}).when(client).delete(any(), any());

doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(mock(BulkByScrollResponse.class));
BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null);
listener.onResponse(response);
return null;
}).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any());
}).when(client).execute(any(), any(), any());
deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
ArgumentCaptor<OpenSearchStatusException> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Failed to delete model chunks or model controller, please try again: test_id",
argumentCaptor.getValue().getMessage()
);
assertEquals("Failed to find model", argumentCaptor.getValue().getMessage());
}

public void testDeleteModelChunks_Success() {
when(bulkByScrollResponse.getBulkFailures()).thenReturn(null);
doAnswer(invocation -> {
ActionListener<BulkByScrollResponse> listener = invocation.getArgument(2);
listener.onResponse(bulkByScrollResponse);
return null;
}).when(client).execute(any(), any(), any());
ActionListener<Boolean> deleteChunksListener = mock(ActionListener.class);
deleteModelTransportAction.deleteModelChunks("test_id", deleteChunksListener);
verify(deleteChunksListener).onResponse(true);
}

@Ignore
Expand Down

0 comments on commit a60ef31

Please sign in to comment.