Skip to content

Commit

Permalink
Fix bug where document embedding fails to be generated due to documen…
Browse files Browse the repository at this point in the history
…t has dot in field name, which does not match field mapping exactly

Signed-off-by: Yizhe Liu <[email protected]>
  • Loading branch information
yizheliu-amazon committed Jan 6, 2025
1 parent 4c119f0 commit 8559262
Show file tree
Hide file tree
Showing 6 changed files with 631 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
Expand All @@ -150,6 +151,16 @@ public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Ex
}
}

@VisibleForTesting
void preprocessIngestDocument(IngestDocument ingestDocument) {
if (ingestDocument == null) return;
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
if (sourceAndMetadataMap == null) return;
Map<String, Object> unflattened = ProcessorDocumentUtils.unflattenJson(sourceAndMetadataMap);
unflattened.forEach(ingestDocument::setFieldValue);
sourceAndMetadataMap.keySet().removeIf(key -> key.contains("."));
}

/**
* This is the function which does actual inference work for batchExecute interface.
* @param inferenceList a list of String for inference.
Expand Down Expand Up @@ -244,12 +255,14 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
Map<String, Object> processMap = null;
List<String> inferenceList = null;
IngestDocument ingestDocument = ingestDocumentWrapper.getIngestDocument();
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
preprocessIngestDocument(ingestDocument);
validateEmbeddingFieldsValue(ingestDocument);
processMap = buildMapWithTargetKeys(ingestDocument);
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
ingestDocumentWrapper.update(ingestDocument, e);
} finally {
dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
}
Expand Down Expand Up @@ -333,13 +346,14 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
} else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
List<Map<String, Object>> list = (List<Map<String, Object>>) sourceAndMetadataMap.get(parentKey);
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
List<Object> listOfStrings = list.stream().map(x -> {
Object nestedSourceValue = x.get(nestedFieldMapEntry.getKey());
Object nestedSourceValue = x.get(processedNestedKey.getKey());
return normalizeSourceValue(nestedSourceValue);
}).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
map.put(processedNestedKey.getKey(), listOfStrings);
buildNestedMap(processedNestedKey.getKey(), processedNestedKey.getValue(), map, next);
}
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
Expand Down Expand Up @@ -387,7 +401,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
ProcessorDocumentUtils.validateMapTypeValue(
FIELD_MAP_FIELD,
sourceAndMetadataMap,
fieldMap,
ProcessorDocumentUtils.unflattenJson(fieldMap),
indexName,
clusterService,
environment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;

/**
* This class is used to accommodate the common code pieces of parsing, validating and processing the document for multiple
Expand Down Expand Up @@ -178,4 +181,128 @@ private static void validateDepth(
);
}
}

/**
* Unflatten a JSON object represented as a {@code Map<String, Object>}, possibly with dot in field name,
* into a nested {@code Map<String, Object>}
* "Object" can be either a {@code Map<String, Object>} or a {@code List<Object>} or simply a String.
* For example, input is {"a.b": "c"}, output is {"a":{"b": "c"}}.
* Another example:
* input is {"a": [{"b.c": "d"}, {"b.c": "e"}]},
* output is {"a": [{"b": {"c": "d"}}, {"b": {"c": "e"}}]}
* @param originalJsonMap the original JSON object represented as a {@code Map<String, Object>}
* @return the nested JSON object represented as a nested {@code Map<String, Object>}
*/
public static Map<String, Object> unflattenJson(Map<String, Object> originalJsonMap) {
Map<String, Object> result = new HashMap<>();
Stack<ProcessJsonObjectItem> stack = new Stack<>();

// Push initial items to stack
for (Map.Entry<String, Object> entry : originalJsonMap.entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonObjectItem item = stack.pop();
String key = item.key;
Object value = item.value;
Map<String, Object> currentMap = item.targetMap;

// Handle nested value
if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
for (Map.Entry<String, Object> entry : ((Map<String, Object>) value).entrySet()) {
stack.push(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap));
}
value = nestedMap;
} else if (value instanceof List) {
value = handleList((List<Object>) value);
}

// If key contains dot, split and create nested structure
unflattenSingleItem(key, value, currentMap);
}

return result;
}

private static List<Object> handleList(List<Object> list) {
List<Object> result = new ArrayList<>();
Stack<ProcessJsonListItem> stack = new Stack<>();

// Push initial items to stack
for (int i = list.size() - 1; i >= 0; i--) {
stack.push(new ProcessJsonListItem(list.get(i), result));
}

// Process items until stack is empty
while (!stack.isEmpty()) {
ProcessJsonListItem item = stack.pop();
Object value = item.value;
List<Object> targetList = item.targetList;

if (value instanceof Map) {
Map<String, Object> nestedMap = new HashMap<>();
Map<String, Object> sourceMap = (Map<String, Object>) value;
for (Map.Entry<String, Object> entry : sourceMap.entrySet()) {
stack.push(new ProcessJsonListItem(new ProcessJsonObjectItem(entry.getKey(), entry.getValue(), nestedMap), targetList));
}
targetList.add(nestedMap);
} else if (value instanceof List) {
List<Object> nestedList = new ArrayList<>();
for (Object listItem : (List<Object>) value) {
stack.push(new ProcessJsonListItem(listItem, nestedList));
}
targetList.add(nestedList);
} else if (value instanceof ProcessJsonObjectItem) {
ProcessJsonObjectItem processJsonObjectItem = (ProcessJsonObjectItem) value;
Map<String, Object> tempMap = new HashMap<>();
unflattenSingleItem(processJsonObjectItem.key, processJsonObjectItem.value, tempMap);
targetList.set(targetList.size() - 1, tempMap);
} else {
targetList.add(value);
}
}

return result;
}

private static void unflattenSingleItem(String key, Object value, Map<String, Object> result) {
if (key.contains(".")) {
String[] parts = key.split("\\.");
Map<String, Object> current = result;

for (int i = 0; i < parts.length - 1; i++) {
current = (Map<String, Object>) current.computeIfAbsent(parts[i], k -> new HashMap<>());
}
current.put(parts[parts.length - 1], value);
} else {
result.put(key, value);
}
}

// Helper classes to maintain state during iteration
private static class ProcessJsonObjectItem {
String key;
Object value;
Map<String, Object> targetMap;

ProcessJsonObjectItem(String key, Object value, Map<String, Object> targetMap) {
this.key = key;
this.value = value;
this.targetMap = targetMap;
}
}

private static class ProcessJsonListItem {
Object value;
List<Object> targetList;

ProcessJsonListItem(Object value, List<Object> targetList) {
this.value = value;
this.targetList = targetList;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String INGEST_DOC4 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc4.json").toURI()));
private final String INGEST_DOC5 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc5.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -176,6 +177,23 @@ private void assertDoc(Map<String, Object> sourceMap, String textFieldValue, Opt
}
}

private void assertDocWithLevel2AsList(Map<String, Object> sourceMap) {
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
assertTrue(sourceMap.get(LEVEL_1_FIELD) instanceof List);
List<Map<String, Object>> nestedPassages = (List<Map<String, Object>>) sourceMap.get(LEVEL_1_FIELD);
nestedPassages.forEach(nestedPassage -> {
assertTrue(nestedPassage.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassage.get(LEVEL_2_FIELD);
Map<String, Object> level3 = (Map<String, Object>) level2.get(LEVEL_3_FIELD_CONTAINER);
List<Double> embeddings = (List<Double>) level3.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}
});
}

public void testTextEmbeddingProcessor_withBatchSizeInProcessor() throws Exception {
String modelId = null;
try {
Expand Down Expand Up @@ -240,6 +258,56 @@ public void testTextEmbeddingProcessor_withFailureAndSkip() throws Exception {
}
}

@SuppressWarnings("unchecked")
public void testNestedFieldMapping_whenDocumentInListIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC5, "5");

assertDocWithLevel2AsList((Map<String, Object>) getDocById(INDEX_NAME, "5").get("_source"));

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_CONTAINER + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
10,
null,
null,
null,
null,
null,
null,
null
);
QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 2);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());

Map<String, Object> innerHitDetails = listOfHits.getFirst();
assertEquals("5", innerHitDetails.get("_id"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading

0 comments on commit 8559262

Please sign in to comment.