From 617fa47f314da680ba6b05a2ef6c36f054969437 Mon Sep 17 00:00:00 2001 From: Aurelien FOUCRET Date: Wed, 10 Apr 2024 13:40:29 +0200 Subject: [PATCH] Add an end-to-end test to check the cache is flushed correctly. --- .../MlLearningToRankRescorerIT.java | 90 ++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java index 0dab4f9e4256c..16df0a576573b 100644 --- a/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java +++ b/x-pack/plugin/ml/qa/basic-multi-node/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlLearningToRankRescorerIT.java @@ -28,7 +28,7 @@ public class MlLearningToRankRescorerIT extends ESRestTestCase { @Before public void setupModelAndData() throws IOException { - putRegressionModel(MODEL_ID, """ + putLearningToRankModel(MODEL_ID, """ { "description": "super complex model for tests", "input": { "field_names": ["cost", "product"] }, @@ -328,6 +328,87 @@ public void testLtrCanMatch() throws Exception { assertThat(response.toString(), (List) XContentMapValues.extractValue("hits.hits._score", response), contains(20.0, 20.0)); } + @SuppressWarnings("unchecked") + public void testModelCacheIsFlushedOnModelChange() throws IOException { + String searchBody = """ + { + "rescore": { + "window_size": 10, + "learning_to_rank": { + "model_id": "basic-ltr-model", + "params": { "keyword": "TV" } + } + } + }"""; + + Response searchResponse = searchDfs(searchBody); + Map response = responseAsMap(searchResponse); + assertThat( + response.toString(), + (List) XContentMapValues.extractValue("hits.hits._score", response), + contains(20.0, 20.0, 9.0, 9.0, 6.0) + ); + + deleteLearningToRankModel(MODEL_ID); + putLearningToRankModel(MODEL_ID, """ + { + "input": { "field_names": ["product_bm25"] }, + "inference_config": { + "learning_to_rank": { + "feature_extractors": [ + { "query_extractor": { "feature_name": "product_bm25", "query": { "match": { "product": "{{keyword}}" } } } + } + ] + } + }, + "definition": { + "trained_model": { + "ensemble": { + "feature_names": ["product_bm25"], + "target_type": "regression", + "trained_models": [ + { + "tree": { + "feature_names": [ "product_bm25" ], + "tree_structure": [ + { + "node_index": 0, + "split_feature": 0, + "split_gain": 12, + "threshold": 0, + "decision_type": "lte", + "default_left": true, + "left_child": 1, + "right_child": 2 + }, + { + "node_index": 1, + "leaf_value": 0.0 + }, + { + "node_index": 2, + "leaf_value": 10.0 + } + ], + "target_type": "regression" + } + } + ] + } + } + } + } + """); + + searchResponse = searchDfs(searchBody); + response = responseAsMap(searchResponse); + assertThat( + response.toString(), + (List) XContentMapValues.extractValue("hits.hits._score", response), + contains(10.0, 10.0, 0.0, 0.0, 0.0) + ); + } + private void indexData(String data) throws IOException { Request request = new Request("POST", INDEX_NAME + "/_doc"); request.setJsonEntity(data); @@ -354,7 +435,12 @@ private Response searchCanMatch(String searchBody, boolean dfs) throws IOExcepti return client().performRequest(request); } - private void putRegressionModel(String modelId, String body) throws IOException { + private void deleteLearningToRankModel(String modelId) throws IOException { + Request model = new Request("DELETE", "_ml/trained_models/" + modelId); + assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200)); + } + + private void putLearningToRankModel(String modelId, String body) throws IOException { Request model = new Request("PUT", "_ml/trained_models/" + modelId); model.setJsonEntity(body); assertThat(client().performRequest(model).getStatusLine().getStatusCode(), equalTo(200));