diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java index ff4101be7c0de..4af3c9a9b700b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java @@ -43,12 +43,15 @@ import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.ContextParser; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder; -import org.elasticsearch.join.aggregations.ParsedChildren; +import org.elasticsearch.join.ParentJoinPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.aggregations.Aggregation; @@ -90,8 +93,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.ParsedLongTerms; import org.elasticsearch.search.aggregations.bucket.terms.ParsedStringTerms; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; -import org.elasticsearch.search.aggregations.matrix.stats.MatrixStatsAggregationBuilder; -import org.elasticsearch.search.aggregations.matrix.stats.ParsedMatrixStats; +import org.elasticsearch.search.aggregations.matrix.MatrixAggregationPlugin; import org.elasticsearch.search.aggregations.metrics.avg.AvgAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.avg.ParsedAvg; import org.elasticsearch.search.aggregations.metrics.cardinality.CardinalityAggregationBuilder; @@ -140,6 +142,9 @@ import org.elasticsearch.search.suggest.term.TermSuggestion; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -160,23 +165,52 @@ * Can be sub-classed to expose additional client methods that make use of endpoints added to Elasticsearch through plugins, or to * add support for custom response sections, again added to Elasticsearch through plugins. */ +@SuppressWarnings("varargs") public class RestHighLevelClient { + static final Collection> PRE_INSTALLED_PLUGINS = + Collections.unmodifiableList(Arrays.asList(ParentJoinPlugin.class, MatrixAggregationPlugin.class)); + private final RestClient client; private final NamedXContentRegistry registry; /** - * Creates a {@link RestHighLevelClient} given the low level {@link RestClient} that it should use to perform requests. + * Creates a {@link RestHighLevelClient} with a given low level {@link RestClient} and pre installed plugins + * + * @param restClient the low level {@link RestClient} used to perform requests + * @param plugins an optional array of additional plugins to run with the {@link RestHighLevelClient} + */ + @SafeVarargs + public RestHighLevelClient(RestClient restClient, Class... plugins) { + this(restClient, Arrays.asList(plugins)); + } + + /** + * Creates a {@link RestHighLevelClient} with a given low level {@link RestClient} and pre installed plugins + * + * @param restClient the low level {@link RestClient} used to perform requests + * @param plugins a collection of additional plugins to run with the {@link RestHighLevelClient} */ - public RestHighLevelClient(RestClient restClient) { - this(restClient, Collections.emptyList()); + public RestHighLevelClient(RestClient restClient, Collection> plugins) { + this(restClient, Settings.EMPTY, plugins); + } + + /** + * Creates a {@link RestHighLevelClient} with a given low level {@link RestClient} and pre installed plugins + * + * @param restClient the low level {@link RestClient} used to perform requests + * @param settings the settings passed to the {@link RestHighLevelClient} + * @param plugins a collection of additional plugins to run with the {@link RestHighLevelClient} + */ + public RestHighLevelClient(RestClient restClient, Settings settings, Collection> plugins) { + this(restClient, pluginsNamedXContents(settings, PRE_INSTALLED_PLUGINS, plugins)); } /** * Creates a {@link RestHighLevelClient} given the low level {@link RestClient} that it should use to perform requests and * a list of entries that allow to parse custom response sections added to Elasticsearch through plugins. */ - protected RestHighLevelClient(RestClient restClient, List namedXContentEntries) { + private RestHighLevelClient(RestClient restClient, List namedXContentEntries) { this.client = Objects.requireNonNull(restClient); this.registry = new NamedXContentRegistry(Stream.of(getDefaultNamedXContents().stream(), namedXContentEntries.stream()) .flatMap(Function.identity()).collect(toList())); @@ -542,8 +576,6 @@ static List getDefaultNamedXContents() { map.put(SignificantLongTerms.NAME, (p, c) -> ParsedSignificantLongTerms.fromXContent(p, (String) c)); map.put(SignificantStringTerms.NAME, (p, c) -> ParsedSignificantStringTerms.fromXContent(p, (String) c)); map.put(ScriptedMetricAggregationBuilder.NAME, (p, c) -> ParsedScriptedMetric.fromXContent(p, (String) c)); - map.put(ChildrenAggregationBuilder.NAME, (p, c) -> ParsedChildren.fromXContent(p, (String) c)); - map.put(MatrixStatsAggregationBuilder.NAME, (p, c) -> ParsedMatrixStats.fromXContent(p, (String) c)); List entries = map.entrySet().stream() .map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue())) .collect(Collectors.toList()); @@ -555,4 +587,27 @@ static List getDefaultNamedXContents() { (parser, context) -> CompletionSuggestion.fromXContent(parser, (String)context))); return entries; } + + static List pluginsNamedXContents(Settings settings, + Collection> preInstalledPlugins, + Collection> plugins) { + List> listOfPlugins = new ArrayList<>(preInstalledPlugins); + for (Class plugin : plugins) { + if (listOfPlugins.contains(plugin)) { + throw new IllegalArgumentException("plugin already exists: " + plugin); + } + listOfPlugins.add(plugin); + } + + PluginsService pluginsService = new PluginsService(settings, null, null, listOfPlugins); + + List entries = new ArrayList<>(); + pluginsService.filterPlugins(Plugin.class) + .forEach(plugin -> entries.addAll(plugin.getNamedXContent())); + pluginsService.filterPlugins(SearchPlugin.class).stream() + .flatMap(plugin -> plugin.getAggregations().stream()) + .flatMap(aggregationSpec -> aggregationSpec.getResultParsers().entrySet().stream()) + .forEach(e -> entries.add(new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(e.getKey()), e.getValue()))); + return entries; + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientExtTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientExtTests.java index cb32f9ae9dd93..ea81621c3aa06 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientExtTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientExtTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.ESTestCase; import org.junit.Before; @@ -36,7 +37,7 @@ import static org.mockito.Mockito.mock; /** - * This test works against a {@link RestHighLevelClient} subclass that simulats how custom response sections returned by + * This test works against a {@link RestHighLevelClient} subclass that simulates how custom response sections returned by * Elasticsearch plugins can be parsed using the high level client. */ public class RestHighLevelClientExtTests extends ESTestCase { @@ -46,7 +47,7 @@ public class RestHighLevelClientExtTests extends ESTestCase { @Before public void initClient() throws IOException { RestClient restClient = mock(RestClient.class); - restHighLevelClient = new RestHighLevelClientExt(restClient); + restHighLevelClient = new RestHighLevelClient(restClient, RestHighLevelClientExtPlugin.class); } public void testParseEntityCustomResponseSection() throws IOException { @@ -66,13 +67,13 @@ public void testParseEntityCustomResponseSection() throws IOException { } } - private static class RestHighLevelClientExt extends RestHighLevelClient { + public static class RestHighLevelClientExtPlugin extends Plugin { - private RestHighLevelClientExt(RestClient restClient) { - super(restClient, getNamedXContentsExt()); + public RestHighLevelClientExtPlugin() { } - private static List getNamedXContentsExt() { + @Override + public List getNamedXContent() { List entries = new ArrayList<>(); entries.add(new NamedXContentRegistry.Entry(BaseCustomResponseSection.class, new ParseField("custom1"), CustomResponseSection1::fromXContent)); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 8c5cdc6d68933..9efa3b3bd417f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -34,6 +34,7 @@ import org.apache.http.message.BasicRequestLine; import org.apache.http.message.BasicStatusLine; import org.apache.http.nio.entity.NStringEntity; +import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.Build; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; @@ -48,12 +49,14 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.common.CheckedFunction; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.cbor.CborXContent; import org.elasticsearch.common.xcontent.smile.SmileXContent; +import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.Aggregation; @@ -76,6 +79,7 @@ import static org.elasticsearch.client.RestClientTestUtil.randomHeaders; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; +import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.instanceOf; import static org.mockito.Matchers.anyMapOf; import static org.mockito.Matchers.anyObject; @@ -598,9 +602,9 @@ public void testWrapResponseListenerOnResponseExceptionWithIgnoresErrorValidBody assertEquals("Elasticsearch exception [type=exception, reason=test error message]", elasticsearchException.getMessage()); } - public void testNamedXContents() { + public void testDefaultNamedXContents() { List namedXContents = RestHighLevelClient.getDefaultNamedXContents(); - assertEquals(45, namedXContents.size()); + assertEquals(43, namedXContents.size()); Map, Integer> categories = new HashMap<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { Integer counter = categories.putIfAbsent(namedXContent.categoryClass, 1); @@ -609,10 +613,22 @@ public void testNamedXContents() { } } assertEquals(2, categories.size()); - assertEquals(Integer.valueOf(42), categories.get(Aggregation.class)); + assertEquals(Integer.valueOf(40), categories.get(Aggregation.class)); assertEquals(Integer.valueOf(3), categories.get(Suggest.Suggestion.class)); } + public void testPreInstalledPlugins() { + for (Class pluginClass : RestHighLevelClient.PRE_INSTALLED_PLUGINS) { + LuceneTestCase.ThrowingRunnable runnable = randomFrom( + () -> new RestHighLevelClient(restClient, pluginClass), + () -> new RestHighLevelClient(restClient, Collections.singletonList(pluginClass)), + () -> new RestHighLevelClient(restClient, Settings.EMPTY, Collections.singletonList(pluginClass))); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, runnable); + assertThat(exception.getMessage(), containsString("plugin already exists: " + pluginClass)); + } + } + private static class TrackingActionListener implements ActionListener { private final AtomicInteger statusCode = new AtomicInteger(-1); private final AtomicReference exception = new AtomicReference<>(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomAggregationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomAggregationTests.java new file mode 100644 index 0000000000000..cd2d6034ba84e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomAggregationTests.java @@ -0,0 +1,284 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client; + +import org.apache.http.HttpEntity; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchResponseSections; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.QueryParseContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregation; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.bucket.InternalSingleBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.ParsedSingleBucketAggregation; +import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static java.util.Collections.singletonList; + +/** + * Test usage of a custom aggregation provided by a plugin with the {@link RestHighLevelClient}. + */ +public class RestHighLevelClientWithCustomAggregationTests extends RestHighLevelClientWithPluginTestCase { + + private static final String CUSTOM = "custom"; + + @Override + protected Collection> getPlugins() { + return singletonList(CustomPlugin.class); + } + + public void testCustomAggregation() throws Exception { + String aggregationName = randomAlphaOfLengthBetween(5, 10); + CustomAggregationBuilder customAggregationBuilder = new CustomAggregationBuilder(aggregationName); + final int customNumber = randomIntBetween(1, 1000); + customAggregationBuilder.setCustomNumber(customNumber); + + Map metaData = null; + if (randomBoolean()) { + metaData = new HashMap<>(); + int metaDataCount = between(0, 10); + while (metaData.size() < metaDataCount) { + metaData.put(randomAlphaOfLength(5), randomAlphaOfLength(5)); + } + customAggregationBuilder.setMetaData(metaData); + } + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.termQuery("field", "hello")); + searchSourceBuilder.aggregation(customAggregationBuilder); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(searchSourceBuilder); + + SearchResponse searchResponse = search(searchRequest); + + Map aggregations = searchResponse.getAggregations().getAsMap(); + assertEquals(1, aggregations.size()); + assertTrue(aggregations.containsKey(aggregationName)); + + CustomAggregation customAggregation = (CustomAggregation) aggregations.get(aggregationName); + assertEquals(customNumber, customAggregation.getDocCount()); + assertEquals(metaData, customAggregation.getMetaData()); + assertEquals(customNumber % 2 == 1, customAggregation.isOdd()); + } + + @Override + @SuppressWarnings("unchecked") + protected Response performRequest(HttpEntity httpEntity) throws IOException { + try (XContentParser parser = createParser(Request.REQUEST_BODY_CONTENT_TYPE.xContent(), httpEntity.getContent())) { + Map requestAsMap = parser.map(); + assertEquals(2, requestAsMap.size()); + assertTrue("Search request does not contain the term query", requestAsMap.containsKey("query")); + assertTrue("Search request does not contain the aggregation", requestAsMap.containsKey("aggregations")); + + Map queryAsMap = (Map) requestAsMap.get("query"); + assertEquals(1, queryAsMap.size()); + assertTrue("Query must be a term query", queryAsMap.containsKey("term")); + + Map aggsAsMap = (Map) requestAsMap.get("aggregations"); + assertEquals("Aggregations must contain the custom aggregation", 1, aggsAsMap.size()); + String name = aggsAsMap.keySet().iterator().next(); + + Map customAggregationAsMap = (Map) aggsAsMap.get(name); + assertTrue("Custom aggregation must have the 'custom' type", customAggregationAsMap.containsKey(CUSTOM)); + Map customMetadata = (Map) customAggregationAsMap.get("meta"); + + customAggregationAsMap = (Map) customAggregationAsMap.get(CUSTOM); + assertTrue("Custom aggregation must contain the random number", customAggregationAsMap.containsKey("number")); + + int customNumber = ((Number) customAggregationAsMap.get("number")).intValue(); + boolean isOdd = (customNumber % 2) != 0; + + InternalCustom internal = new InternalCustom(name, isOdd, customNumber, InternalAggregations.EMPTY, null, customMetadata); + InternalAggregations internalAggregations = new InternalAggregations(singletonList(internal)); + SearchResponse searchResponse = + new SearchResponse( + new SearchResponseSections(SearchHits.empty(), internalAggregations, null, false, false, null, 1), + randomAlphaOfLengthBetween(5, 10), 5, 5, 100, ShardSearchFailure.EMPTY_ARRAY); + return createResponse(searchResponse); + } + } + + /** + * A plugin that provides a custom aggregation. + */ + public static class CustomPlugin extends Plugin implements SearchPlugin { + + public CustomPlugin() { + } + + @Override + public List getAggregations() { + return singletonList(new AggregationSpec(CUSTOM, CustomAggregationBuilder::new, null) + .addResultParser((p, c) -> ParsedCustom.fromXContent(p, (String) c))); + } + } + + interface CustomAggregation extends SingleBucketAggregation { + boolean isOdd(); + } + + static class CustomAggregationBuilder extends AbstractAggregationBuilder { + + private int customNumber; + + CustomAggregationBuilder(String name) { + super(name); + } + + CustomAggregationBuilder(StreamInput in) throws IOException { + super(in); + this.customNumber = in.readInt(); + } + + @Override + public String getType() { + return CUSTOM; + } + + void setCustomNumber(int value) { + this.customNumber = value; + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("number", customNumber); + return builder.endObject(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + protected AggregatorFactory doBuild(SearchContext context, + AggregatorFactory parent, + AggregatorFactories.Builder subFactories) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + protected int doHashCode() { + return Objects.hash(customNumber); + } + + @Override + protected boolean doEquals(Object obj) { + CustomAggregationBuilder other = (CustomAggregationBuilder) obj; + return customNumber == other.customNumber; + } + } + + static class InternalCustom extends InternalSingleBucketAggregation implements CustomAggregation { + + private final boolean odd; + + InternalCustom(String name, boolean odd, long docCount, InternalAggregations subAggregations, + List pipelineAggregators, Map metaData) { + super(name, docCount, subAggregations, pipelineAggregators, metaData); + this.odd = odd; + } + + @Override + public String getWriteableName() { + return CUSTOM; + } + + @Override + public boolean isOdd() { + return odd; + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + super.doXContentBody(builder, params); + return builder.field("is_odd", odd); + } + + @Override + public InternalAggregation doReduce(List aggregations, ReduceContext reduceContext) { + throw new UnsupportedOperationException(); + } + + @Override + protected InternalSingleBucketAggregation newAggregation(String name, long docCount, InternalAggregations subAggregations) { + throw new UnsupportedOperationException(); + } + } + + static class ParsedCustom extends ParsedSingleBucketAggregation implements CustomAggregation { + + private boolean odd; + + @Override + public String getType() { + return CUSTOM; + } + + @Override + public boolean isOdd() { + return odd; + } + + void setOdd(boolean odd) { + this.odd = odd; + } + + private static final ObjectParser PARSER; + static { + PARSER = new ObjectParser<>(CUSTOM, ParsedCustom::new); + PARSER.declareBoolean(ParsedCustom::setOdd, new ParseField("is_odd")); + PARSER.declareLong(ParsedCustom::setDocCount, CommonFields.DOC_COUNT); + PARSER.declareObject(ParsedCustom::setMetadata, (p, c) -> p.mapOrdered(), CommonFields.META); + } + + static ParsedCustom fromXContent(XContentParser parser, final String name) throws IOException { + ParsedCustom aggregation = PARSER.parse(parser, null); + aggregation.setName(name); + return aggregation; + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomSuggesterTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomSuggesterTests.java new file mode 100644 index 0000000000000..ef23583e212fb --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithCustomSuggesterTests.java @@ -0,0 +1,180 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client; + +import org.apache.http.HttpEntity; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchResponseSections; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.suggest.SortBy; +import org.elasticsearch.search.suggest.Suggest; +import org.elasticsearch.search.suggest.SuggestBuilder; +import org.elasticsearch.search.suggest.term.TermSuggestion; +import org.elasticsearch.search.suggest.term.TermSuggestionBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; +import static org.hamcrest.Matchers.greaterThan; + +/** + * Test usage of a custom suggester provided by a plugin with the {@link RestHighLevelClient}. + */ +public class RestHighLevelClientWithCustomSuggesterTests extends RestHighLevelClientWithPluginTestCase { + + private static final String CUSTOM = "custom"; + + @Override + protected Collection> getPlugins() { + return singletonList(CustomPlugin.class); + } + + public void testCustomSuggester() throws Exception { + final int numSuggestions = randomIntBetween(1, 5); + final int[] customNumbers = new int[numSuggestions]; + + SuggestBuilder suggestBuilder = new SuggestBuilder(); + for (int i = 0; i < numSuggestions; i++) { + customNumbers[i] = randomIntBetween(0, 10); + suggestBuilder.addSuggestion("suggest_" + i, + new CustomSuggestionBuilder("field", customNumbers[i]).text("custom number is " + customNumbers[i])); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(QueryBuilders.matchAllQuery()); + searchSourceBuilder.suggest(suggestBuilder); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(searchSourceBuilder); + + SearchResponse searchResponse = search(searchRequest); + + Suggest suggest = searchResponse.getSuggest(); + assertEquals(numSuggestions, suggest.size()); + for (int i = 0; i < numSuggestions; i++) { + TermSuggestion suggestion = suggest.getSuggestion("suggest_" + i); + assertEquals(customNumbers[i], suggestion.getEntries().size()); + + for (int j = 0; j < customNumbers[i]; j++) { + assertEquals("term " + j, suggestion.getEntries().get(j).getText().string()); + } + } + } + + @Override + @SuppressWarnings("unchecked") + protected Response performRequest(HttpEntity httpEntity) throws IOException { + try (XContentParser parser = createParser(Request.REQUEST_BODY_CONTENT_TYPE.xContent(), httpEntity.getContent())) { + Map requestAsMap = parser.map(); + assertEquals(2, requestAsMap.size()); + assertTrue("Search request does not contain the match all query", requestAsMap.containsKey("query")); + assertTrue("Search request does not contain any suggest", requestAsMap.containsKey("suggest")); + + Map queryAsMap = (Map) requestAsMap.get("query"); + assertEquals(1, queryAsMap.size()); + assertTrue("Query must be a match query", queryAsMap.containsKey("match_all")); + + Map suggestAsMap = (Map) requestAsMap.get("suggest"); + assertThat(suggestAsMap.size(), greaterThan(0)); + + List>> suggestions = + new ArrayList<>(); + + for (Map.Entry suggestEntry : suggestAsMap.entrySet()) { + assertTrue(suggestEntry.getKey().startsWith("suggest_")); + + Map suggestEntryAsMap = (Map) suggestEntry.getValue(); + assertTrue("Suggest must have 'custom' type", suggestEntryAsMap.containsKey(CUSTOM)); + assertTrue("Suggest must have some text", suggestEntryAsMap.containsKey("text")); + + Map customSuggestAsMap = (Map) suggestEntryAsMap.get(CUSTOM); + assertTrue("Custom suggest must contain the random number", customSuggestAsMap.containsKey("number")); + + int customNumber = ((Number) customSuggestAsMap.get("number")).intValue(); + assertEquals("custom number is " + customNumber, suggestEntryAsMap.get("text")); + + TermSuggestion suggestion = new TermSuggestion(suggestEntry.getKey(), customNumber, SortBy.SCORE); + for (int i = 0; i < customNumber; i++) { + suggestion.addTerm(new TermSuggestion.Entry(new Text("term " + i), i, customNumber)); + } + suggestions.add(suggestion); + } + + Suggest suggests = new Suggest(suggestions); + SearchResponse searchResponse = + new SearchResponse( + new SearchResponseSections(SearchHits.empty(), null, suggests, false, false, null, 1), + randomAlphaOfLengthBetween(5, 10), 5, 5, 100, ShardSearchFailure.EMPTY_ARRAY); + return createResponse(searchResponse); + } + } + + /** + * A plugin that provides a custom suggester. + */ + public static class CustomPlugin extends Plugin implements SearchPlugin { + + public CustomPlugin() { + } + + @Override + public List> getSuggesters() { + return singletonList(new SuggesterSpec<>(CUSTOM, CustomSuggestionBuilder::new, CustomSuggestionBuilder::fromXContent)); + } + } + + static class CustomSuggestionBuilder extends TermSuggestionBuilder { + + private final int customNumber; + + CustomSuggestionBuilder(String field, int customNumber) { + super(field); + this.customNumber = customNumber; + } + + CustomSuggestionBuilder(StreamInput in) throws IOException { + super(in); + this.customNumber = in.readInt(); + } + + @Override + public String getWriteableName() { + return CUSTOM; + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("number", customNumber); + return super.innerToXContent(builder, params); + } + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithPluginTestCase.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithPluginTestCase.java new file mode 100644 index 0000000000000..d299c3bde9b72 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientWithPluginTestCase.java @@ -0,0 +1,118 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHost; +import org.apache.http.HttpResponse; +import org.apache.http.ProtocolVersion; +import org.apache.http.RequestLine; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.entity.ContentType; +import org.apache.http.message.BasicHttpResponse; +import org.apache.http.message.BasicRequestLine; +import org.apache.http.message.BasicStatusLine; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.action.search.RestSearchAction; +import org.elasticsearch.test.ESTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; + +import static java.util.Collections.singletonMap; +import static org.elasticsearch.client.ESRestHighLevelClientTestCase.execute; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyMapOf; +import static org.mockito.Matchers.anyObject; +import static org.mockito.Matchers.anyVararg; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +/** + * Test usage of search extensions provided by plugins with the {@link RestHighLevelClient}. + */ +public abstract class RestHighLevelClientWithPluginTestCase extends ESTestCase { + + private static final String CUSTOM = "custom"; + + private RestClient restClient; + private RestHighLevelClient restHighLevelClient; + + protected Collection> getPlugins() { + return Collections.emptyList(); + } + + @Before + public void iniClients() throws IOException { + if (restHighLevelClient == null) { + restClient = mock(RestClient.class); + restHighLevelClient = new RestHighLevelClient(restClient, getPlugins()); + + doAnswer(mock -> performRequest((HttpEntity) mock.getArguments()[3])) + .when(restClient) + .performRequest(eq(HttpGet.METHOD_NAME), eq("/_search"), anyMapOf(String.class, String.class), + anyObject(), anyVararg()); + doAnswer(mock -> performRequestAsync((HttpEntity) mock.getArguments()[3], (ResponseListener) mock.getArguments()[4])) + .when(restClient) + .performRequestAsync(eq(HttpGet.METHOD_NAME), eq("/_search"), anyMapOf(String.class, String.class), + any(HttpEntity.class), any(ResponseListener.class), anyVararg()); + } + } + + protected abstract Response performRequest(HttpEntity httpEntity) throws IOException; + + protected Void performRequestAsync(HttpEntity httpEntity, ResponseListener responseListener) { + try { + responseListener.onSuccess(performRequest(httpEntity)); + } catch (IOException e) { + responseListener.onFailure(e); + } + return null; + } + + protected SearchResponse search(SearchRequest searchRequest) throws IOException { + return execute(searchRequest, restHighLevelClient::search, restHighLevelClient::searchAsync); + } + + /** + * Creates a {@link Response} from a {@link SearchResponse}? + */ + protected Response createResponse(SearchResponse searchResponse) throws IOException { + ProtocolVersion protocol = new ProtocolVersion("HTTP", 1, 1); + HttpResponse httpResponse = new BasicHttpResponse(new BasicStatusLine(protocol, 200, "OK")); + + final ToXContent.Params params = new ToXContent.MapParams(singletonMap(RestSearchAction.TYPED_KEYS_PARAM, "true")); + BytesRef bytesRef = XContentHelper.toXContent(searchResponse, XContentType.JSON, params, false).toBytesRef(); + httpResponse.setEntity(new ByteArrayEntity(bytesRef.bytes, ContentType.APPLICATION_JSON)); + + RequestLine requestLine = new BasicRequestLine(HttpGet.METHOD_NAME, "/_search", protocol); + return new Response(requestLine, new HttpHost("localhost", 9200), httpResponse); + } +} diff --git a/core/src/main/java/org/elasticsearch/plugins/SearchPlugin.java b/core/src/main/java/org/elasticsearch/plugins/SearchPlugin.java index 01685535a4e0e..9dd3ada50f703 100644 --- a/core/src/main/java/org/elasticsearch/plugins/SearchPlugin.java +++ b/core/src/main/java/org/elasticsearch/plugins/SearchPlugin.java @@ -28,6 +28,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.lucene.search.function.ScoreFunction; +import org.elasticsearch.common.xcontent.ContextParser; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; @@ -39,6 +41,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.ParsedAggregation; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.significant.SignificantTerms; import org.elasticsearch.search.aggregations.bucket.significant.heuristics.SignificanceHeuristic; @@ -216,7 +219,9 @@ public QuerySpec(String name, Writeable.Reader reader, QueryParser parser) * Specification for an {@link Aggregation}. */ class AggregationSpec extends SearchExtensionSpec { + private final Map> resultReaders = new TreeMap<>(); + private final Map> resultParsers = new TreeMap<>(); /** * Specification for an {@link Aggregation}. @@ -262,11 +267,33 @@ public AggregationSpec addResultReader(String writeableName, Writeable.Reader> getResultReaders() { return resultReaders; } + + /** + * Adds a {@link ContextParser} that can be used to parse this aggregation when it has been printed out as a {@link ToXContent}. + */ + public AggregationSpec addResultParser(ContextParser resultParser) { + return addResultParser(getName().getPreferredName(), resultParser); + } + + /** + * Adds a {@link ContextParser} that can be used to parse this aggregation when it has been printed out as a {@link ToXContent}. + */ + public AggregationSpec addResultParser(String typeName, ContextParser resultParser) { + resultParsers.put(typeName, resultParser); + return this; + } + + /** + * Get the parsers that can be used to parse this aggregation's results when it has been printed out as a {@link ToXContent}. + */ + public Map> getResultParsers() { + return resultParsers; + } } /** diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/ParsedAggregation.java b/core/src/main/java/org/elasticsearch/search/aggregations/ParsedAggregation.java index d79baac06b097..ebcc61610474a 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/ParsedAggregation.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/ParsedAggregation.java @@ -48,7 +48,7 @@ public final String getName() { return name; } - protected void setName(String name) { + public void setName(String name) { this.name = name; } @@ -72,6 +72,10 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par protected abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException; + protected void setMetadata(Map metadata) { + this.metadata = metadata; + } + /** * Parse a token of type XContentParser.Token.VALUE_NUMBER or XContentParser.Token.STRING to a double. * In other cases the default value is returned instead. diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/MatrixAggregationPlugin.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/MatrixAggregationPlugin.java index a712371fa10d2..b7c1ca873658e 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/MatrixAggregationPlugin.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/MatrixAggregationPlugin.java @@ -24,6 +24,7 @@ import org.elasticsearch.search.aggregations.matrix.stats.InternalMatrixStats; import org.elasticsearch.search.aggregations.matrix.stats.MatrixStatsAggregationBuilder; import org.elasticsearch.search.aggregations.matrix.stats.MatrixStatsParser; +import org.elasticsearch.search.aggregations.matrix.stats.ParsedMatrixStats; import java.util.List; @@ -32,7 +33,10 @@ public class MatrixAggregationPlugin extends Plugin implements SearchPlugin { @Override public List getAggregations() { - return singletonList(new AggregationSpec(MatrixStatsAggregationBuilder.NAME, MatrixStatsAggregationBuilder::new, - new MatrixStatsParser()).addResultReader(InternalMatrixStats::new)); + return singletonList( + new AggregationSpec(MatrixStatsAggregationBuilder.NAME, MatrixStatsAggregationBuilder::new, new MatrixStatsParser()) + .addResultReader(InternalMatrixStats::new) + .addResultParser((p, c) -> ParsedMatrixStats.fromXContent(p, (String) c)) + ); } } diff --git a/modules/parent-join/src/main/java/org/elasticsearch/join/ParentJoinPlugin.java b/modules/parent-join/src/main/java/org/elasticsearch/join/ParentJoinPlugin.java index 83033545cfbb7..43086e17ec8db 100644 --- a/modules/parent-join/src/main/java/org/elasticsearch/join/ParentJoinPlugin.java +++ b/modules/parent-join/src/main/java/org/elasticsearch/join/ParentJoinPlugin.java @@ -23,6 +23,7 @@ import org.elasticsearch.index.mapper.Mapper; import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder; import org.elasticsearch.join.aggregations.InternalChildren; +import org.elasticsearch.join.aggregations.ParsedChildren; import org.elasticsearch.join.fetch.ParentJoinFieldSubFetchPhase; import org.elasticsearch.join.mapper.ParentJoinFieldMapper; import org.elasticsearch.join.query.HasChildQueryBuilder; @@ -52,7 +53,8 @@ public List> getQueries() { public List getAggregations() { return Collections.singletonList( new AggregationSpec(ChildrenAggregationBuilder.NAME, ChildrenAggregationBuilder::new, ChildrenAggregationBuilder::parse) - .addResultReader(InternalChildren::new) + .addResultReader(InternalChildren::new) + .addResultParser((p, c) -> ParsedChildren.fromXContent(p, (String) c)) ); }