From 092f95e06e7b6ad123dc45f3ae31e7a46703a010 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Mon, 25 Mar 2024 18:25:24 +0000 Subject: [PATCH] [ML] Always update tokenisation options for chunked inference (#106718) Fixes an issue where if the defaults were used the input was truncated --- .../ElasticsearchInternalService.java | 4 +- .../services/elser/ElserInternalService.java | 7 ++- .../ElasticsearchInternalServiceTests.java | 60 ++++++++++++++++++ .../elser/ElserInternalServiceTests.java | 61 +++++++++++++++++++ 4 files changed, 127 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index a07ebe56a9258..02090ee84e708 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -251,9 +251,9 @@ public void chunkedInfer( return; } - var configUpdate = chunkingOptions.settingsArePresent() + var configUpdate = chunkingOptions != null ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : TextEmbeddingConfigUpdate.EMPTY_INSTANCE; + : new TokenizationConfigUpdate(null, null); var request = InferTrainedModelDeploymentAction.Request.forTextInput( model.getConfigurations().getInferenceEntityId(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 5069724697818..bb88193612ff4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -15,6 +15,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.ChunkingOptions; @@ -288,7 +289,7 @@ public void chunkedInfer( List input, Map taskSettings, InputType inputType, - ChunkingOptions chunkingOptions, + @Nullable ChunkingOptions chunkingOptions, ActionListener> listener ) { try { @@ -298,9 +299,9 @@ public void chunkedInfer( return; } - var configUpdate = chunkingOptions.settingsArePresent() + var configUpdate = chunkingOptions != null ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) - : TextExpansionConfigUpdate.EMPTY_UPDATE; + : new TokenizationConfigUpdate(null, null); var request = InferTrainedModelDeploymentAction.Request.forTextInput( model.getConfigurations().getInferenceEntityId(), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 0757012b234bd..073712beb8050 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; import java.util.ArrayList; @@ -38,6 +39,8 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -410,6 +413,63 @@ public void testChunkInfer() { assertTrue("Listener not called", gotResults.get()); } + @SuppressWarnings("unchecked") + public void testChunkInferSetsTokenization() { + var expectedSpan = new AtomicInteger(); + var expectedWindowSize = new AtomicReference(); + + Client client = mock(Client.class); + ThreadPool threadpool = new TestThreadPool("test"); + try { + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; + assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); + var update = (TokenizationConfigUpdate) request.getUpdate(); + assertEquals(update.getSpanSettings().span(), expectedSpan.get()); + assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); + return null; + }).when(client) + .execute( + same(InferTrainedModelDeploymentAction.INSTANCE), + any(InferTrainedModelDeploymentAction.Request.class), + any(ActionListener.class) + ); + + var model = new MultilingualE5SmallModel( + "foo", + TaskType.TEXT_EMBEDDING, + "e5", + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform") + ); + var service = createService(client); + + expectedSpan.set(-1); + expectedWindowSize.set(null); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + null, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + + expectedSpan.set(-1); + expectedWindowSize.set(256); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(256, null), + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + } finally { + terminate(threadpool); + } + } + private ElasticsearchInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElasticsearchInternalService(context); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index f2fd195ab8c5a..dbb50260edaf1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate; import java.util.ArrayList; import java.util.Collections; @@ -35,6 +36,8 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -394,6 +397,64 @@ public void testChunkInfer() { assertTrue("Listener not called", gotResults.get()); } + @SuppressWarnings("unchecked") + public void testChunkInferSetsTokenization() { + var expectedSpan = new AtomicInteger(); + var expectedWindowSize = new AtomicReference(); + + ThreadPool threadpool = new TestThreadPool("test"); + Client client = mock(Client.class); + try { + when(client.threadPool()).thenReturn(threadpool); + doAnswer(invocationOnMock -> { + var request = (InferTrainedModelDeploymentAction.Request) invocationOnMock.getArguments()[1]; + assertThat(request.getUpdate(), instanceOf(TokenizationConfigUpdate.class)); + var update = (TokenizationConfigUpdate) request.getUpdate(); + assertEquals(update.getSpanSettings().span(), expectedSpan.get()); + assertEquals(update.getSpanSettings().maxSequenceLength(), expectedWindowSize.get()); + return null; + }).when(client) + .execute( + same(InferTrainedModelDeploymentAction.INSTANCE), + any(InferTrainedModelDeploymentAction.Request.class), + any(ActionListener.class) + ); + + var model = new ElserInternalModel( + "foo", + TaskType.SPARSE_EMBEDDING, + "elser", + new ElserInternalServiceSettings(1, 1, "elser"), + new ElserMlNodeTaskSettings() + ); + var service = createService(client); + + expectedSpan.set(-1); + expectedWindowSize.set(null); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + null, + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + + expectedSpan.set(-1); + expectedWindowSize.set(256); + service.chunkedInfer( + model, + List.of("foo", "bar"), + Map.of(), + InputType.SEARCH, + new ChunkingOptions(256, null), + ActionListener.wrap(r -> fail("unexpected result"), e -> fail(e.getMessage())) + ); + } finally { + terminate(threadpool); + } + } + private ElserInternalService createService(Client client) { var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client); return new ElserInternalService(context);