Skip to content

Commit

Permalink
Fix bug where document embedding fails to be generated due to documen…
Browse files Browse the repository at this point in the history
…t has dot in field name (opensearch-project#1062)

* Fix bug where document embedding fails to be generated due to document has dot in field name

Signed-off-by: Yizhe Liu <[email protected]>

* Address comments

Signed-off-by: Yizhe Liu <[email protected]>

---------

Signed-off-by: Yizhe Liu <[email protected]>
  • Loading branch information
yizheliu-amazon committed Jan 10, 2025
1 parent 38e1f30 commit 0c65880
Show file tree
Hide file tree
Showing 7 changed files with 658 additions and 96 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998))
- Fix bug where ingested document has list of nested objects ([#1040](https://github.com/opensearch-project/neural-search/pull/1040))
- Fixed document source and score field mismatch in sorted hybrid queries ([#1043](https://github.com/opensearch-project/neural-search/pull/1043))
- Fix bug where embedding is missing when ingested document has "." in field name, and mismatches fieldMap config ([#1062](https://github.com/opensearch-project/neural-search/pull/1062))
### Infrastructure
- Update batch related tests to use batch_size in processor & refactor BWC version check ([#852](https://github.com/opensearch-project/neural-search/pull/852))
- Fix CI for JDK upgrade towards 21 ([#835](https://github.com/opensearch-project/neural-search/pull/835))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
Expand All @@ -150,6 +151,15 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
}
}

@VisibleForTesting
void preprocessIngestDocument(IngestDocument ingestDocument) {
if (ingestDocument == null || ingestDocument.getSourceAndMetadata() == null) return;
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap);
unflattened.forEach(ingestDocument::setFieldValue);
sourceAndMetadataMap.keySet().removeIf(key -> key.contains("."));
}

/**
* This is the function which does actual inference work for batchExecute interface.
* @param inferenceList a list of String for inference.
Expand Down Expand Up @@ -244,12 +254,14 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Map<String, Object> processMap = null;
List<String> inferenceList = null;
IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument();
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
processMap = buildMapWithTargetKeys(ingestDocument);
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
ingestDocumentWrapper.update(ingestDocument, e);
} finally {
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
}
Expand Down Expand Up @@ -333,13 +345,14 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
Object nestedSourceValue = x.get(processedNestedKey.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
map.put(processedNestedKey.getKey(), listOfStrings);
buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
Expand Down Expand Up @@ -387,7 +400,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
ProcessorDocumentUtils.unflattenJson(fieldMap),
indexName,
clusterService,
environment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

/**
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
Expand Down Expand Up @@ -178,4 +181,166 @@ private static void validateDepth(
);
}
}

/**
* Unflatten a JSON object represented as a {@code Map<String, Object>}, possibly with dot in field name,
* into a nested {@code Map<String, Object>}
* "Object" can be either a {@code Map<String, Object>} or a {@code List<Object>} or simply a String.
* For example, input is {"a.b": "c"}, output is {"a":{"b": "c"}}.
* Another example:
* input is {"a": [{"b.c": "d"}, {"b.c": "e"}]},
* output is {"a": [{"b": {"c": "d"}}, {"b": {"c": "e"}}]}
* @param originalJsonMap the original JSON object represented as a {@code Map<String, Object>}
* @return the nested JSON object represented as a nested {@code Map<String, Object>}
* @throws IllegalArgumentException if the originalJsonMap is null or has invalid dot usage in field name
*/
public static Map<String, Object> unflattenJson(Map<String, Object> originalJsonMap) {
if (originalJsonMap == null) {
throw new IllegalArgumentException("originalJsonMap cannot be null");
}
Map<String, Object> result = new HashMap<>();
Stack<ProcessJsonObjectItem> stack = new Stack<>();

// Push initial items to stack
for (Map.Entry<String, Object> entry : originalJsonMap.entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonObjectItem item = stack.pop();
String key = item.key;
Object value = item.value;
Map<String, Object> currentMap = item.targetMap;

// Handle nested value
if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
for (Map.Entry<String, Object> entry : ((Map<String, Object>) value).entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap));
}
value = nestedMap;
} else if (value instanceof List) {
value = handleList((List<Object>) value);
}

// If key contains dot, split and create nested structure
unflattenSingleItem(key, value, currentMap);
}

return result;
}

private static List<Object> handleList(List<Object> list) {
List<Object> result = new ArrayList<>();
Stack<ProcessJsonListItem> stack = new Stack<>();

// Push initial items to stack
for (int i = list.size() - 1; i >= 0; i--) {
stack.push(new ProcessJsonListItem(list.get(i), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonListItem item = stack.pop();
Object value = item.value;
List<Object> targetList = item.targetList;

if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
Map<String, Object> sourceMap = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
stack.push(new ProcessJsonListItem(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap), targetList));
}
targetList.add(nestedMap);
} else if (value instanceof List) {
List<Object> nestedList = new ArrayList<>();
for (Object listItem : (List<Object>) value) {
stack.push(new ProcessJsonListItem(listItem, nestedList));
}
targetList.add(nestedList);
} else if (value instanceof ProcessJsonObjectItem) {
ProcessJsonObjectItem processJsonObjectItem = (ProcessJsonObjectItem) value;
Map<String, Object> tempMap = new HashMap<>();
unflattenSingleItem(processJsonObjectItem.key, processJsonObjectItem.value, tempMap);
targetList.set(targetList.size() - 1, tempMap);
} else {
targetList.add(value);
}
}

return result;
}

private static void unflattenSingleItem(String key, Object value, Map<String, Object> result) {
if (StringUtils.isBlank(key)) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}
if (key.contains(".")) {
// Use split with -1 limit to preserve trailing empty strings
String[] parts = key.split("\\.", -1);
Map<String, Object> current = result;

for (int i = 0; i < parts.length; i++) {
if (StringUtils.isBlank(parts[i])) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", key));
}
if (i == parts.length - 1) {
current.put(parts[i], value);
continue;
}
current = (Map<String, Object>) current.computeIfAbsent(parts[i], k -> new HashMap<>());
}
} else {
result.put(key, value);
}
}

/**
* Validate if field name is in correct format, which is either "a", or "a.b.c".
* If field name is like "..a..b", "a..b", "a.b..", it should be invalid.
* This is done via checking if a string contains empty segments when split by dots.
*
* @param input the string to check
* @throws IllegalArgumentException if the input is null or has invalid dot usage
*/
private static void validateFieldName(String input) {
if (StringUtils.isBlank(input)) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}

// Use split with -1 limit to preserve trailing empty strings
String[] segments = input.split("\\.", -1);

// Check if any segment is empty
for (String segment : segments) {
if (StringUtils.isBlank(segment)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field name '%s' contains invalid dot usage", input));
}
}
}

// Helper classes to maintain state during iteration
private static class ProcessJsonObjectItem {
String key;
Object value;
Map<String, Object> targetMap;

ProcessJsonObjectItem(String key, Object value, Map<String, Object> targetMap) {
this.key = key;
this.value = value;
this.targetMap = targetMap;
}
}

private static class ProcessJsonListItem {
Object value;
List<Object> targetList;

ProcessJsonListItem(Object value, List<Object> targetList) {
this.value = value;
this.targetList = targetList;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -168,6 +169,23 @@ private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Opt
}
}

private void assertDocWithLevel2AsList(Map<String, Object> sourceMap) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
assertTrue(sourceMap.get(LEVEL_1_FIELD) instanceof List);
List<Map<String, Object>> nestedPassages = (List<Map<String, Object>>) sourceMap.get(LEVEL_1_FIELD);
nestedPassages.forEach(nestedPassage -> {
assertTrue(nestedPassage.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassage.get(LEVEL_2_FIELD);
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
});
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down Expand Up @@ -232,6 +250,49 @@ public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
}
}

@SuppressWarnings("unchecked")
public void testNestedFieldMapping_whenDocumentInListIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createIndexWithPipeline(INDEX_NAME, "IndexMappings.json", PIPELINE_NAME);
ingestDocument(INDEX_NAME, INGEST_DOC5, "5");

assertDocWithLevel2AsList((Map<String, Object>) getDocById(INDEX_NAME, "5").get("_source"));

NeuralQueryBuilder neuralQueryBuilderQuery = NeuralQueryBuilder.builder()
.fieldName(LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING)
.queryText(QUERY_TEXT)
.modelId(modelId)
.k(10)
.build();

QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.getFirst();
assertEquals("5", innerHitDetails.get("_id"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading

0 comments on commit 0c65880

Please sign in to comment.