From dc6029f7cd1553e76edd8b2e7a2292a4ebb8bc69 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 4 Aug 2021 13:14:53 -0400 Subject: [PATCH] [ML] allow for larger models in the inference step for data frame analytics --- .../InferenceToXContentCompressor.java | 10 +++- .../core/ml/inference/TrainedModelConfig.java | 16 +++++ .../integration/TrainedModelProviderIT.java | 2 +- .../dataframe/inference/InferenceRunner.java | 2 +- .../loadingservice/ModelLoadingService.java | 22 +++++-- .../persistence/TrainedModelProvider.java | 15 +++-- .../ModelLoadingServiceTests.java | 60 +++++++++---------- 7 files changed, 81 insertions(+), 46 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java index 163d397061977..35ca522e3f7d0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java @@ -37,9 +37,9 @@ */ public final class InferenceToXContentCompressor { private static final int BUFFER_SIZE = 4096; - // Either 10% of the configured JVM heap, or 1 GB, which ever is smaller + // Either 25% of the configured JVM heap, or 1 GB, which ever is smaller private static final long MAX_INFLATED_BYTES = Math.min( - (long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()), + (long)((0.25) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()), ByteSizeValue.ofGb(1).getBytes()); private InferenceToXContentCompressor() {} @@ -49,6 +49,12 @@ public static BytesReference deflate(T objectToComp return deflate(reference); } + public static T inflateUnsafe(BytesReference compressedBytes, + CheckedFunction parserFunction, + NamedXContentRegistry xContentRegistry) throws IOException { + return inflate(compressedBytes, parserFunction, xContentRegistry, Long.MAX_VALUE); + } + public static T inflate(BytesReference compressedBytes, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index e0ae853d34449..785fe97c39ee7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -286,6 +286,14 @@ public TrainedModelConfig ensureParsedDefinition(NamedXContentRegistry xContentR return this; } + public TrainedModelConfig ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException { + if (definition == null) { + return null; + } + definition.ensureParsedDefinitionUnsafe(xContentRegistry); + return this; + } + @Nullable public TrainedModelDefinition getModelDefinition() { if (definition == null) { @@ -872,6 +880,14 @@ private void ensureParsedDefinition(NamedXContentRegistry xContentRegistry) thro } } + private void ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException { + if (parsedDefinition == null) { + parsedDefinition = InferenceToXContentCompressor.inflateUnsafe(compressedRepresentation, + parser -> TrainedModelDefinition.fromXContent(parser, true).build(), + xContentRegistry); + } + } + @Override public void writeTo(StreamOutput out) throws IOException { if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO adjust on backport diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 1c395b6baf6ce..713b3412a8b4b 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -339,7 +339,7 @@ public void testGetTrainedModelForInference() throws InterruptedException, IOExc AtomicReference definitionHolder = new AtomicReference<>(); blockingCall( - listener -> trainedModelProvider.getTrainedModelForInference(modelId, listener), + listener -> trainedModelProvider.getTrainedModelForInference(modelId, false, listener), definitionHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(nullValue())); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index d696c72ab5533..ceea0ae5fb07d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -91,7 +91,7 @@ public void run(String modelId) { LOGGER.info("[{}] Started inference on test data against model [{}]", config.getId(), modelId); try { PlainActionFuture localModelPlainActionFuture = new PlainActionFuture<>(); - modelLoadingService.getModelForPipeline(modelId, localModelPlainActionFuture); + modelLoadingService.getModelForInternalInference(modelId, localModelPlainActionFuture); InferenceState inferenceState = restoreInferenceState(); dataCountsTracker.setTestDocsCount(inferenceState.processedTestDocsCount); TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(client, ClientHelper.ML_ORIGIN), config, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 9481070833b62..fd6dc208f7ab2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -99,7 +99,7 @@ public class ModelLoadingService implements ClusterStateListener { // The feature requesting the model public enum Consumer { - PIPELINE, SEARCH + PIPELINE, SEARCH, INTERNAL } private static class ModelAndConsumer { @@ -175,6 +175,16 @@ public void getModelForPipeline(String modelId, ActionListener model getModel(modelId, Consumer.PIPELINE, modelActionListener); } + /** + * Load the model for internal use. Note, this decompresses the model if the stored estimate doesn't trip circuit breakers. + * Consequently, it assumes the model was created by an ML process + * @param modelId the model to get + * @param modelActionListener the listener to alert when the model has been retrieved + */ + public void getModelForInternalInference(String modelId, ActionListener modelActionListener) { + getModel(modelId, Consumer.INTERNAL, modelActionListener); + } + /** * Load the model for use by at search. Models requested by search are always cached. * @@ -272,7 +282,7 @@ private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, A return true; } - if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) { + if (Consumer.SEARCH != consumer && referencedModels.contains(modelId) == false) { // The model is requested by a pipeline but not referenced by any ingest pipelines. // This means it is a simulate call and the model should not be cached logger.trace(() -> new ParameterizedMessage( @@ -280,7 +290,7 @@ private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, A modelId, modelIdOrAlias )); - loadWithoutCaching(modelId, modelActionListener); + loadWithoutCaching(modelId, consumer, modelActionListener); } else { logger.trace(() -> new ParameterizedMessage( "[{}] (model_alias [{}]) attempting to load and cache", @@ -298,7 +308,7 @@ private void loadModel(String modelId, Consumer consumer) { provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap( trainedModelConfig -> { trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( + provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap( inferenceDefinition -> { try { // Since we have used the previously stored estimate to help guard against OOM we need @@ -327,14 +337,14 @@ private void loadModel(String modelId, Consumer consumer) { )); } - private void loadWithoutCaching(String modelId, ActionListener modelActionListener) { + private void loadWithoutCaching(String modelId, Consumer consumer, ActionListener modelActionListener) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap( trainedModelConfig -> { // Verify we can pull the model into memory without causing OOM trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId); - provider.getTrainedModelForInference(modelId, ActionListener.wrap( + provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, ActionListener.wrap( inferenceDefinition -> { InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 01c23409240da..eef2c5ed6758c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -392,13 +392,17 @@ private void storeTrainedModelAndDefinition(TrainedModelConfig trainedModelConfi * do not. * * @param modelId The model tp get + * @param unsafe when true, the compressed bytes size is not checked and the circuit breaker is solely responsible for + * preventing OOMs * @param listener The listener */ - public void getTrainedModelForInference(final String modelId, final ActionListener listener) { + public void getTrainedModelForInference(final String modelId, boolean unsafe, final ActionListener listener) { // TODO Change this when we get more than just langIdent stored if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { - TrainedModelConfig config = loadModelFromResource(modelId, false).build().ensureParsedDefinition(xContentRegistry); + TrainedModelConfig config = loadModelFromResource(modelId, false) + .build() + .ensureParsedDefinitionUnsafe(xContentRegistry); assert config.getModelDefinition().getTrainedModel() instanceof LangIdentNeuralNetwork; assert config.getModelType() == TrainedModelType.LANG_IDENT; listener.onResponse( @@ -425,10 +429,9 @@ public void getTrainedModelForInference(final String modelId, final ActionListen success -> { try { BytesReference compressedData = getDefinitionFromDocs(docs, modelId); - InferenceDefinition inferenceDefinition = InferenceToXContentCompressor.inflate( - compressedData, - InferenceDefinition::fromXContent, - xContentRegistry); + InferenceDefinition inferenceDefinition = unsafe ? + InferenceToXContentCompressor.inflateUnsafe(compressedData, InferenceDefinition::fromXContent, xContentRegistry) : + InferenceToXContentCompressor.inflate(compressedData, InferenceDefinition::fromXContent, xContentRegistry); listener.onResponse(inferenceDefinition); } catch (Exception e) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 27621aac8b436..9ede4b215c7ac 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -143,9 +143,9 @@ public void testGetCachedModels() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), eq(false), any()); assertTrue(modelLoadingService.isModelCached(model1)); assertTrue(modelLoadingService.isModelCached(model2)); @@ -160,10 +160,10 @@ public void testGetCachedModels() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), eq(false), any()); // It is not referenced, so called eagerly - verify(trainedModelProvider, times(4)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, times(4)).getTrainedModelForInference(eq(model3), eq(false), any()); } public void testMaxCachedLimitReached() throws Exception { @@ -196,9 +196,9 @@ public void testMaxCachedLimitReached() throws Exception { // the loading occurred or which models are currently in the cache due to evictions. // Verify that we have at least loaded all three assertBusy(() -> { - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), eq(false), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), eq(false), any()); }); // all models loaded put in the cache @@ -215,10 +215,10 @@ public void testMaxCachedLimitReached() throws Exception { // Depending on the order the models were first loaded in the first step // models 1 & 2 may have been evicted by model 3 in which case they have // been loaded at most twice - verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model1), any()); - verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model2), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model1), eq(false), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model2), eq(false), any()); // Only loaded requested once on the initial load from the change event - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model3), eq(false), any()); // model 3 has been loaded and evicted exactly once verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<>() { @@ -234,7 +234,7 @@ public boolean matches(final Object o) { modelLoadingService.getModelForPipeline(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, times(2)).getTrainedModelForInference(eq(model3), eq(false), any()); verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<>() { @Override @@ -255,7 +255,7 @@ public boolean matches(final Object o) { modelLoadingService.getModelForPipeline(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } - verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), eq(false), any()); verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<>() { @Override public boolean matches(final Object o) { @@ -269,7 +269,7 @@ public boolean matches(final Object o) { modelLoadingService.getModelForPipeline(model2, future2); assertThat(future2.get(), is(not(nullValue()))); } - verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), eq(false), any()); // Test invalidate cache for model3 // Now both model 1 and 2 should fit in cache without issues @@ -281,9 +281,9 @@ public boolean matches(final Object o) { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), any()); - verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), any()); - verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model1), eq(false), any()); + verify(trainedModelProvider, atMost(3)).getTrainedModelForInference(eq(model2), eq(false), any()); + verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), eq(false), any()); } public void testWhenCacheEnabledButNotIngestNode() throws Exception { @@ -308,7 +308,7 @@ public void testWhenCacheEnabledButNotIngestNode() throws Exception { } assertFalse(modelLoadingService.isModelCached(model1)); - verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), any()); + verify(trainedModelProvider, times(10)).getTrainedModelForInference(eq(model1), eq(false), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -337,7 +337,7 @@ public void testGetCachedMissingModel() throws Exception { } assertFalse(modelLoadingService.isModelCached(model)); - verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), any()); + verify(trainedModelProvider, atMost(2)).getTrainedModelForInference(eq(model), eq(false), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -384,7 +384,7 @@ public void testGetModelEagerly() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), any()); + verify(trainedModelProvider, times(3)).getTrainedModelForInference(eq(model), eq(false), any()); assertFalse(modelLoadingService.isModelCached(model)); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -410,7 +410,7 @@ public void testGetModelForSearch() throws Exception { assertTrue(modelLoadingService.isModelCached(modelId)); - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(modelId), eq(false), any()); verify(trainedModelStatsService, never()).queueStats(any(InferenceStats.class), anyBoolean()); } @@ -571,7 +571,7 @@ public void testGetCachedModelViaModelAliases() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), eq(false), any()); assertTrue(modelLoadingService.isModelCached(model1)); assertTrue(modelLoadingService.isModelCached("loaded_model")); @@ -592,7 +592,7 @@ public void testGetCachedModelViaModelAliases() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), eq(false), any()); assertTrue(modelLoadingService.isModelCached(model2)); assertTrue(modelLoadingService.isModelCached("loaded_model")); } @@ -647,10 +647,10 @@ private void withTrainedModel(String modelId, long size) { when(trainedModelConfig.getEstimatedHeapMemory()).thenReturn(size); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onResponse(definition); return null; - }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), eq(false), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; @@ -680,19 +680,19 @@ private void withMissingModel(String modelId) { }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(GetTrainedModelsAction.Includes.empty()), any()); doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); return null; - }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), eq(false), any()); } doAnswer(invocationOnMock -> { @SuppressWarnings("rawtypes") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return null; - }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), any()); + }).when(trainedModelProvider).getTrainedModelForInference(eq(modelId), eq(false), any()); } private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException {