diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index f66d49b7880f5..92833da0bd1eb 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -1,5 +1,4 @@ import org.elasticsearch.gradle.VersionProperties -import org.elasticsearch.gradle.internal.dra.DraResolvePlugin apply plugin: 'elasticsearch.internal-es-plugin' apply plugin: 'elasticsearch.internal-cluster-test' @@ -86,6 +85,7 @@ dependencies { testImplementation project(':modules:reindex') testImplementation project(':modules:analysis-common') testImplementation project(':modules:mapper-extras') + testImplementation project(':modules:lang-mustache') // This should not be here testImplementation(testArtifact(project(xpackModule('security')))) testImplementation project(path: xpackModule('wildcard')) @@ -133,4 +133,4 @@ tasks.named("dependencyLicenses").configure { mapping from: /lucene-.*/, to: 'lucene' } -addQaCheckDependencies(project) \ No newline at end of file +addQaCheckDependencies(project) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java index 57784654e4f63..c6bddf08e06a8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearnToRankServiceTests.java @@ -11,7 +11,10 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.script.ScriptEngine; +import org.elasticsearch.script.ScriptModule; import org.elasticsearch.script.ScriptService; +import org.elasticsearch.script.mustache.MustacheScriptEngine; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -26,13 +29,16 @@ import org.elasticsearch.xpack.core.ml.utils.QueryProviderTests; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import org.junit.AssumptionViolatedException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; +import static org.elasticsearch.script.Script.DEFAULT_TEMPLATE_LANG; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.doAnswer; @@ -41,8 +47,8 @@ public class LearnToRankServiceTests extends ESTestCase { public static final String GOOD_MODEL = "modelId"; - public static final String BAD_MODEL = "badModel"; + public static final String TEMPLATED_GOOD_MODEL = "templatedModelId"; public static final TrainedModelConfig GOOD_MODEL_CONFIG = TrainedModelConfig.builder() .setModelId(GOOD_MODEL) .setInput(new TrainedModelInput(List.of("field1", "field2"))) @@ -68,6 +74,19 @@ public class LearnToRankServiceTests extends ESTestCase { .setInferenceConfig(new RegressionConfig(null, null)) .build(); + public static final TrainedModelConfig TEMPLATED_GOOD_MODEL_CONFIG = new TrainedModelConfig.Builder(GOOD_MODEL_CONFIG) + .setModelId(TEMPLATED_GOOD_MODEL) + .setInferenceConfig( + new LearnToRankConfig( + 2, + List.of( + new QueryExtractorBuilder("feature_1", QueryProviderTests.createRandomValidQueryProvider("field_1", "{{foo_param}}")), + new QueryExtractorBuilder("feature_2", QueryProviderTests.createRandomValidQueryProvider("field_2", "{{bar_param}}")) + ) + ) + ) + .build(); + @SuppressWarnings("unchecked") public void testLoadLearnToRankConfig() throws Exception { LearnToRankService learnToRankService = new LearnToRankService( @@ -78,7 +97,7 @@ public void testLoadLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig(GOOD_MODEL, Collections.emptyMap(), listener); - assertBusy(() -> { verify(listener).onResponse(eq((LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig())); }); + assertBusy(() -> verify(listener).onResponse(eq((LearnToRankConfig) GOOD_MODEL_CONFIG.getInferenceConfig()))); } @SuppressWarnings("unchecked") @@ -91,7 +110,7 @@ public void testLoadMissingLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig("non-existing-model", Collections.emptyMap(), listener); - assertBusy(() -> { verify(listener).onFailure(isA(ResourceNotFoundException.class)); }); + assertBusy(() -> verify(listener).onFailure(isA(ResourceNotFoundException.class))); } @SuppressWarnings("unchecked") @@ -104,12 +123,39 @@ public void testLoadBadLearnToRankConfig() throws Exception { ); ActionListener listener = mock(ActionListener.class); learnToRankService.loadLearnToRankConfig(BAD_MODEL, Collections.emptyMap(), listener); - assertBusy(() -> { verify(listener).onFailure(isA(ElasticsearchStatusException.class)); }); + assertBusy(() -> verify(listener).onFailure(isA(ElasticsearchStatusException.class))); } - public void testLoadLearnToRankConfigWithTemplate() { - // TODO: write the test. - throw new AssumptionViolatedException("Test to be written"); + @SuppressWarnings("unchecked") + public void testLoadLearnToRankConfigWithTemplate() throws Exception { + LearnToRankService learnToRankService = new LearnToRankService( + mockModelLoadingService(), + mockTrainedModelProvider(), + mockScriptService(), + xContentRegistry() + ); + + // When no parameters are provided we expect the templated queries not being part of the retrieved config. + ActionListener noParamsListener = mock(ActionListener.class); + learnToRankService.loadLearnToRankConfig(TEMPLATED_GOOD_MODEL, Collections.emptyMap(), noParamsListener); + assertBusy(() -> verify(noParamsListener).onResponse(argThat(retrievedConfig -> { + assertThat(retrievedConfig.getFeatureExtractorBuilders(), hasSize(2)); + assertEquals(retrievedConfig, TEMPLATED_GOOD_MODEL_CONFIG.getInferenceConfig()); + return true; + }))); + + // Now testing when providing all the params of the template. + ActionListener allParamsListener = mock(ActionListener.class); + learnToRankService.loadLearnToRankConfig( + TEMPLATED_GOOD_MODEL, + Map.ofEntries(Map.entry("foo_param", "foo"), Map.entry("bar_param", "bar")), + allParamsListener + ); + assertBusy(() -> verify(allParamsListener).onResponse(argThat(retrievedConfig -> { + assertThat(retrievedConfig.getFeatureExtractorBuilders(), hasSize(2)); + assertEquals(retrievedConfig, GOOD_MODEL_CONFIG.getInferenceConfig()); + return true; + }))); } @Override @@ -132,12 +178,11 @@ private TrainedModelProvider mockTrainedModelProvider() { doAnswer(invocation -> { String modelId = invocation.getArgument(0); ActionListener l = invocation.getArgument(3, ActionListener.class); - if (modelId.equals(GOOD_MODEL)) { - l.onResponse(GOOD_MODEL_CONFIG); - } else if (modelId.equals(BAD_MODEL)) { - l.onResponse(BAD_MODEL_CONFIG); - } else { - l.onFailure(new ResourceNotFoundException("missing model")); + switch (modelId) { + case GOOD_MODEL -> l.onResponse(GOOD_MODEL_CONFIG); + case TEMPLATED_GOOD_MODEL -> l.onResponse(TEMPLATED_GOOD_MODEL_CONFIG); + case BAD_MODEL -> l.onResponse(BAD_MODEL_CONFIG); + default -> l.onFailure(new ResourceNotFoundException("missing model")); } return null; @@ -147,6 +192,7 @@ private TrainedModelProvider mockTrainedModelProvider() { } private ScriptService mockScriptService() { - return mock(ScriptService.class); + ScriptEngine scriptEngine = new MustacheScriptEngine(); + return new ScriptService(Settings.EMPTY, Map.of(DEFAULT_TEMPLATE_LANG, scriptEngine), ScriptModule.CORE_CONTEXTS, () -> 1L); } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learn_to_rank_rescorer.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learn_to_rank_rescorer.yml index a0ae4b7c44316..c65621426f7fe 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learn_to_rank_rescorer.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learn_to_rank_rescorer.yml @@ -192,10 +192,6 @@ setup: - match: { hits.hits.4._score: 1.0 } --- "Test rescore with stored model and chained rescorers": - - skip: - version: all - reason: "@AwaitsFix https://github.com/elastic/elasticsearch/issues/80703" - - do: search: index: store