Skip to content

Commit

Permalink
Rename the inference rescorer into learn_to_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Nov 23, 2023
1 parent 99cacab commit 2113447
Show file tree
Hide file tree
Showing 21 changed files with 512 additions and 488 deletions.
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@
exports org.elasticsearch.xpack.core.watcher.trigger;
exports org.elasticsearch.xpack.core.watcher.watch;
exports org.elasticsearch.xpack.core.watcher;
exports org.elasticsearch.xpack.core.ml.ltr;

provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber
with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference;
package org.elasticsearch.xpack.core.ml.ltr;

import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
public class QueryProvider implements Writeable, ToXContentObject, Rewriteable<QueryProvider> {

private static final Logger logger = LogManager.getLogger(QueryProvider.class);

private Exception parsingException;
private QueryBuilder parsedQuery;
private Map<String, Object> query;
private final Exception parsingException;
private final QueryBuilder parsedQuery;
private final Map<String, Object> query;

public static QueryProvider defaultQuery() {
return new QueryProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.ltr.MlLTRNamedXContentProvider;

import java.util.ArrayList;
import java.util.Collections;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
import org.elasticsearch.xpack.core.ml.ltr.MlLTRNamedXContentProvider;
import org.junit.Before;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.MlLTRNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.LearnToRankFeatureExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests;
import org.elasticsearch.xpack.core.ml.ltr.MlLTRNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ public class MlRescorerIT extends ESRestTestCase {
@Before
public void setupModelAndData() throws IOException {
putRegressionModel(MODEL_ID, """
{
{
"description": "super complex model for tests",
"input": {"field_names": ["cost", "product"]},
"inference_config": {
"learn_to_rank": {
"feature_extractors": [{
"query_extractor": {
"feature_name": "two",
"query": {"script_score": {"query": {"match_all":{}}, "script": {"source": "return 2.0;"}}}
"query": { "script_score": { "query": { "match_all":{} }, "script": { "source": "return 2.0;" } } }
}
}]
}
Expand Down Expand Up @@ -174,11 +174,13 @@ public void setupModelAndData() throws IOException {
}
}
}
}""");
}
""");
createIndex(INDEX_NAME, Settings.builder().put("number_of_shards", randomIntBetween(1, 3)).build(), """
"properties":{
"product":{"type": "keyword"},
"cost":{"type": "integer"}}""");
"properties": {
"product": {"type": "keyword"},
"cost": {"type": "integer"}
}""");
indexData("{ \"product\": \"TV\", \"cost\": 300 }");
indexData("{ \"product\": \"TV\", \"cost\": 400 }");
indexData("{ \"product\": \"VCR\", \"cost\": 150 }");
Expand All @@ -190,18 +192,18 @@ public void setupModelAndData() throws IOException {
@SuppressWarnings("unchecked")
public void testLtrSimple() throws Exception {
Response searchResponse = search("""
{
"query": {
"match": { "product": { "query": "TV"}}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model"
}
{
"query": {
"match": { "product": { "query": "TV" } }
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""");
}
""");

Map<String, Object> response = responseAsMap(searchResponse);
assertThat((List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));
Expand All @@ -211,44 +213,42 @@ public void testLtrSimple() throws Exception {
public void testLtrSimpleDFS() throws Exception {
Response searchResponse = searchDfs("""
{
"query": {
"match": { "product": { "query": "TV"}}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model",
"inference_config": {
"learn_to_rank": {
"feature_extractors":[
{"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}}
]
}
}
}
"query": {
"match": { "product": { "query": "TV" } }
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model",
"inference_config": {
"learn_to_rank": {
"feature_extractors":[
{ "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } }
]
}
}
}
}
}""");

Map<String, Object> response = responseAsMap(searchResponse);
assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));

searchResponse = searchDfs("""
{
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model",
"inference_config": {
"learn_to_rank": {
"feature_extractors":[
{"query_extractor": {"feature_name": "product_bm25", "query": {"term": {"product": "TV"}}}}
]
}
}
}
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model",
"inference_config": {
"learn_to_rank": {
"feature_extractors":[
{ "query_extractor": { "feature_name": "product_bm25", "query": { "term": { "product": "TV" } } } }
]
}
}
}
}
}""");

response = responseAsMap(searchResponse);
Expand All @@ -262,16 +262,16 @@ public void testLtrSimpleDFS() throws Exception {
@SuppressWarnings("unchecked")
public void testLtrSimpleEmpty() throws Exception {
Response searchResponse = search("""
{ "query": {
"term": { "product": "computer"}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model"
}
{
"query": {
"term": { "product": "computer" }
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""");

Map<String, Object> response = responseAsMap(searchResponse);
Expand All @@ -281,16 +281,16 @@ public void testLtrSimpleEmpty() throws Exception {
@SuppressWarnings("unchecked")
public void testLtrEmptyDFS() throws Exception {
Response searchResponse = searchDfs("""
{ "query": {
"match": { "product": { "query": "computer"}}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model"
}
{
"query": {
"match": { "product": { "query": "computer"} }
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""");

Map<String, Object> response = responseAsMap(searchResponse);
Expand All @@ -300,32 +300,32 @@ public void testLtrEmptyDFS() throws Exception {
@SuppressWarnings("unchecked")
public void testLtrCanMatch() throws Exception {
Response searchResponse = searchCanMatch("""
{ "query": {
"match": { "product": { "query": "TV"}}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model"
}
{
"query": {
"match": { "product": { "query": "TV"}}
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""", false);

Map<String, Object> response = responseAsMap(searchResponse);
assertThat(response.toString(), (List<Double>) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0));

searchResponse = searchCanMatch("""
{ "query": {
"match": { "product": { "query": "TV"}}
},
"rescore": {
"window_size": 10,
"inference": {
"model_id": "basic-ltr-model"
}
{
"query": {
"match": { "product": { "query": "TV"} }
},
"rescore": {
"window_size": 10,
"learn_to_rank": {
"model_id": "basic-ltr-model"
}
}
}""", true);

response = responseAsMap(searchResponse);
Expand Down
Loading

0 comments on commit 2113447

Please sign in to comment.