Skip to content

Commit

Permalink
fix bug to fetch all pipelines
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Nov 19, 2024
1 parent ef68097 commit 525926f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
);
} else if (isModelNotDeployed(mlModelState)) {
checkDownstreamTaskBeforeDeleteModel(modelId, isHidden, actionListener);
;
} else {
wrappedListener
.onFailure(
Expand Down Expand Up @@ -291,29 +290,26 @@ private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean>
private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetPipelineRequest getPipelineRequest = new GetPipelineRequest();
client.execute(GetPipelineAction.INSTANCE, getPipelineRequest, ActionListener.wrap(ingestPipelineResponse -> {
if (!isPipelineContainsModel(
ingestPipelineResponse.pipelines(),
modelId,
org.opensearch.ingest.PipelineConfiguration::getConfigAsMap
)) {
actionListener.onResponse(true);
} else {
List<String> searchPipelineIds = getAllPipelineIds(
List<String> allRelevantPipelineIds = findRelevantPipelines(
ingestPipelineResponse.pipelines(),
modelId,
org.opensearch.ingest.PipelineConfiguration::getConfigAsMap,
org.opensearch.ingest.PipelineConfiguration::getId
);
);
if (allRelevantPipelineIds.isEmpty()) {
actionListener.onResponse(true);
}
else {
actionListener
.onFailure(
new OpenSearchStatusException(
searchPipelineIds.size()
+ " ingest pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(searchPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);

.onFailure(
new OpenSearchStatusException(
allRelevantPipelineIds.size()
+ " ingest pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);
}

}, e -> {
log.error("Failed to delete ML Model: " + modelId, e);
actionListener.onFailure(e);
Expand All @@ -325,29 +321,26 @@ private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener
private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetSearchPipelineRequest getSearchPipelineRequest = new GetSearchPipelineRequest();
client.execute(GetSearchPipelineAction.INSTANCE, getSearchPipelineRequest, ActionListener.wrap(searchPipelineResponse -> {
if (!isPipelineContainsModel(
searchPipelineResponse.pipelines(),
modelId,
org.opensearch.search.pipeline.PipelineConfiguration::getConfigAsMap
)) {
actionListener.onResponse(true);
} else {
List<String> searchPipelineIds = getAllPipelineIds(
List<String> allRelevantPipelineIds = findRelevantPipelines(
searchPipelineResponse.pipelines(),
modelId,
org.opensearch.search.pipeline.PipelineConfiguration::getConfigAsMap,
org.opensearch.search.pipeline.PipelineConfiguration::getId
);
);
if (allRelevantPipelineIds.isEmpty()) {
actionListener.onResponse(true);
}
else {
actionListener
.onFailure(
new OpenSearchStatusException(
searchPipelineIds.size()
+ " search pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(searchPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);

.onFailure(
new OpenSearchStatusException(
allRelevantPipelineIds.size()
+ " search pipelines are still using this model, please delete or update the pipelines first: "
+ Arrays.toString(allRelevantPipelineIds.toArray(new String[0])),
RestStatus.CONFLICT
)
);
}

}, e -> {
log.error("Failed to delete ML Model: " + modelId, e);
actionListener.onFailure(e);
Expand Down Expand Up @@ -484,18 +477,20 @@ private Boolean isModelNotDeployed(MLModelState mlModelState) {
&& !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED);
}

private <T> Boolean isPipelineContainsModel(
private <T> List<String> findRelevantPipelines(
List<T> pipelineConfigurations,
String candidateModelId,
Function<T, Map<String, Object>> getConfigFunction
Function<T, Map<String, Object>> getConfigFunction,
Function<T, String> getIdFunction
) {
List<String> relevantPipelineConfigurations = new ArrayList<>();
for (T pipelineConfiguration : pipelineConfigurations) {
Map<String, Object> config = getConfigFunction.apply(pipelineConfiguration);
if (searchThroughConfig(config, candidateModelId, "")) {
return true;
relevantPipelineConfigurations.add(getIdFunction.apply(pipelineConfiguration));
}
}
return false;
return relevantPipelineConfigurations;
}

private Boolean searchThroughConfig(Object searchCandidate, String candidateId, String targetModelKey) {
Expand All @@ -518,13 +513,6 @@ private Boolean searchThroughConfig(Object searchCandidate, String candidateId,
return flag;
}

private <T> List<String> getAllPipelineIds(List<T> pipelineConfigurations, Function<T, String> getIdFunction) {
List<String> pipelineIds = new ArrayList<>();
for (T pipelineConfiguration : pipelineConfigurations) {
pipelineIds.add(getIdFunction.apply(pipelineConfiguration));
}
return pipelineIds;
}

// this method is only to stub static method.
@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -188,11 +189,45 @@ public void testDeleteModel_Success() throws IOException {
verify(actionListener).onResponse(deleteResponse);
}

public void testDeleteModel_BlockedBySearchPipelineAndIngestionPipeline() throws IOException {
when(searchPipelineConfiguration.getId()).thenReturn("1");
when(searchPipelineConfiguration.getConfigAsMap()).thenReturn(configDataMap);
when(getSearchPipelineResponse.pipelines()).thenReturn(List.of(searchPipelineConfiguration));
doAnswer(invocation -> {
ActionListener<GetSearchPipelineResponse> listener = invocation.getArgument(2);
listener.onResponse(getSearchPipelineResponse);
return null;
}).when(client).execute(eq(GetSearchPipelineAction.INSTANCE), any(), any());

org.opensearch.ingest.PipelineConfiguration ingestPipelineConfiguration = new org.opensearch.ingest.PipelineConfiguration(
"1",
new BytesArray("{\"model_id\": \"test_id\"}".getBytes(StandardCharsets.UTF_8)),
MediaTypeRegistry.JSON
);
when(getIngestionPipelineResponse.pipelines()).thenReturn(List.of(ingestPipelineConfiguration));
doAnswer(invocation -> {
ActionListener<GetPipelineResponse> listener = invocation.getArgument(2);
listener.onResponse(getIngestionPipelineResponse);
return null;
}).when(client).execute(eq(GetPipelineAction.INSTANCE), any(), any());

deleteModelTransportAction.doExecute(null, mlModelDeleteRequest, actionListener);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals("1 ingest pipelines are still using this model, please delete or update the pipelines first: [1],1 search pipelines are still using this model, please delete or update the pipelines first: [1]", argumentCaptor.getValue().getMessage());
}

public void testDeleteModel_BlockedBySearchPipeline() throws IOException {
//org.opensearch.search.pipeline.PipelineConfiguration pipelineConfiguration = new PipelineConfiguration();
when(searchPipelineConfiguration.getId()).thenReturn("1");
when(searchPipelineConfiguration.getConfigAsMap()).thenReturn(configDataMap);
when(getSearchPipelineResponse.pipelines()).thenReturn(List.of(searchPipelineConfiguration));

org.opensearch.search.pipeline.PipelineConfiguration irrelevantSearchPipelineConfiguration = mock(org.opensearch.search.pipeline.PipelineConfiguration.class);
Map<String, Object> irrelevantConfigMap = new HashMap<>();
irrelevantConfigMap.put("nothing", "nothing");
when(irrelevantSearchPipelineConfiguration.getConfigAsMap()).thenReturn(irrelevantConfigMap);
when(irrelevantSearchPipelineConfiguration.getId()).thenReturn("2");
when(getSearchPipelineResponse.pipelines()).thenReturn(List.of(searchPipelineConfiguration, irrelevantSearchPipelineConfiguration));
doAnswer(invocation -> {
ActionListener<GetSearchPipelineResponse> listener = invocation.getArgument(2);
listener.onResponse(getSearchPipelineResponse);
Expand All @@ -211,7 +246,14 @@ public void testDeleteModel_BlockedByIngestPipeline() throws IOException {
new BytesArray("{\"model_id\": \"test_id\"}".getBytes(StandardCharsets.UTF_8)),
MediaTypeRegistry.JSON
);
when(getIngestionPipelineResponse.pipelines()).thenReturn(List.of(ingestPipelineConfiguration));

org.opensearch.ingest.PipelineConfiguration irrelevantIngestPipelineConfiguration = new org.opensearch.ingest.PipelineConfiguration(
"2",
new BytesArray("{\"nothing\": \"test_id\"}".getBytes(StandardCharsets.UTF_8)),
MediaTypeRegistry.JSON
);

when(getIngestionPipelineResponse.pipelines()).thenReturn(List.of(ingestPipelineConfiguration, irrelevantIngestPipelineConfiguration));
doAnswer(invocation -> {
ActionListener<GetPipelineResponse> listener = invocation.getArgument(2);
listener.onResponse(getIngestionPipelineResponse);
Expand Down

0 comments on commit 525926f

Please sign in to comment.