Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Enable '.' for nested field in text embedding processor (#811) #825

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.15...2.x)
### Features
### Enhancements
* Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Adds dynamic knn query parameters efsearch and nprobes [#814](https://github.com/opensearch-project/neural-search/pull/814/)
- Enable '.' for nested field in text embedding processor ([#811](https://github.com/opensearch-project/neural-search/pull/811))
### Bug Fixes
- Fix for missing HybridQuery results when concurrent segment search is enabled ([#800](https://github.com/opensearch-project/neural-search/pull/800))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
Expand All @@ -21,6 +22,8 @@
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -120,7 +123,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
try {
validateEmbeddingFieldsValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocument);
Map<String, Object> processMap = buildMapWithTargetKeys(ingestDocument);
List<String> inferenceList = createInferenceList(processMap);
if (inferenceList.size() == 0) {
handler.accept(ingestDocument, null);
Expand Down Expand Up @@ -228,7 +231,7 @@ private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> i
List<String> inferenceList = null;
try {
validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeyAndOriginalValue(ingestDocumentWrapper.getIngestDocument());
processMap = buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
inferenceList = createInferenceList(processMap);
} catch (Exception e) {
ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
Expand Down Expand Up @@ -276,15 +279,17 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List<String>
}

@VisibleForTesting
Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestDocument) {
Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
Map<String, Object> mapWithProcessorKeys = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
String originalKey = fieldMapEntry.getKey();
Object targetKey = fieldMapEntry.getValue();
Pair<String, Object> processedNestedKey = processNestedKey(fieldMapEntry);
String originalKey = processedNestedKey.getKey();
Object targetKey = processedNestedKey.getValue();

if (targetKey instanceof Map) {
Map<String, Object> treeRes = new LinkedHashMap<>();
buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes);
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
Expand All @@ -293,20 +298,19 @@ Map<String, Object> buildMapWithTargetKeyAndOriginalValue(IngestDocument ingestD
return mapWithProcessorKeys;
}

private void buildMapWithProcessorKeyAndOriginalValueForMapType(
String parentKey,
Object processorKey,
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> treeRes
) {
if (processorKey == null || sourceAndMetadataMap == null) return;
@VisibleForTesting
void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> sourceAndMetadataMap, Map<String, Object> treeRes) {
if (Objects.isNull(processorKey) || Objects.isNull(sourceAndMetadataMap)) {
return;
}
if (processorKey instanceof Map) {
Map<String, Object> next = new LinkedHashMap<>();
if (sourceAndMetadataMap.get(parentKey) instanceof Map) {
for (Map.Entry<String, Object> nestedFieldMapEntry : ((Map<String, Object>) processorKey).entrySet()) {
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
Pair<String, Object> processedNestedKey = processNestedKey(nestedFieldMapEntry);
buildNestedMap(
processedNestedKey.getKey(),
processedNestedKey.getValue(),
(Map<String, Object>) sourceAndMetadataMap.get(parentKey),
next
);
Expand All @@ -317,21 +321,46 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType(
List<Object> listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
Map<String, Object> map = new LinkedHashMap<>();
map.put(nestedFieldMapEntry.getKey(), listOfStrings);
buildMapWithProcessorKeyAndOriginalValueForMapType(
nestedFieldMapEntry.getKey(),
nestedFieldMapEntry.getValue(),
map,
next
);
buildNestedMap(nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), map, next);
}
}
treeRes.put(parentKey, next);
treeRes.merge(parentKey, next, (v1, v2) -> {
if (v1 instanceof Collection && v2 instanceof Collection) {
((Collection) v1).addAll((Collection) v2);
return v1;
} else if (v1 instanceof Map && v2 instanceof Map) {
((Map) v1).putAll((Map) v2);
return v1;
} else {
return v2;
}
});
} else {
String key = String.valueOf(processorKey);
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
}
}

/**
* Process the nested key, such as "a.b.c" to "a", "b.c"
* @param nestedFieldMapEntry
* @return A pair of the original key and the target key
*/
@VisibleForTesting
protected Pair<String, Object> processNestedKey(final Map.Entry<String, Object> nestedFieldMapEntry) {
String originalKey = nestedFieldMapEntry.getKey();
Object targetKey = nestedFieldMapEntry.getValue();
int nestedDotIndex = originalKey.indexOf('.');
if (nestedDotIndex != -1) {
Map<String, Object> newTargetKey = new LinkedHashMap<>();
newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey);
targetKey = newTargetKey;

originalKey = originalKey.substring(0, nestedDotIndex);
}
return new ImmutablePair<>(originalKey, targetKey);
}

private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.processor;

import com.google.common.collect.ImmutableList;
import org.apache.commons.lang.math.RandomUtils;
import org.opensearch.index.mapper.IndexFieldMapper;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
Expand Down Expand Up @@ -58,4 +59,17 @@ protected List<List<Float>> createMockVectorResult() {
modelTensorList.add(number7);
return modelTensorList;
}

protected List<List<Float>> createRandomOneDimensionalMockVector(int numOfVectors, int vectorDimension, float min, float max) {
List<List<Float>> result = new ArrayList<>();
for (int i = 0; i < numOfVectors; i++) {
List<Float> numbers = new ArrayList<>();
for (int j = 0; j < vectorDimension; j++) {
Float nextFloat = RandomUtils.nextFloat() * (max - min) + min;
numbers.add(nextFloat);
}
result.add(numbers);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,31 @@
import org.apache.http.message.BasicHeader;
import org.apache.http.util.EntityUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.search.join.ScoreMode;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;

import com.google.common.collect.ImmutableList;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

public class TextEmbeddingProcessorIT extends BaseNeuralSearchIT {

private static final String INDEX_NAME = "text_embedding_index";

private static final String PIPELINE_NAME = "pipeline-hybrid";
protected static final String QUERY_TEXT = "hello";
protected static final String LEVEL_1_FIELD = "nested_passages";
protected static final String LEVEL_2_FIELD = "level_2";
protected static final String LEVEL_3_FIELD_TEXT = "level_3_text";
protected static final String LEVEL_3_FIELD_EMBEDDING = "level_3_embedding";
private final String INGEST_DOC1 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc1.json").toURI()));
private final String INGEST_DOC2 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc2.json").toURI()));
private final String INGEST_DOC3 = Files.readString(Path.of(classLoader.getResource("processor/ingest_doc3.json").toURI()));
private final String BULK_ITEM_TEMPLATE = Files.readString(
Path.of(classLoader.getResource("processor/bulk_item_template.json").toURI())
);
Expand Down Expand Up @@ -77,6 +87,66 @@ public void testTextEmbeddingProcessor_batch() throws Exception {
}
}

public void testNestedFieldMapping_whenDocumentsIngested_thenSuccessful() throws Exception {
String modelId = null;
try {
modelId = uploadTextEmbeddingModel();
loadModel(modelId);
createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING);
createTextEmbeddingIndex();
ingestDocument(INGEST_DOC3, "3");

Map<String, Object> sourceMap = (Map<String, Object>) getDocById(INDEX_NAME, "3").get("_source");
assertNotNull(sourceMap);
assertTrue(sourceMap.containsKey(LEVEL_1_FIELD));
Map<String, Object> nestedPassages = (Map<String, Object>) sourceMap.get(LEVEL_1_FIELD);
assertTrue(nestedPassages.containsKey(LEVEL_2_FIELD));
Map<String, Object> level2 = (Map<String, Object>) nestedPassages.get(LEVEL_2_FIELD);
assertEquals(QUERY_TEXT, level2.get(LEVEL_3_FIELD_TEXT));
assertTrue(level2.containsKey(LEVEL_3_FIELD_EMBEDDING));
List<Double> embeddings = (List<Double>) level2.get(LEVEL_3_FIELD_EMBEDDING);
assertEquals(768, embeddings.size());
for (Double embedding : embeddings) {
assertTrue(embedding >= 0.0 && embedding <= 1.0);
}

NeuralQueryBuilder neuralQueryBuilderQuery = new NeuralQueryBuilder(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD + "." + LEVEL_3_FIELD_EMBEDDING,
QUERY_TEXT,
"",
modelId,
10,
null,
null,
null,
null,
null
);
QueryBuilder queryNestedLowerLevel = QueryBuilders.nestedQuery(
LEVEL_1_FIELD + "." + LEVEL_2_FIELD,
neuralQueryBuilderQuery,
ScoreMode.Total
);
QueryBuilder queryNestedHighLevel = QueryBuilders.nestedQuery(LEVEL_1_FIELD, queryNestedLowerLevel, ScoreMode.Total);

Map<String, Object> searchResponseAsMap = search(INDEX_NAME, queryNestedHighLevel, 1);
assertNotNull(searchResponseAsMap);

Map<String, Object> hits = (Map<String, Object>) searchResponseAsMap.get("hits");
assertNotNull(hits);

assertEquals(1.0, hits.get("max_score"));
List<Map<String, Object>> listOfHits = (List<Map<String, Object>>) hits.get("hits");
assertNotNull(listOfHits);
assertEquals(1, listOfHits.size());
Map<String, Object> hitsInner = listOfHits.get(0);
assertEquals("3", hitsInner.get("_id"));
assertEquals(1.0, hitsInner.get("_score"));
} finally {
wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null);
}
}

private String uploadTextEmbeddingModel() throws Exception {
String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI()));
return registerModelGroupAndUploadModel(requestBody);
Expand Down
Loading
Loading