From 442af354b8199179cd06a91abd4f71395a819a13 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Thu, 6 Jun 2024 16:44:17 +0200 Subject: [PATCH] Do not add fields extracted using a query to the FieldValueFeatureExtractor. --- .../ltr/LearningToRankRescorerContext.java | 17 ++++++++++++----- ...rningToRankRescorerBuilderRewriteTests.java | 18 ++++++++++++------ .../ltr/LearningToRankServiceTests.java | 2 +- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java index b1df3a2da7c42..e03370b415417 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerContext.java @@ -24,6 +24,8 @@ import java.util.ArrayList; import java.util.List; +import static java.util.function.Predicate.not; + public class LearningToRankRescorerContext extends RescoreContext { final SearchExecutionContext executionContext; @@ -52,12 +54,9 @@ public LearningToRankRescorerContext( List buildFeatureExtractors(IndexSearcher searcher) throws IOException { assert this.regressionModelDefinition != null && this.learningToRankConfig != null; + List featureExtractors = new ArrayList<>(); - if (this.regressionModelDefinition.inputFields().isEmpty() == false) { - featureExtractors.add( - new FieldValueFeatureExtractor(new ArrayList<>(this.regressionModelDefinition.inputFields()), this.executionContext) - ); - } + List weights = new ArrayList<>(); List queryFeatureNames = new ArrayList<>(); for (LearningToRankFeatureExtractorBuilder featureExtractorBuilder : learningToRankConfig.getFeatureExtractorBuilders()) { @@ -72,6 +71,14 @@ List buildFeatureExtractors(IndexSearcher searcher) throws IOE featureExtractors.add(new QueryFeatureExtractor(queryFeatureNames, weights)); } + List fieldValueExtractorFields = this.regressionModelDefinition.inputFields() + .stream() + .filter(not(queryFeatureNames::contains)) + .toList(); + if (fieldValueExtractorFields.isEmpty() == false) { + featureExtractors.add(new FieldValueFeatureExtractor(fieldValueExtractorFields, this.executionContext)); + } + return featureExtractors; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java index 3bfe8aa390d8b..1f8995b25a349 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilderRewriteTests.java @@ -193,8 +193,7 @@ private static LearningToRankService learningToRankServiceMock() { public void testBuildContext() throws Exception { LocalModel localModel = mock(LocalModel.class); - List inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME); - when(localModel.inputFields()).thenReturn(inputFields); + when(localModel.inputFields()).thenReturn(GOOD_MODEL_CONFIG.getInput().getFieldNames()); IndexSearcher searcher = mock(IndexSearcher.class); doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class)); @@ -212,10 +211,17 @@ public void testBuildContext() throws Exception { assertThat(rescoreContext.getWindowSize(), equalTo(20)); List featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher()); assertThat(featureExtractors, hasSize(2)); - assertThat( - featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(), - containsInAnyOrder("feature_1", "feature_2", DOUBLE_FIELD_NAME, INT_FIELD_NAME) - ); + + FeatureExtractor queryExtractor = featureExtractors.stream().filter(fe -> fe instanceof QueryFeatureExtractor).findFirst().get(); + assertThat(queryExtractor.featureNames(), hasSize(2)); + assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2")); + + FeatureExtractor fieldValueExtractor = featureExtractors.stream() + .filter(fe -> fe instanceof FieldValueFeatureExtractor) + .findFirst() + .get(); + assertThat(fieldValueExtractor.featureNames(), hasSize(2)); + assertThat(fieldValueExtractor.featureNames(), containsInAnyOrder("field1", "field2")); } private LearningToRankRescorerBuilder rewriteAndFetch( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java index 026dcca4bfcf7..d0b68d15951bf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java @@ -54,7 +54,7 @@ public class LearningToRankServiceTests extends ESTestCase { public static final String BAD_MODEL = "badModel"; public static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder() .setModelId(GOOD_MODEL) - .setInput(new TrainedModelInput(List.of("field1", "field2"))) + .setInput(new TrainedModelInput(List.of("field1", "field2", "feature_1", "feature_2"))) .setEstimatedOperations(1) .setModelSize(2) .setModelType(TrainedModelType.TREE_ENSEMBLE)