Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Supporting sparse semantic retrieval in neural search #343

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x)
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry(
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you name function like that? An function should be a verb phrase.

} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +173,22 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"
);
}
List<Map<String, ?>> resultMaps = new ArrayList<>();
for (ModelTensors tensors : tensorOutputList) {
List<ModelTensor> tensorList = tensors.getMlModelTensors();
for (ModelTensor tensor : tensorList) {
resultMaps.add(tensor.getDataAsMap());
}
}
return resultMaps;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -31,15 +30,18 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
Expand All @@ -62,7 +64,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin,
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();

@Override
public Collection<Object> createComponents(
Expand All @@ -79,6 +81,7 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -87,14 +90,20 @@ public Collection<Object> createComponents(
public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent)
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Loading
Loading