Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix bug where document embedding fails to be generated due to document has dot in field name #1071

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
Expand All @@ -150,6 +151,15 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
}
}

@VisibleForTesting
void preprocessIngestDocument(IngestDocument ingestDocument) {
if (ingestDocument == null || ingestDocument.getSourceAndMetadata() == null) return;
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap);
unflattened.forEach(ingestDocument::setFieldValue);
sourceAndMetadataMap.keySet().removeIf(key -> key.contains("."));
}

/**
* This is the function which does actual inference work for batchExecute interface.
* @param inferenceList a list of String for inference.
Expand Down Expand Up @@ -244,12 +254,14 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Map<String, Object> processMap = null;
List<String> inferenceList = null;
IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument();
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
processMap = buildMapWithTargetKeys(ingestDocument);
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
ingestDocumentWrapper.update(ingestDocument, e);
} finally {
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
}
Expand Down Expand Up @@ -333,13 +345,14 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
Object nestedSourceValue = x.get(processedNestedKey.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
map.put(processedNestedKey.getKey(), listOfStrings);
buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
Expand Down Expand Up @@ -387,7 +400,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
ProcessorDocumentUtils.unflattenJson(fieldMap),
indexName,
clusterService,
environment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

/**
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
Expand Down Expand Up @@ -178,4 +181,166 @@ private static void validateDepth(
);
}
}

/**
* Unflatten a JSON object represented as a {@code Map<String, Object>}, possibly with dot in field name,
* into a nested {@code Map<String, Object>}
* "Object" can be either a {@code Map<String, Object>} or a {@code List<Object>} or simply a String.
* For example, input is {"a.b": "c"}, output is {"a":{"b": "c"}}.
* Another example:
* input is {"a": [{"b.c": "d"}, {"b.c": "e"}]},
* output is {"a": [{"b": {"c": "d"}}, {"b": {"c": "e"}}]}
* @param originalJsonMap the original JSON object represented as a {@code Map<String, Object>}
* @return the nested JSON object represented as a nested {@code Map<String, Object>}
* @throws IllegalArgumentException if the originalJsonMap is null or has invalid dot usage in field name
*/
public static Map<String, Object> unflattenJson(Map<String, Object> originalJsonMap) {
if (originalJsonMap == null) {
throw new IllegalArgumentException("originalJsonMap cannot be null");
}
Map<String, Object> result = new HashMap<>();
Stack<ProcessJsonObjectItem> stack = new Stack<>();

// Push initial items to stack
for (Map.Entry<String, Object> entry : originalJsonMap.entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonObjectItem item = stack.pop();
String key = item.key;
Object value = item.value;
Map<String, Object> currentMap = item.targetMap;

// Handle nested value
if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
for (Map.Entry<String, Object> entry : ((Map<String, Object>) value).entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap));
}
value = nestedMap;
} else if (value instanceof List) {
value = handleList((List<Object>) value);
}

// If key contains dot, split and create nested structure
unflattenSingleItem(key, value, currentMap);
}

return result;
}

private static List<Object> handleList(List<Object> list) {
List<Object> result = new ArrayList<>();
Stack<ProcessJsonListItem> stack = new Stack<>();

// Push initial items to stack
for (int i = list.size() - 1; i >= 0; i--) {
stack.push(new ProcessJsonListItem(list.get(i), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonListItem item = stack.pop();
Object value = item.value;
List<Object> targetList = item.targetList;

if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
Map<String, Object> sourceMap = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
stack.push(new ProcessJsonListItem(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap), targetList));
}
targetList.add(nestedMap);
} else if (value instanceof List) {
List<Object> nestedList = new ArrayList<>();
for (Object listItem : (List<Object>) value) {
stack.push(new ProcessJsonListItem(listItem, nestedList));
}
targetList.add(nestedList);
} else if (value instanceof ProcessJsonObjectItem) {
ProcessJsonObjectItem processJsonObjectItem = (ProcessJsonObjectItem) value;
Map<String, Object> tempMap = new HashMap<>();
unflattenSingleItem(processJsonObjectItem.key, processJsonObjectItem.value, tempMap);
targetList.set(targetList.size() - 1, tempMap);
} else {
targetList.add(value);
}
}

return result;
}

private static void unflattenSingleItem(String key, Object value, Map<String, Object> result) {
if (StringUtils.isBlank(key)) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}
if (key.contains(".")) {
// Use split with -1 limit to preserve trailing empty strings
String[] parts = key.split("\\.", -1);
Map<String, Object> current = result;

for (int i = 0; i < parts.length; i++) {
if (StringUtils.isBlank(parts[i])) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", key));
}
if (i == parts.length - 1) {
current.put(parts[i], value);
continue;
}
current = (Map<String, Object>) current.computeIfAbsent(parts[i], k -> new HashMap<>());
}
} else {
result.put(key, value);
}
}

/**
* Validate if field name is in correct format, which is either "a", or "a.b.c".
* If field name is like "..a..b", "a..b", "a.b..", it should be invalid.
* This is done via checking if a string contains empty segments when split by dots.
*
* @param input the string to check
* @throws IllegalArgumentException if the input is null or has invalid dot usage
*/
private static void validateFieldName(String input) {
if (StringUtils.isBlank(input)) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}

// Use split with -1 limit to preserve trailing empty strings
String[] segments = input.split("\\.", -1);

// Check if any segment is empty
for (String segment : segments) {
if (StringUtils.isBlank(segment)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", input));
}
}
}

// Helper classes to maintain state during iteration
private static class ProcessJsonObjectItem {
String key;
Object value;
Map<String, Object> targetMap;

ProcessJsonObjectItem(String key, Object value, Map<String, Object> targetMap) {
this.key = key;
this.value = value;
this.targetMap = targetMap;
}
}

private static class ProcessJsonListItem {
Object value;
List<Object> targetList;

ProcessJsonListItem(Object value, List<Object> targetList) {
this.value = value;
this.targetList = targetList;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -176,6 +177,23 @@ private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Opt
}
}

private void assertDocWithLevel2AsList(Map<String, Object> sourceMap) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
assertTrue(sourceMap.get(LEVEL_1_FIELD) instanceof List);
List<Map<String, Object>> nestedPassages = (List<Map<String, Object>>) sourceMap.get(LEVEL_1_FIELD);
nestedPassages.forEach(nestedPassage -> {
assertTrue(nestedPassage.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassage.get(LEVEL_2_FIELD);
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
});
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down Expand Up @@ -240,6 +258,49 @@ public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
}
}

@SuppressWarnings("unchecked")
public void testNestedFieldMapping_whenDocumentInListIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC5, "5");

assertDocWithLevel2AsList((Map<String, Object>) getDocById(INDEX_NAME, "5").get("_source"));

NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder()
.fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING)
.queryText(QUERY_TEXT)
.modelId(modelId)
.k(10)
.build();

QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.getFirst();
assertEquals("5", innerHitDetails.get("_id"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading
Loading