Skip to content

Commit

Permalink
Do not add fields extracted using a query to the FieldValueFeatureExt…
Browse files Browse the repository at this point in the history
…ractor.
  • Loading branch information
afoucret committed Jun 6, 2024
1 parent d4d5d93 commit 442af35
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.util.ArrayList;
import java.util.List;

import static java.util.function.Predicate.not;

public class LearningToRankRescorerContext extends RescoreContext {

final SearchExecutionContext executionContext;
Expand Down Expand Up @@ -52,12 +54,9 @@ public LearningToRankRescorerContext(

List<FeatureExtractor> buildFeatureExtractors(IndexSearcher searcher) throws IOException {
assert this.regressionModelDefinition != null && this.learningToRankConfig != null;

List<FeatureExtractor> featureExtractors = new ArrayList<>();
if (this.regressionModelDefinition.inputFields().isEmpty() == false) {
featureExtractors.add(
new FieldValueFeatureExtractor(new ArrayList<>(this.regressionModelDefinition.inputFields()), this.executionContext)
);
}

List<Weight> weights = new ArrayList<>();
List<String> queryFeatureNames = new ArrayList<>();
for (LearningToRankFeatureExtractorBuilder featureExtractorBuilder : learningToRankConfig.getFeatureExtractorBuilders()) {
Expand All @@ -72,6 +71,14 @@ List<FeatureExtractor> buildFeatureExtractors(IndexSearcher searcher) throws IOE
featureExtractors.add(new QueryFeatureExtractor(queryFeatureNames, weights));
}

List<String> fieldValueExtractorFields = this.regressionModelDefinition.inputFields()
.stream()
.filter(not(queryFeatureNames::contains))
.toList();
if (fieldValueExtractorFields.isEmpty() == false) {
featureExtractors.add(new FieldValueFeatureExtractor(fieldValueExtractorFields, this.executionContext));
}

return featureExtractors;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ private static LearningToRankService learningToRankServiceMock() {

public void testBuildContext() throws Exception {
LocalModel localModel = mock(LocalModel.class);
List<String> inputFields = List.of(DOUBLE_FIELD_NAME, INT_FIELD_NAME);
when(localModel.inputFields()).thenReturn(inputFields);
when(localModel.inputFields()).thenReturn(GOOD_MODEL_CONFIG.getInput().getFieldNames());

IndexSearcher searcher = mock(IndexSearcher.class);
doAnswer(invocation -> invocation.getArgument(0)).when(searcher).rewrite(any(Query.class));
Expand All @@ -212,10 +211,17 @@ public void testBuildContext() throws Exception {
assertThat(rescoreContext.getWindowSize(), equalTo(20));
List<FeatureExtractor> featureExtractors = rescoreContext.buildFeatureExtractors(context.searcher());
assertThat(featureExtractors, hasSize(2));
assertThat(
featureExtractors.stream().flatMap(featureExtractor -> featureExtractor.featureNames().stream()).toList(),
containsInAnyOrder("feature_1", "feature_2", DOUBLE_FIELD_NAME, INT_FIELD_NAME)
);

FeatureExtractor queryExtractor = featureExtractors.stream().filter(fe -> fe instanceof QueryFeatureExtractor).findFirst().get();
assertThat(queryExtractor.featureNames(), hasSize(2));
assertThat(queryExtractor.featureNames(), containsInAnyOrder("feature_1", "feature_2"));

FeatureExtractor fieldValueExtractor = featureExtractors.stream()
.filter(fe -> fe instanceof FieldValueFeatureExtractor)
.findFirst()
.get();
assertThat(fieldValueExtractor.featureNames(), hasSize(2));
assertThat(fieldValueExtractor.featureNames(), containsInAnyOrder("field1", "field2"));
}

private LearningToRankRescorerBuilder rewriteAndFetch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public class LearningToRankServiceTests extends ESTestCase {
public static final String BAD_MODEL = "badModel";
public static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder()
.setModelId(GOOD_MODEL)
.setInput(new TrainedModelInput(List.of("field1", "field2")))
.setInput(new TrainedModelInput(List.of("field1", "field2", "feature_1", "feature_2")))
.setEstimatedOperations(1)
.setModelSize(2)
.setModelType(TrainedModelType.TREE_ENSEMBLE)
Expand Down

0 comments on commit 442af35

Please sign in to comment.