Skip to content

Commit

Permalink
change name and reformat logic
Browse files Browse the repository at this point in the history
Signed-off-by: xinyual <[email protected]>
  • Loading branch information
xinyual committed Dec 4, 2024
1 parent cde3934 commit 6bb7c63
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public class CommonValue {
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final Integer ML_MEMORY_MESSAGE_INDEX_SCHEMA_VERSION = 1;
public static final String TOOL_MODEL_RELATED_FIELD_PREFIX = "tools.parameters.";
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package org.opensearch.ml.engine.tools;
package org.opensearch.ml.engine.utils;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
import static org.opensearch.ml.common.CommonValue.TOOL_MODEL_RELATED_FIELD_PREFIX;
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;

import java.util.HashSet;
import java.util.Map;
Expand All @@ -28,7 +28,7 @@ public SearchRequest constructQueryRequest(String candidateModelId) {
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
BoolQueryBuilder shouldQuery = QueryBuilders.boolQuery();
for (String keyField : relatedModelIdSet) {
shouldQuery.should(QueryBuilders.termsQuery(TOOL_MODEL_RELATED_FIELD_PREFIX + keyField, candidateModelId));
shouldQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
}
searchRequest.source(new SearchSourceBuilder().query(shouldQuery));
return searchRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -66,7 +68,8 @@
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelDeleteRequest;
import org.opensearch.ml.common.transport.model.MLModelGetRequest;
import org.opensearch.ml.engine.tools.AgentModelsSearcher;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -284,48 +287,27 @@ private void checkAgentBeforeDeleteModel(String modelId, ActionListener<Boolean>
}

private void checkIngestPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetPipelineRequest getPipelineRequest = new GetPipelineRequest();
client.execute(GetPipelineAction.INSTANCE, getPipelineRequest, ActionListener.wrap(ingestPipelineResponse -> {
List<String> allDependentPipelineIds = findDependentPipelines(
ingestPipelineResponse.pipelines(),
modelId,
org.opensearch.ingest.PipelineConfiguration::getConfigAsMap,
org.opensearch.ingest.PipelineConfiguration::getId
);
if (allDependentPipelineIds.isEmpty()) {
actionListener.onResponse(true);
} else {
actionListener
.onFailure(
new OpenSearchStatusException(
String
.format(
Locale.ROOT,
"%d ingest pipelines are still using this model, please delete or update the pipelines first: %s",
allDependentPipelineIds.size(),
Arrays.toString(allDependentPipelineIds.toArray(new String[0]))
),
RestStatus.CONFLICT
)
);
}
}, e -> {
log.error("Failed to delete ML Model: " + modelId, e);
actionListener.onFailure(e);

}));
checkPipelineBeforeDeleteModel(modelId, actionListener, "ingest", GetPipelineRequest::new, GetPipelineAction.INSTANCE);

}

private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener<Boolean> actionListener) {
GetSearchPipelineRequest getSearchPipelineRequest = new GetSearchPipelineRequest();
client.execute(GetSearchPipelineAction.INSTANCE, getSearchPipelineRequest, ActionListener.wrap(searchPipelineResponse -> {
List<String> allDependentPipelineIds = findDependentPipelines(
searchPipelineResponse.pipelines(),
modelId,
org.opensearch.search.pipeline.PipelineConfiguration::getConfigAsMap,
org.opensearch.search.pipeline.PipelineConfiguration::getId
);
checkPipelineBeforeDeleteModel(modelId, actionListener, "search", GetSearchPipelineRequest::new, GetSearchPipelineAction.INSTANCE);

}

private void checkPipelineBeforeDeleteModel(
String modelId,
ActionListener<Boolean> actionListener,
String pipelineType,
Supplier<ActionRequest> requestSupplier,
ActionType actionType
) {
ActionRequest request = requestSupplier.get();
client.execute(actionType, request, ActionListener.wrap(pipelineResponse -> {
String responseString = pipelineResponse.toString();
Map<String, Object> allConfigMap = StringUtils.fromJson(pipelineResponse.toString(), "");
List<String> allDependentPipelineIds = findDependentPipelinesEasy(allConfigMap, modelId);
if (allDependentPipelineIds.isEmpty()) {
actionListener.onResponse(true);
} else {
Expand All @@ -335,8 +317,9 @@ private void checkSearchPipelineBeforeDeleteModel(String modelId, ActionListener
String
.format(
Locale.ROOT,
"%d search pipelines are still using this model, please delete or update the pipelines first: %s",
"%d %s pipelines are still using this model, please delete or update the pipelines first: %s",
allDependentPipelineIds.size(),
pipelineType,
Arrays.toString(allDependentPipelineIds.toArray(new String[0]))
),
RestStatus.CONFLICT
Expand Down Expand Up @@ -479,6 +462,18 @@ private Boolean isModelNotDeployed(MLModelState mlModelState) {
&& !mlModelState.equals(MLModelState.PARTIALLY_DEPLOYED);
}

private List<String> findDependentPipelinesEasy(Map<String, Object> allConfigMap, String candidateModelId) {
List<String> dependentPipelineConfigurations = new ArrayList<>();
for (Map.Entry<String, Object> entry : allConfigMap.entrySet()) {
String id = entry.getKey();
Map<String, Object> config = (Map<String, Object>) entry.getValue();
if (searchThroughConfig(config, candidateModelId)) {
dependentPipelineConfigurations.add(id);
}
}
return dependentPipelineConfigurations;
}

private <T> List<String> findDependentPipelines(
List<T> pipelineConfigurations,
String candidateModelId,
Expand Down Expand Up @@ -533,24 +528,22 @@ private Boolean searchThroughConfig(Object searchCandidate, String candidateId)
}

private String formatAgentErrorMessage(SearchHit[] hits) {
boolean isHidden = false;
List<String> agentIds = new ArrayList<>();
for (SearchHit hit : hits) {
Map<String, Object> sourceAsMap = hit.getSourceAsMap();
isHidden = isHidden || Boolean.parseBoolean((String) sourceAsMap.getOrDefault(MLAgent.IS_HIDDEN_FIELD, false));
agentIds.add(hit.getId());
}
if (isHidden) {
return String
.format(Locale.ROOT, "%d agents are still using this model, please delete or update the agents first", hits.length);
Boolean isHidden = (Boolean) sourceAsMap.getOrDefault(MLAgent.IS_HIDDEN_FIELD, false);
if (!isHidden) {
agentIds.add(hit.getId());
}
}
return String
.format(
Locale.ROOT,
"%d agents are still using this model, please delete or update the agents first: %s",
"%d agents are still using this model, please delete or update the agents first, all visible agents are: %s",
hits.length,
Arrays.toString(agentIds.toArray(new String[0]))
);

}

// this method is only to stub static method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
import org.opensearch.ml.engine.memory.MLMemoryManager;
import org.opensearch.ml.engine.tools.AgentModelsSearcher;
import org.opensearch.ml.engine.tools.AgentTool;
import org.opensearch.ml.engine.tools.CatIndexTool;
import org.opensearch.ml.engine.tools.ConnectorTool;
import org.opensearch.ml.engine.tools.IndexMappingTool;
import org.opensearch.ml.engine.tools.MLModelTool;
import org.opensearch.ml.engine.tools.SearchIndexTool;
import org.opensearch.ml.engine.tools.VisualizationsTool;
import org.opensearch.ml.engine.utils.AgentModelsSearcher;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
Expand Down
Loading

0 comments on commit 6bb7c63

Please sign in to comment.