diff --git a/docs/changelog/102734.yaml b/docs/changelog/102734.yaml new file mode 100644 index 0000000000000..c27846d7d8478 --- /dev/null +++ b/docs/changelog/102734.yaml @@ -0,0 +1,5 @@ +pr: 102734 +summary: Allow match field in enrich fields +area: ES|QL +type: bug +issues: [] diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 557da9639a086..ddb4ed05f450f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -256,8 +256,10 @@ public static List calculateEnrichFields( List enrichFields, EnrichPolicy policy ) { - Map fieldMap = mapping.stream().collect(Collectors.toMap(NamedExpression::name, Function.identity())); - fieldMap.remove(policy.getMatchField()); + Set policyEnrichFieldSet = new HashSet<>(policy.getEnrichFields()); + Map fieldMap = mapping.stream() + .filter(e -> policyEnrichFieldSet.contains(e.name())) + .collect(Collectors.toMap(NamedExpression::name, Function.identity())); List result = new ArrayList<>(); if (enrichFields == null || enrichFields.isEmpty()) { // use the policy to infer the enrich fields diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 6e75eea75f655..95d52c0a93a60 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -9,6 +9,9 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.enrich.EnrichPolicy; +import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolution; +import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.aggregate.Max; import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; @@ -29,6 +32,7 @@ import org.elasticsearch.xpack.ql.plan.logical.EsRelation; import org.elasticsearch.xpack.ql.plan.logical.Filter; import org.elasticsearch.xpack.ql.plan.logical.Limit; +import org.elasticsearch.xpack.ql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.ql.plan.logical.OrderBy; import org.elasticsearch.xpack.ql.type.DataType; import org.elasticsearch.xpack.ql.type.DataTypes; @@ -36,9 +40,12 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer; @@ -1311,6 +1318,27 @@ public void testEnrichExcludesPolicyKey() { assertThat(e.getMessage(), containsString("Unknown column [id]")); } + public void testEnrichFieldsIncludeMatchField() { + String query = """ + FROM test + | EVAL x = to_string(languages) + | ENRICH languages ON x + | KEEP language_name, language_code + """; + IndexResolution testIndex = loadMapping("mapping-basic.json", "test"); + IndexResolution languageIndex = loadMapping("mapping-languages.json", "languages"); + var enrichPolicy = new EnrichPolicy("match", null, List.of("unused"), "language_code", List.of("language_code", "language_name")); + EnrichResolution enrichResolution = new EnrichResolution( + Set.of(new EnrichPolicyResolution("languages", enrichPolicy, languageIndex)), + Set.of("languages") + ); + AnalyzerContext context = new AnalyzerContext(configuration(query), new EsqlFunctionRegistry(), testIndex, enrichResolution); + Analyzer analyzer = new Analyzer(context, TEST_VERIFIER); + LogicalPlan plan = analyze(query, analyzer); + var limit = as(plan, Limit.class); + assertThat(Expressions.names(limit.output()), contains("language_name", "language_code")); + } + public void testChainedEvalFieldsUse() { var query = "from test | eval x0 = pow(salary, 1), x1 = pow(x0, 2), x2 = pow(x1, 3)"; int additionalEvals = randomIntBetween(0, 5);