Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

semantic text bulk inference integration test #17

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/inference/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies {
compileOnly project(":server")
compileOnly project(path: xpackModule('core'))
testImplementation(testArtifact(project(xpackModule('core'))))
testImplementation(project(':x-pack:plugin:inference:qa:test-service-plugin'))
testImplementation project(':modules:reindex')
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@

public abstract class AbstractTestInferenceService implements InferenceService {

protected static int stringWeight(String input, int position) {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to specific methods on the different dense and sparse vector mock services

int hashCode = input.hashCode();
if (hashCode < 0) {
hashCode = -hashCode;
}
return hashCode + position;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
Expand All @@ -43,8 +44,22 @@ public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
}

public static class TestDenseModel extends Model {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added mock models as well

public TestDenseModel(String inferenceEntityId, TestDenseInferenceServiceExtension.TestServiceSettings serviceSettings) {
super(
new ModelConfigurations(
inferenceEntityId,
TaskType.TEXT_EMBEDDING,
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
serviceSettings
),
new ModelSecrets(new AbstractTestInferenceService.TestSecretSettings("api_key"))
);
}
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "text_embedding_test_service";
public static final String NAME = "text_embedding_test_service";

public TestInferenceService(InferenceServiceFactoryContext context) {}

Expand Down Expand Up @@ -83,9 +98,10 @@ public void infer(
ActionListener<InferenceServiceResults> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> listener.onResponse(
makeResults(input, ((TestServiceModel) model).getServiceSettings().dimensions())
);
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand All @@ -107,9 +123,10 @@ public void chunkedInfer(
ActionListener<List<ChunkedInferenceServiceResults>> listener
) {
switch (model.getConfigurations().getTaskType()) {
case ANY, TEXT_EMBEDDING -> listener.onResponse(
makeChunkedResults(input, ((TestServiceModel) model).getServiceSettings().dimensions())
);
case ANY, TEXT_EMBEDDING -> {
ServiceSettings modelServiceSettings = model.getServiceSettings();
listener.onResponse(makeChunkedResults(input, modelServiceSettings.dimensions(), modelServiceSettings.similarity()));
}
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
Expand All @@ -119,28 +136,30 @@ public void chunkedInfer(
}
}

private TextEmbeddingResults makeResults(List<String> input, int dimensions) {
private TextEmbeddingResults makeResults(List<String> input, int dimensions, SimilarityMeasure similarityMeasure) {
List<TextEmbeddingResults.Embedding> embeddings = new ArrayList<>();
for (int i = 0; i < input.size(); i++) {
List<Float> values = new ArrayList<>();
double[] doubleEmbeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure);
List<Float> floatEmbeddings = new ArrayList<>(dimensions);
for (int j = 0; j < dimensions; j++) {
values.add((float) stringWeight(input.get(i), j));
floatEmbeddings.add((float) doubleEmbeddings[j]);
}
embeddings.add(new TextEmbeddingResults.Embedding(values));
embeddings.add(new TextEmbeddingResults.Embedding(floatEmbeddings));
}
return new TextEmbeddingResults(embeddings);
}

private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input, int dimensions) {
private List<ChunkedInferenceServiceResults> makeChunkedResults(
List<String> input,
int dimensions,
SimilarityMeasure similarityMeasure
) {
var results = new ArrayList<ChunkedInferenceServiceResults>();
for (int i = 0; i < input.size(); i++) {
double[] values = new double[dimensions];
for (int j = 0; j < dimensions; j++) {
values[j] = stringWeight(input.get(i), j);
}
double[] embeddings = generateEmbedding(input.get(i), dimensions, similarityMeasure);
results.add(
new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults(
List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), values))
List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), embeddings))
)
);
}
Expand All @@ -150,6 +169,15 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
}

private static double[] generateEmbedding(String input, int dimensions, SimilarityMeasure similarityMeasure) {
double[] embedding = new double[dimensions];
for (int j = 0; j < dimensions; j++) {
embedding[j] = input.hashCode() + (double) j;
}

return embedding;
}
}

public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -44,8 +45,17 @@ public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
}

public static class TestSparseModel extends Model {
public TestSparseModel(String inferenceEntityId, TestServiceSettings serviceSettings) {
super(
new ModelConfigurations(inferenceEntityId, TaskType.SPARSE_EMBEDDING, TestInferenceService.NAME, serviceSettings),
new ModelSecrets(new AbstractTestInferenceService.TestSecretSettings("api_key"))
);
}
}

public static class TestInferenceService extends AbstractTestInferenceService {
private static final String NAME = "test_service";
public static final String NAME = "test_service";

public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}

Expand Down Expand Up @@ -121,7 +131,7 @@ private SparseEmbeddingResults makeResults(List<String> input) {
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<SparseEmbeddingResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j)));
tokens.add(new SparseEmbeddingResults.WeightedToken("feature_" + j, generateEmbedding(input.get(i), j)));
}
embeddings.add(new SparseEmbeddingResults.Embedding(tokens, false));
}
Expand All @@ -133,7 +143,7 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, stringWeight(input.get(i), j)));
tokens.add(new TextExpansionResults.WeightedToken("feature_" + j, generateEmbedding(input.get(i), j)));
}
results.add(
new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens)))
Expand All @@ -145,6 +155,11 @@ private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> inp
protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
return TestServiceSettings.fromMap(serviceSettingsMap);
}

private static float generateEmbedding(String input, int position) {
// Ensure non-negative and non-zero values for features
return Math.abs(input.hashCode()) + 1 + position;
}
}

public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.action.filter;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.junit.Before;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;

public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {

public static final String INDEX_NAME = "test-index";

@Before
public void setup() throws Exception {
storeSparseModel();
storeDenseModel();
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(TestInferencePlugin.class);
}

public void testBulkOperations() throws Exception {
Map<String, Integer> shardsSettings = Collections.singletonMap(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 10));
indicesAdmin().prepareCreate(INDEX_NAME).setMapping("""
{
"properties": {
"sparse_field": {
"type": "semantic_text",
"inference_id": "test_service"
},
"dense_field": {
"type": "semantic_text",
"inference_id": "text_embedding_test_service"
}
}
}
""").setSettings(shardsSettings).get();

int totalBulkReqs = randomIntBetween(2, 100);
long totalDocs = 0;
for (int bulkReqs = 0; bulkReqs < totalBulkReqs; bulkReqs++) {
BulkRequestBuilder bulkReqBuilder = client().prepareBulk();
int totalBulkSize = randomIntBetween(1, 100);
for (int bulkSize = 0; bulkSize < totalBulkSize; bulkSize++) {
String id = Long.toString(totalDocs);
boolean isIndexRequest = randomBoolean();
Map<String, Object> source = new HashMap<>();
source.put("sparse_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000));
source.put("dense_field", isIndexRequest && rarely() ? null : randomAlphaOfLengthBetween(0, 1000));
if (isIndexRequest) {
bulkReqBuilder.add(new IndexRequestBuilder(client()).setIndex(INDEX_NAME).setId(id).setSource(source));
totalDocs++;
} else {
boolean isUpsert = randomBoolean();
UpdateRequestBuilder request = new UpdateRequestBuilder(client()).setIndex(INDEX_NAME).setDoc(source);
if (isUpsert || totalDocs == 0) {
request.setDocAsUpsert(true);
totalDocs++;
} else {
// Update already existing document
id = Long.toString(randomLongBetween(0, totalDocs - 1));
}
request.setId(id);
bulkReqBuilder.add(request);
}
}
BulkResponse bulkResponse = bulkReqBuilder.get();
if (bulkResponse.hasFailures()) {
// Get more details in case something fails
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
if (bulkItemResponse.isFailed()) {
fail(
bulkItemResponse.getFailure().getCause(),
"Failed to index document %s: %s",
bulkItemResponse.getId(),
bulkItemResponse.getFailureMessage()
);
}
}
}
assertFalse(bulkResponse.hasFailures());
}

client().admin().indices().refresh(new RefreshRequest(INDEX_NAME)).get();

SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(0).trackTotalHits(true);
SearchResponse searchResponse = client().search(new SearchRequest(INDEX_NAME).source(sourceBuilder)).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo(totalDocs));
searchResponse.decRef();
}

private void storeSparseModel() throws Exception {
Model model = new TestSparseInferenceServiceExtension.TestSparseModel(
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
new TestSparseInferenceServiceExtension.TestServiceSettings(
TestSparseInferenceServiceExtension.TestInferenceService.NAME,
null,
false
)
);
storeModel(model);
}

private void storeDenseModel() throws Exception {
Model model = new TestDenseInferenceServiceExtension.TestDenseModel(
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
new TestDenseInferenceServiceExtension.TestServiceSettings(
TestDenseInferenceServiceExtension.TestInferenceService.NAME,
randomIntBetween(1, 100),
// dot product means that we need normalized vectors; it's not worth doing that in this test
randomValueOtherThan(SimilarityMeasure.DOT_PRODUCT, () -> randomFrom(SimilarityMeasure.values()))
)
);

storeModel(model);
}

private void storeModel(Model model) throws Exception {
ModelRegistry modelRegistry = new ModelRegistry(client());

AtomicReference<Boolean> storeModelHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

blockingCall(listener -> modelRegistry.storeModel(model, listener), storeModelHolder, exceptionHolder);

assertThat(storeModelHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
}

private <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<T> listener = ActionListener.wrap(r -> {
response.set(r);
latch.countDown();
}, e -> {
error.set(e);
latch.countDown();
});

function.accept(listener);
latch.await();
}

public static class TestInferencePlugin extends InferencePlugin {
public TestInferencePlugin(Settings settings) {
super(settings);
}

@Override
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
TestSparseInferenceServiceExtension.TestInferenceService::new,
TestDenseInferenceServiceExtension.TestInferenceService::new
);
}
}
}
Loading