Skip to content

Commit

Permalink
Rebasing
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Jain <[email protected]>
  • Loading branch information
vibrantvarun committed Sep 27, 2023
2 parents 9c010e7 + 7bef7a0 commit d0b70c7
Show file tree
Hide file tree
Showing 24 changed files with 2,272 additions and 314 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x)
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ dependencies {
runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12'
runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA'
runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}"
runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
runtimeOnly group: 'org.json', name: 'json', version: '20230227'
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
Expand Down Expand Up @@ -100,10 +102,38 @@ public void inferenceSentences(
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener);
}

private void inferenceSentencesWithRetry(
public void inferenceSentencesWithMapResult(
@NonNull final String modelId,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Map<String, ?>>> listener
) {
retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Map<String, ?>>> listener
) {
MLInput mlInput = createMLInput(null, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Map<String, ?>> result = buildMapResultFromResponse(mlOutput);
listener.onResponse(result);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private void retryableInferenceSentencesWithVectorResult(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
Expand All @@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry(
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
Expand All @@ -144,4 +173,22 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private List<Map<String, ?>> buildMapResultFromResponse(MLOutput mlOutput) {
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
final List<ModelTensors> tensorOutputList = modelTensorOutput.getMlModelOutputs();
if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) {
throw new IllegalStateException(
"Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"
);
}
List<Map<String, ?>> resultMaps = new ArrayList<>();
for (ModelTensors tensors : tensorOutputList) {
List<ModelTensor> tensorList = tensors.getMlModelTensors();
for (ModelTensor tensor : tensorList) {
resultMaps.add(tensor.getDataAsMap());
}
}
return resultMaps;
}

}
15 changes: 12 additions & 3 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -32,15 +31,18 @@
import org.opensearch.neuralsearch.processor.NeuralQueryProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand Down Expand Up @@ -83,6 +85,7 @@ public Collection<Object> createComponents(
) {
NeuralSearchClusterUtil.instance().initialize(clusterService);

Check warning on line 86 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L86

Added line #L86 was not covered by tests
NeuralQueryBuilder.initialize(clientAccessor);
SparseEncodingQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
Expand All @@ -91,14 +94,20 @@ public Collection<Object> createComponents(
public List<QuerySpec<?>> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent)
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
);
}

@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Loading

0 comments on commit d0b70c7

Please sign in to comment.