Skip to content

Commit

Permalink
[ML] Always update tokenisation options for chunked inference (#106718)
Browse files Browse the repository at this point in the history
Fixes an issue where if the defaults were used the input was truncated
davidkyle authored Mar 25, 2024
1 parent 2196576 commit 092f95e
Showing 4 changed files with 127 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -251,9 +251,9 @@ public void chunkedInfer(
return;
}

var configUpdate = chunkingOptions.settingsArePresent()
var configUpdate = chunkingOptions != null
? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span())
: TextEmbeddingConfigUpdate.EMPTY_INSTANCE;
: new TokenizationConfigUpdate(null, null);

var request = InferTrainedModelDeploymentAction.Request.forTextInput(
model.getConfigurations().getInferenceEntityId(),
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
@@ -288,7 +289,7 @@ public void chunkedInfer(
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
ChunkingOptions chunkingOptions,
@Nullable ChunkingOptions chunkingOptions,
ActionListener<List<ChunkedInferenceServiceResults>> listener
) {
try {
@@ -298,9 +299,9 @@ public void chunkedInfer(
return;
}

var configUpdate = chunkingOptions.settingsArePresent()
var configUpdate = chunkingOptions != null
? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span())
: TextExpansionConfigUpdate.EMPTY_UPDATE;
: new TokenizationConfigUpdate(null, null);

var request = InferTrainedModelDeploymentAction.Request.forTextInput(
model.getConfigurations().getInferenceEntityId(),
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;

import java.util.ArrayList;
@@ -38,6 +39,8 @@
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
@@ -410,6 +413,63 @@ public void testChunkInfer() {
assertTrue("Listener not called", gotResults.get());
}

@SuppressWarnings("unchecked")
public void testChunkInferSetsTokenization() {
var expectedSpan = new AtomicInteger();
var expectedWindowSize = new AtomicReference<Integer>();

Client client = mock(Client.class);
ThreadPool threadpool = new TestThreadPool("test");
try {
when(client.threadPool()).thenReturn(threadpool);
doAnswer(invocationOnMock -> {
var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1];
assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class));
var update = (TokenizationConfigUpdate) request.getUpdate();
assertEquals(update.getSpanSettings().span(), expectedSpan.get());
assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get());
return null;
}).when(client)
.execute(
same(InferTrainedModelDeploymentAction.INSTANCE),
any(InferTrainedModelDeploymentAction.Request.class),
any(ActionListener.class)
);

var model = new MultilingualE5SmallModel(
"foo",
TaskType.TEXT_EMBEDDING,
"e5",
new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform")
);
var service = createService(client);

expectedSpan.set(-1);
expectedWindowSize.set(null);
service.chunkedInfer(
model,
List.of("foo", "bar"),
Map.of(),
InputType.SEARCH,
null,
ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage()))
);

expectedSpan.set(-1);
expectedWindowSize.set(256);
service.chunkedInfer(
model,
List.of("foo", "bar"),
Map.of(),
InputType.SEARCH,
new ChunkingOptions(256, null),
ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage()))
);
} finally {
terminate(threadpool);
}
}

private ElasticsearchInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
return new ElasticsearchInternalService(context);
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;

import java.util.ArrayList;
import java.util.Collections;
@@ -35,6 +36,8 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
@@ -394,6 +397,64 @@ public void testChunkInfer() {
assertTrue("Listener not called", gotResults.get());
}

@SuppressWarnings("unchecked")
public void testChunkInferSetsTokenization() {
var expectedSpan = new AtomicInteger();
var expectedWindowSize = new AtomicReference<Integer>();

ThreadPool threadpool = new TestThreadPool("test");
Client client = mock(Client.class);
try {
when(client.threadPool()).thenReturn(threadpool);
doAnswer(invocationOnMock -> {
var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1];
assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class));
var update = (TokenizationConfigUpdate) request.getUpdate();
assertEquals(update.getSpanSettings().span(), expectedSpan.get());
assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get());
return null;
}).when(client)
.execute(
same(InferTrainedModelDeploymentAction.INSTANCE),
any(InferTrainedModelDeploymentAction.Request.class),
any(ActionListener.class)
);

var model = new ElserInternalModel(
"foo",
TaskType.SPARSE_EMBEDDING,
"elser",
new ElserInternalServiceSettings(1, 1, "elser"),
new ElserMlNodeTaskSettings()
);
var service = createService(client);

expectedSpan.set(-1);
expectedWindowSize.set(null);
service.chunkedInfer(
model,
List.of("foo", "bar"),
Map.of(),
InputType.SEARCH,
null,
ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage()))
);

expectedSpan.set(-1);
expectedWindowSize.set(256);
service.chunkedInfer(
model,
List.of("foo", "bar"),
Map.of(),
InputType.SEARCH,
new ChunkingOptions(256, null),
ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage()))
);
} finally {
terminate(threadpool);
}
}

private ElserInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
return new ElserInternalService(context);

0 comments on commit 092f95e

Please sign in to comment.