From 6a1408dc4859c76e09decbb6ad662848fe9bc61a Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Thu, 16 Nov 2023 11:27:16 +0100 Subject: [PATCH] Remove suppliers when it is possible to avoid it. --- .../rescorer/LearnToRankRescorerBuilder.java | 161 ++++++++++-------- ...earnToRankRescorerBuilderRewriteTests.java | 44 ++--- ...RankRescorerBuilderSerializationTests.java | 11 +- 3 files changed, 110 insertions(+), 106 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java index af36dca311ecf..562273c024232 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilder.java @@ -70,8 +70,8 @@ public static LearnToRankRescorerBuilder fromXContent( private final Map params; private final ScriptService scriptService; private final ModelLoadingService modelLoadingService; - private final Supplier localModelSupplier; - private final Supplier learnToRankConfigSupplier; + private final LocalModel localModel; + private final LearnToRankConfig learnToRankConfig; private boolean rescoreOccurred = false; LearnToRankRescorerBuilder( @@ -86,40 +86,23 @@ public static LearnToRankRescorerBuilder fromXContent( this.modelLoadingService = modelLoadingService; // Config and model will be set during successive rewrite phases. - this.learnToRankConfigSupplier = null; - this.localModelSupplier = null; + this.learnToRankConfig = null; + this.localModel = null; } - LearnToRankRescorerBuilder( - String modelId, - ModelLoadingService modelLoadingService, - Supplier learnToRankConfigSupplier - ) { + LearnToRankRescorerBuilder(String modelId, ModelLoadingService modelLoadingService, LearnToRankConfig learnToRankConfig) { this.modelId = modelId; this.modelLoadingService = modelLoadingService; - this.learnToRankConfigSupplier = learnToRankConfigSupplier; + this.learnToRankConfig = learnToRankConfig; // Local inference model is not loaded yet. Will be done in a later rewrite. - this.localModelSupplier = null; + this.localModel = null; // Templates has been applied already, so we do not need params and script service anymore. this.params = null; this.scriptService = null; } - LearnToRankRescorerBuilder(Supplier learnToRankConfigSupplier, Supplier localModelSupplier) { - this.modelId = localModelSupplier.get() != null ? localModelSupplier.get().getModelId() : null; - this.learnToRankConfigSupplier = learnToRankConfigSupplier; - this.localModelSupplier = localModelSupplier; - - // Model is loaded already, so we do not need the model loading service anymore. - this.modelLoadingService = null; - - // Template has been applied already, so we do not need params and script service anymore. - this.params = null; - this.scriptService = null; - } - public LearnToRankRescorerBuilder( StreamInput input, Supplier modelLoadingServiceSupplier, @@ -128,14 +111,25 @@ public LearnToRankRescorerBuilder( super(input); this.modelId = input.readString(); this.params = input.readMap(); - - LearnToRankConfig learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class); - this.learnToRankConfigSupplier = learnToRankConfig != null ? () -> learnToRankConfig : null; + this.learnToRankConfig = input.readOptionalNamedWriteable(LearnToRankConfig.class); this.modelLoadingService = modelLoadingServiceSupplier.get(); this.scriptService = scriptServiceSupplier.get(); - this.localModelSupplier = null; + this.localModel = null; + } + + LearnToRankRescorerBuilder(LearnToRankConfig learnToRankConfig, LocalModel localModel) { + this.modelId = localModel.getModelId(); + this.learnToRankConfig = learnToRankConfig; + this.localModel = localModel; + + // Model is loaded already, so we do not need the model loading service anymore. + this.modelLoadingService = null; + + // Template has been applied already, so we do not need params and script service anymore. + this.params = null; + this.scriptService = null; } public String modelId() { @@ -146,16 +140,16 @@ public Map params() { return params; } - public Supplier learnToRankConfigSupplier() { - return learnToRankConfigSupplier; + public LearnToRankConfig learnToRankConfig() { + return learnToRankConfig; } - public ModelLoadingService modelLoadingServiceSupplier() { + public ModelLoadingService modelLoadingService() { return modelLoadingService; } - public Supplier localModelSupplier() { - return localModelSupplier; + public LocalModel localModel() { + return localModel; } @Override @@ -179,18 +173,13 @@ public RescorerBuilder rewrite(QueryRewriteContext c * @throws IOException when rewrite fails */ private RescorerBuilder doCoordinatorNodeRewrite(QueryRewriteContext ctx) throws IOException { - if (learnToRankConfigSupplier != null && learnToRankConfigSupplier.get() == null) { - // Awaiting to fetch the model. - return this; - } - // We have requested for the stored config and fetch is completed, get the config and rewrite further if required - if (learnToRankConfigSupplier != null) { - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); - if (rewrittenConfig == learnToRankConfigSupplier.get()) { + if (learnToRankConfig != null) { + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfig, ctx); + if (rewrittenConfig == learnToRankConfig) { return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, () -> rewrittenConfig); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, rewrittenConfig); if (windowSize != null) { builder.windowSize(windowSize); } @@ -228,7 +217,24 @@ private RescorerBuilder doCoordinatorNodeRewrite(Que }, l::onFailure) ) ); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, configSetOnce::get); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, null) { + @Override + public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (configSetOnce.get() == null) { + // Still waiting for the model to be loaded. + return this; + } + + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, configSetOnce.get()); + + if (windowSize() != null) { + builder.windowSize(windowSize()); + } + + return builder; + } + }; + if (windowSize() != null) { builder.windowSize(windowSize); } @@ -264,23 +270,41 @@ private LearnToRankConfig applyParams(LearnToRankConfig config, QueryRewriteCont * @return A rewritten rescorer with a model definition or a model definition supplier populated */ private RescorerBuilder doDataNodeRewrite(QueryRewriteContext ctx) throws IOException { - assert learnToRankConfigSupplier.get() != null; + assert learnToRankConfig != null; - // The model supplier is already created, no need to rewrite further. - if (localModelSupplier != null) { + // The model is already loaded, no need to rewrite further. + if (localModel != null) { return this; } if (modelLoadingService == null) { throw new IllegalStateException("Model loading service must be available"); } - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); - SetOnce inferenceDefinitionSetOnce = new SetOnce<>(); + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfig, ctx); + SetOnce localModelSetOnce = new SetOnce<>(); ctx.registerAsyncAction((c, l) -> modelLoadingService.getModelForLearnToRank(modelId, ActionListener.wrap(lm -> { - inferenceDefinitionSetOnce.set(lm); + localModelSetOnce.set(lm); l.onResponse(null); }, l::onFailure))); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(() -> rewrittenConfig, inferenceDefinitionSetOnce::get); + + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(modelId, modelLoadingService, learnToRankConfig) { + @Override + public RescorerBuilder rewrite(QueryRewriteContext ctx) throws IOException { + if (localModelSetOnce.get() == null) { + // Still waiting for the model to be loaded. + return this; + } + + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(learnToRankConfig, localModelSetOnce.get()); + + if (windowSize() != null) { + builder.windowSize(windowSize()); + } + + return builder; + } + }; + if (windowSize() != null) { builder.windowSize(windowSize()); } @@ -294,14 +318,14 @@ private RescorerBuilder doDataNodeRewrite(QueryRewri * @throws IOException If fetching, parsing, or overall rewrite failures occur */ private RescorerBuilder doSearchRewrite(QueryRewriteContext ctx) throws IOException { - if (learnToRankConfigSupplier == null || learnToRankConfigSupplier.get() == null) { + if (learnToRankConfig == null) { return this; } - LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfigSupplier.get(), ctx); - if (rewrittenConfig == learnToRankConfigSupplier.get()) { + LearnToRankConfig rewrittenConfig = Rewriteable.rewrite(learnToRankConfig, ctx); + if (rewrittenConfig == learnToRankConfig) { return this; } - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(() -> rewrittenConfig, localModelSupplier); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(rewrittenConfig, localModel); if (windowSize != null) { builder.windowSize(windowSize); } @@ -311,9 +335,11 @@ private RescorerBuilder doSearchRewrite(QueryRewrite @Override protected LearnToRankRescorerContext innerBuildContext(int windowSize, SearchExecutionContext context) { rescoreOccurred = true; - LearnToRankConfig learnToRankConfig = learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null; - LocalModel inferenceDefinition = localModelSupplier != null ? localModelSupplier.get() : null; - return new LearnToRankRescorerContext(windowSize, LearnToRankRescorer.INSTANCE, learnToRankConfig, inferenceDefinition, context); + + assert learnToRankConfig != null; + assert localModel != null; + + return new LearnToRankRescorerContext(windowSize, LearnToRankRescorer.INSTANCE, learnToRankConfig, localModel, context); } @Override @@ -329,10 +355,10 @@ public TransportVersion getMinimalSupportedVersion() { @Override protected void doWriteTo(StreamOutput out) throws IOException { - assert localModelSupplier == null || rescoreOccurred : "Unnecessarily populated local model object"; + assert localModel == null || rescoreOccurred : "Unnecessarily populated local model object"; out.writeString(modelId); out.writeGenericMap(params); - out.writeOptionalNamedWriteable(learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null); + out.writeOptionalNamedWriteable(learnToRankConfig); } @Override @@ -352,19 +378,10 @@ public boolean equals(Object o) { if (super.equals(o) == false) return false; LearnToRankRescorerBuilder that = (LearnToRankRescorerBuilder) o; - if (learnToRankConfigSupplier != null - && (that.learnToRankConfigSupplier == null - || Objects.equals(learnToRankConfigSupplier.get(), that.learnToRankConfigSupplier.get()) == false)) { - return false; - } - - if (localModelSupplier != null - && (that.localModelSupplier == null || Objects.equals(localModelSupplier.get(), that.localModelSupplier.get()) == false)) { - return false; - } - return Objects.equals(modelId, that.modelId) && Objects.equals(params, that.params) + && Objects.equals(learnToRankConfig, that.learnToRankConfig) + && Objects.equals(localModel, that.localModel) && Objects.equals(modelLoadingService, that.modelLoadingService) && Objects.equals(scriptService, that.scriptService) && rescoreOccurred == that.rescoreOccurred; @@ -376,10 +393,10 @@ public int hashCode() { super.hashCode(), modelId, params, + learnToRankConfig, + localModel, modelLoadingService, scriptService, - learnToRankConfigSupplier != null ? learnToRankConfigSupplier.get() : null, - localModelSupplier != null ? localModelSupplier.get() : null, rescoreOccurred ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java index 5090be6880224..0f4f6a1d5cdad 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderRewriteTests.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder; @@ -55,6 +54,7 @@ import java.util.Collections; import java.util.List; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearnToRankConfigTests.randomLearnToRankConfig; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -97,11 +97,8 @@ public class LearnToRankRescorerBuilderRewriteTests extends AbstractBuilderTestC public void testMustRewrite() { TestModelLoader testModelLoader = new TestModelLoader(); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( - GOOD_MODEL, - testModelLoader, - () -> LearnToRankConfigTests.randomLearnToRankConfig() - ); + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, testModelLoader, randomLearnToRankConfig()); + SearchExecutionContext context = createSearchExecutionContext(); LearnToRankRescorerContext rescorerContext = rescorerBuilder.innerBuildContext(randomIntBetween(1, 30), context); IllegalStateException e = expectThrows( @@ -127,14 +124,13 @@ public void testRewriteOnCoordinator() throws IOException { randomIntBetween(1_500_000, Integer.MAX_VALUE) ); LearnToRankRescorerBuilder rewritten = rewriteAndFetch(rescorerBuilder, context); - assertThat(rewritten.learnToRankConfigSupplier().get(), not(nullValue())); - assertThat(rewritten.learnToRankConfigSupplier().get().getNumTopFeatureImportanceValues(), equalTo(2)); + assertThat(rewritten.learnToRankConfig(), not(nullValue())); + assertThat(rewritten.learnToRankConfig().getNumTopFeatureImportanceValues(), equalTo(2)); assertThat( "feature_1", is( in( - rewritten.learnToRankConfigSupplier() - .get() + rewritten.learnToRankConfig() .getFeatureExtractorBuilders() .stream() .map(LearnToRankFeatureExtractorBuilder::featureName) @@ -173,17 +169,15 @@ public void testRewriteOnCoordinatorWithMissingModel() { public void testSearchRewrite() throws IOException { LocalModel localModel = localModel(); when(localModel.getModelId()).thenReturn(GOOD_MODEL); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( - () -> LearnToRankConfigTests.randomLearnToRankConfig(), - () -> localModel - ); + + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(randomLearnToRankConfig(), localModel); QueryRewriteContext context = createSearchExecutionContext(); LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) Rewriteable.rewrite(rescorerBuilder, context, true); - LearnToRankConfig rewrittenLearnToRankConfig = Rewriteable.rewrite(rewritten.learnToRankConfigSupplier().get(), context); - assertThat(rewritten.localModelSupplier().get(), is(localModel)); - assertThat(rewritten.learnToRankConfigSupplier().get(), is(rewrittenLearnToRankConfig)); + LearnToRankConfig rewrittenLearnToRankConfig = Rewriteable.rewrite(rewritten.learnToRankConfig(), context); + assertThat(rewritten.localModel(), is(localModel)); + assertThat(rewritten.learnToRankConfig(), is(rewrittenLearnToRankConfig)); } protected LearnToRankRescorerBuilder rewriteAndFetch(RescorerBuilder builder, QueryRewriteContext context) { @@ -220,7 +214,7 @@ public void testRewriteOnShard() throws IOException { LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( GOOD_MODEL, testModelLoader, - () -> (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig() + (LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig() ); SearchExecutionContext searchExecutionContext = createSearchExecutionContext(); LearnToRankRescorerBuilder rewritten = (LearnToRankRescorerBuilder) rescorerBuilder.rewrite(createSearchExecutionContext()); @@ -230,11 +224,9 @@ public void testRewriteOnShard() throws IOException { public void testRewriteAndFetchOnDataNode() throws IOException { TestModelLoader testModelLoader = new TestModelLoader(); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( - GOOD_MODEL, - testModelLoader, - () -> LearnToRankConfigTests.randomLearnToRankConfig() - ); + + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(GOOD_MODEL, testModelLoader, randomLearnToRankConfig()); + boolean setWindowSize = randomBoolean(); if (setWindowSize) { rescorerBuilder.windowSize(42); @@ -253,10 +245,8 @@ public void testBuildContext() throws Exception { List inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME); when(localModel.inputFields()).thenReturn(inputFields); SearchExecutionContext context = createSearchExecutionContext(); - LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder( - () -> LearnToRankConfigTests.randomLearnToRankConfig(), - () -> localModel - ); + + LearnToRankRescorerBuilder rescorerBuilder = new LearnToRankRescorerBuilder(randomLearnToRankConfig(), localModel); LearnToRankRescorerContext rescoreContext = rescorerBuilder.innerBuildContext(20, context); assertNotNull(rescoreContext); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java index 5a1cf3cb6e442..3a967ee4a8fba 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/rescorer/LearnToRankRescorerBuilderSerializationTests.java @@ -72,10 +72,9 @@ protected Writeable.Reader instanceReader() { @Override protected LearnToRankRescorerBuilder createTestInstance() { - LearnToRankConfig learnToRankConfig = randomLearnToRankConfig(); LearnToRankRescorerBuilder builder = randomBoolean() ? createXContextTestInstance(null) - : new LearnToRankRescorerBuilder(randomAlphaOfLength(10), null, () -> learnToRankConfig); + : new LearnToRankRescorerBuilder(randomAlphaOfLength(10), null, randomLearnToRankConfig()); if (randomBoolean()) { builder.windowSize(randomIntBetween(1, 10000)); @@ -127,11 +126,9 @@ protected LearnToRankRescorerBuilder mutateInstance(LearnToRankRescorerBuilder i yield builder; } case 3 -> { - LearnToRankConfig learnToRankConfig = instance.learnToRankConfigSupplier() != null - ? randomValueOtherThan(instance.learnToRankConfigSupplier().get(), () -> randomLearnToRankConfig()) - : randomLearnToRankConfig(); + LearnToRankConfig learnToRankConfig = randomValueOtherThan(instance.learnToRankConfig(), () -> randomLearnToRankConfig()); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.modelId(), null, () -> learnToRankConfig); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.modelId(), null, learnToRankConfig); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); } @@ -139,7 +136,7 @@ protected LearnToRankRescorerBuilder mutateInstance(LearnToRankRescorerBuilder i } case 4 -> { LocalModel localModel = mock(LocalModel.class); - LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.learnToRankConfigSupplier(), () -> localModel); + LearnToRankRescorerBuilder builder = new LearnToRankRescorerBuilder(instance.learnToRankConfig(), localModel); if (instance.windowSize() != null) { builder.windowSize(instance.windowSize()); }