diff --git a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java index 9014c907..52811e61 100644 --- a/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java +++ b/src/main/java/org/opensearch/agent/tools/CreateAnomalyDetectorTool.java @@ -199,7 +199,7 @@ public void run(Map parameters, ActionListener listener) // flatten all the fields in the mapping Map fieldsToType = new HashMap<>(); - ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, ""); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", true); // find all date type fields from the mapping final Set dateFields = findDateTypeFields(fieldsToType); diff --git a/src/main/java/org/opensearch/agent/tools/PPLTool.java b/src/main/java/org/opensearch/agent/tools/PPLTool.java index 82787836..037a321e 100644 --- a/src/main/java/org/opensearch/agent/tools/PPLTool.java +++ b/src/main/java/org/opensearch/agent/tools/PPLTool.java @@ -402,7 +402,7 @@ private String constructTableInfo(SearchHit[] searchHits, Map fieldsToType = new HashMap<>(); - ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, ""); + ToolHelper.extractFieldNamesTypes(mappingSource, fieldsToType, "", false); StringJoiner tableInfoJoiner = new StringJoiner("\n"); List sortedKeys = new ArrayList<>(fieldsToType.keySet()); Collections.sort(sortedKeys); diff --git a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java index e88c5175..b60f46c9 100644 --- a/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java +++ b/src/main/java/org/opensearch/agent/tools/utils/ToolHelper.java @@ -40,8 +40,15 @@ public static Map loadDefaultPromptDictFromFile(Class source, * @param mappingSource the mappings of an index * @param fieldsToType the result containing the field to fieldType mapping * @param prefix the parent field path + * @param includeFields whether include the `fields` in a text type field, for some use case like PPLTool, `fields` in a text type field + * cannot be included, but for CreateAnomalyDetectorTool, `fields` must be included. */ - public static void extractFieldNamesTypes(Map mappingSource, Map fieldsToType, String prefix) { + public static void extractFieldNamesTypes( + Map mappingSource, + Map fieldsToType, + String prefix, + boolean includeFields + ) { if (prefix.length() > 0) { prefix += "."; } @@ -53,15 +60,17 @@ public static void extractFieldNamesTypes(Map mappingSource, Map if (v instanceof Map) { Map vMap = (Map) v; if (vMap.containsKey("type")) { - if (!((vMap.getOrDefault("type", "")).equals("alias"))) { + String fieldType = (String) vMap.getOrDefault("type", ""); + // no need to extract alias into the result, and for object field, extract the subfields only + if (!fieldType.equals("alias") && !fieldType.equals("object")) { fieldsToType.put(prefix + n, (String) vMap.get("type")); } } if (vMap.containsKey("properties")) { - extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n); + extractFieldNamesTypes((Map) vMap.get("properties"), fieldsToType, prefix + n, includeFields); } - if (vMap.containsKey("fields")) { - extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n); + if (includeFields && vMap.containsKey("fields")) { + extractFieldNamesTypes((Map) vMap.get("fields"), fieldsToType, prefix + n, true); } } } diff --git a/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java new file mode 100644 index 00000000..5b6dfa7f --- /dev/null +++ b/src/test/java/org/opensearch/agent/tools/ToolHelperTests.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.agent.tools; + +import static org.junit.Assert.assertEquals; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.agent.tools.utils.ToolHelper; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ToolHelperTests { + @Test + public void TestExtractFieldNamesTypes() { + Map indexMappings = Map + .of( + "response", + Map.of("type", "integer"), + "responseLatency", + Map.of("type", "float"), + "date", + Map.of("type", "date"), + "objectA", + Map.of("type", "object", "properties", Map.of("subA", Map.of("type", "keyword"))), + "objectB", + Map.of("properties", Map.of("subB", Map.of("type", "keyword"))), + "textC", + Map.of("type", "text", "fields", Map.of("subC", Map.of("type", "keyword"))), + "aliasD", + Map.of("type", "alias", "path", "date") + ); + Map result = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(indexMappings, result, "", true); + assertMapEquals( + result, + Map + .of( + "response", + "integer", + "responseLatency", + "float", + "date", + "date", + "objectA.subA", + "keyword", + "objectB.subB", + "keyword", + "textC", + "text", + "textC.subC", + "keyword" + ) + ); + + Map result1 = new HashMap<>(); + ToolHelper.extractFieldNamesTypes(indexMappings, result1, "", false); + assertMapEquals( + result1, + Map + .of( + "response", + "integer", + "responseLatency", + "float", + "date", + "date", + "objectA.subA", + "keyword", + "objectB.subB", + "keyword", + "textC", + "text" + ) + ); + } + + private void assertMapEquals(Map expected, Map actual) { + assertEquals(expected.size(), actual.size()); + for (Map.Entry entry : expected.entrySet()) { + assertEquals(entry.getValue(), actual.get(entry.getKey())); + } + } +}