From 95e109a9bc7118d22ac766c41a8d118181669d75 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Fri, 6 Oct 2023 11:35:19 -0700 Subject: [PATCH] fix model group auto-deletion when last version is deleted (#1444) * fix model group auto-deletion when last version is deleted Signed-off-by: Bhavana Ramaram (cherry picked from commit 1f43b2895498ef8eaab9401b1a99c0bc26bd47f0) --- .../models/DeleteModelTransportAction.java | 88 ++---------- .../DeleteModelTransportActionTests.java | 133 ------------------ 2 files changed, 15 insertions(+), 206 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index 5d89e1c113..9070781d63 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -6,14 +6,12 @@ package org.opensearch.ml.action.models; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.MLModel.ALGORITHM_FIELD; import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; -import org.apache.commons.lang3.StringUtils; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; @@ -21,8 +19,6 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -34,8 +30,6 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; @@ -48,7 +42,6 @@ import org.opensearch.ml.common.transport.model.MLModelGetRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; -import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -110,6 +103,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { @@ -118,37 +112,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - boolean isLastModelOfGroup = false; - if (response != null - && response.getHits() != null - && response.getHits().getTotalHits() != null - && response.getHits().getTotalHits().value == 1) { - isLastModelOfGroup = true; - } - deleteModel(modelId, mlModel.getModelGroupId(), isLastModelOfGroup, wrappedListener); - }, e -> { - log.error("Failed to Search Model index " + modelId, e); - wrappedListener.onFailure(e); - })); - } else { - deleteModel(modelId, mlModel.getModelGroupId(), false, wrappedListener); - } + deleteModel(modelId, actionListener); } }, e -> { log.error("Failed to validate Access for Model Id " + modelId, e); @@ -168,18 +145,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(MLModel.MODEL_GROUP_ID_FIELD, modelGroupId)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query); - SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder); - client.search(searchRequest, ActionListener.wrap(response -> { listener.onResponse(response); }, e -> { - log.error("Failed to search Model index", e); - listener.onFailure(e); - })); - } - @VisibleForTesting void deleteModelChunks(String modelId, DeleteResponse deleteResponse, ActionListener actionListener) { DeleteByQueryRequest deleteModelsRequest = new DeleteByQueryRequest(ML_MODEL_INDEX); @@ -218,19 +183,11 @@ private void returnFailure(BulkByScrollResponse response, String modelId, Action actionListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.INTERNAL_SERVER_ERROR)); } - private void deleteModel( - String modelId, - String modelGroupId, - boolean isLastModelOfGroup, - ActionListener actionListener - ) { + private void deleteModel(String modelId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_INDEX, modelId); client.delete(deleteRequest, new ActionListener() { @Override public void onResponse(DeleteResponse deleteResponse) { - if (isLastModelOfGroup) { - deleteModelGroup(modelGroupId); - } deleteModelChunks(modelId, deleteResponse, actionListener); } @@ -244,19 +201,4 @@ public void onFailure(Exception e) { } }); } - - private void deleteModelGroup(String modelGroupId) { - DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId); - client.delete(deleteRequest, new ActionListener() { - @Override - public void onResponse(DeleteResponse deleteResponse) { - log.debug("Completed Delete Model Group for modelGroupId:{}", modelGroupId); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to delete ML Model Group with Id:{} " + modelGroupId, e); - } - }); - } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java index c0f431e64f..57051643a6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/DeleteModelTransportActionTests.java @@ -6,9 +6,7 @@ package org.opensearch.ml.action.models; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -22,7 +20,6 @@ import java.util.ArrayList; import java.util.Arrays; -import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; import org.junit.Rule; @@ -35,7 +32,6 @@ import org.opensearch.action.bulk.BulkItemResponse; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -50,15 +46,11 @@ import org.opensearch.index.get.GetResult; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.ScrollableHitSource; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelManager; -import org.opensearch.ml.utils.TestHelper; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -184,115 +176,6 @@ public void testDeleteModel_Success_AlgorithmNotNull() throws IOException { verify(actionListener).onResponse(deleteResponse); } - public void test_Success_ModelGroupIDNotNull_LastModelOfGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - SearchResponse searchResponse = createModelGroupSearchResponse(1); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); - } - - public void test_Success_ModelGroupIDNotNull_NotLastModelOfGroup() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - SearchResponse searchResponse = createModelGroupSearchResponse(2); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(searchResponse); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - MLModel mlModel = MLModel - .builder() - .modelId("test_id") - .modelGroupId("modelGroupID") - .modelState(MLModelState.REGISTERED) - .algorithm(FunctionName.TEXT_EMBEDDING) - .build(); - XContentBuilder content = mlModel.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); - BytesReference bytesReference = BytesReference.bytes(content); - GetResult getResult = new GetResult("indexName", "111", 111l, 111l, 111l, true, bytesReference, null, null); - GetResponse getResponse = new GetResponse(getResult); - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - verify(actionListener).onResponse(deleteResponse); - } - - public void test_Failure_FailedToSearchLastModel() throws IOException { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onResponse(deleteResponse); - return null; - }).when(client).delete(any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - BulkByScrollResponse response = new BulkByScrollResponse(new ArrayList<>(), null); - listener.onResponse(response); - return null; - }).when(client).execute(any(), any(), any()); - - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); - listener.onFailure(new Exception("Failed to search Model index")); - return null; - }).when(client).search(any(), isA(ActionListener.class)); - - GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - actionListener.onResponse(getResponse); - return null; - }).when(client).get(any(), any()); - - deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals("Failed to search Model index", argumentCaptor.getValue().getMessage()); - } - public void test_UserHasNoAccessException() throws IOException { GetResponse getResponse = prepareMLModel(MLModelState.REGISTERED, "modelGroupID"); doAnswer(invocation -> { @@ -517,20 +400,4 @@ public GetResponse prepareMLModel(MLModelState mlModelState, String modelGroupID GetResponse getResponse = new GetResponse(getResult); return getResponse; } - - private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException { - SearchResponse searchResponse = mock(SearchResponse.class); - String modelContent = "{\n" - + " \"created_time\": 1684981986069,\n" - + " \"access\": \"public\",\n" - + " \"latest_version\": 0,\n" - + " \"last_updated_time\": 1684981986069,\n" - + " \"name\": \"model_group_IT\",\n" - + " \"description\": \"This is an example description\"\n" - + " }"; - SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent)); - SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN); - when(searchResponse.getHits()).thenReturn(hits); - return searchResponse; - } }