Skip to content

Commit

Permalink
Try to save memory on aggregations (backport of elastic#53793)
Browse files Browse the repository at this point in the history
This delays deserializing the aggregation response try until *right*
before we merge the objects.
  • Loading branch information
nik9000 committed Mar 23, 2020
1 parent eda023e commit f35c346
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@

package org.elasticsearch.action.search;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
Expand Down Expand Up @@ -58,16 +67,8 @@
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;

public final class SearchPhaseController {
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
Expand Down Expand Up @@ -429,7 +430,7 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
* @see QuerySearchResult#consumeProfileResult()
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
Expand All @@ -453,7 +454,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
final boolean hasSuggest = firstResult.suggest() != null;
final boolean hasProfileResults = firstResult.hasProfileResults();
final boolean consumeAggs;
final List<InternalAggregations> aggregationsList;
final List<Supplier<InternalAggregations>> aggregationsList;
if (bufferedAggs != null) {
consumeAggs = false;
// we already have results from intermediate reduces and just need to perform the final reduce
Expand Down Expand Up @@ -492,7 +493,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
}
}
if (consumeAggs) {
aggregationsList.add((InternalAggregations) result.consumeAggs());
aggregationsList.add(result.consumeAggs());
}
if (hasProfileResults) {
String key = result.getSearchShardTarget().toString();
Expand All @@ -508,8 +509,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
}
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(aggregationsList,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, aggregationsList);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
reducedCompletionSuggestions);
Expand All @@ -519,6 +519,24 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}

private InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<Supplier<InternalAggregations>> aggregationsList
) {
/*
* Parse the aggregations, clearing the list as we go so bits backing
* the DelayedWriteable can be collected immediately.
*/
List<InternalAggregations> toReduce = new ArrayList<>(aggregationsList.size());
for (int i = 0; i < aggregationsList.size(); i++) {
toReduce.add(aggregationsList.get(i).get());
aggregationsList.set(i, null);
}
return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
}

/*
* Returns the size of the requested top documents (from + size)
*/
Expand Down Expand Up @@ -600,7 +618,7 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
*/
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private final SearchShardTarget[] processedShards;
private final InternalAggregations[] aggsBuffer;
private final Supplier<InternalAggregations>[] aggsBuffer;
private final TopDocs[] topDocsBuffer;
private final boolean hasAggs;
private final boolean hasTopDocs;
Expand Down Expand Up @@ -642,7 +660,9 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
this.progressListener = progressListener;
this.processedShards = new SearchShardTarget[expectedResultSize];
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
this.aggsBuffer = new InternalAggregations[hasAggs ? bufferSize : 0];
@SuppressWarnings("unchecked")
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0];
this.aggsBuffer = aggsBuffer;
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
Expand All @@ -665,10 +685,14 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
if (hasAggs) {
ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Arrays.asList(aggsBuffer), reduceContext);
Arrays.fill(aggsBuffer, null);
aggsBuffer[0] = reducedAggs;
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(
aggs, aggReduceContextBuilder.forPartialReduction());
aggsBuffer[0] = () -> reducedAggs;
}
if (hasTopDocs) {
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
Expand All @@ -681,12 +705,12 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
index = 1;
if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases);
}
}
final int i = index++;
if (hasAggs) {
aggsBuffer[i] = (InternalAggregations) querySearchResult.consumeAggs();
aggsBuffer[i] = querySearchResult.consumeAggs();
}
if (hasTopDocs) {
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
Expand All @@ -698,7 +722,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget();
}

private synchronized List<InternalAggregations> getRemainingAggs() {
private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.common.io.stream;

import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;

import java.io.IOException;
import java.util.function.Supplier;

/**
* A holder for {@link Writeable}s that can delays reading the underlying
* {@linkplain Writeable} when it is read from a remote node.
*/
public abstract class DelayableWriteable<T extends Writeable> implements Supplier<T>, Writeable {
/**
* Build a {@linkplain DelayableWriteable} that wraps an existing object
* but is serialized so that deserializing it can be delayed.
*/
public static <T extends Writeable> DelayableWriteable<T> referencing(T reference) {
return new Referencing<>(reference);
}
/**
* Build a {@linkplain DelayableWriteable} that copies a buffer from
* the provided {@linkplain StreamInput} and deserializes the buffer
* when {@link Supplier#get()} is called.
*/
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
return new Delayed<>(reader, in);
}

private DelayableWriteable() {}

public abstract boolean isDelayed();

private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
private T reference;

Referencing(T reference) {
this.reference = reference;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
reference.writeTo(buffer);
out.writeBytesReference(buffer.bytes());
}
}

@Override
public T get() {
return reference;
}

@Override
public boolean isDelayed() {
return false;
}
}

private static class Delayed<T extends Writeable> extends DelayableWriteable<T> {
private final Writeable.Reader<T> reader;
private final Version remoteVersion;
private final BytesReference serialized;
private final NamedWriteableRegistry registry;

Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
this.reader = reader;
remoteVersion = in.getVersion();
serialized = in.readBytesReference();
registry = in.namedWriteableRegistry();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion() == remoteVersion) {
/*
* If the version *does* line up we can just copy the bytes
* which is good because this is how shard request caching
* works.
*/
out.writeBytesReference(serialized);
} else {
/*
* If the version doesn't line up then we have to deserialize
* into the Writeable and re-serialize it against the new
* output stream so it can apply any backwards compatibility
* differences in the wire protocol. This ain't efficient but
* it should be quite rare.
*/
referencing(get()).writeTo(out);
}
}

@Override
public T get() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(remoteVersion);
return reader.read(in);
}
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
}
}

@Override
public boolean isDelayed() {
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,9 @@ public void setVersion(Version version) {
protected void ensureCanReadBytes(int length) throws EOFException {
delegate.ensureCanReadBytes(length);
}

@Override
public NamedWriteableRegistry namedWriteableRegistry() {
return delegate.namedWriteableRegistry();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ public <C extends NamedWriteable> C readNamedWriteable(@SuppressWarnings("unused
+ "] than it was read from [" + name + "].";
return c;
}

@Override
public NamedWriteableRegistry namedWriteableRegistry() {
return namedWriteableRegistry;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,14 @@ public <T extends Exception> T readException() throws IOException {
return null;
}

/**
* Get the registry of named writeables is his stream has one,
* {@code null} otherwise.
*/
public NamedWriteableRegistry namedWriteableRegistry() {
return null;
}

/**
* Reads a {@link NamedWriteable} from the current stream, by first reading its name and then looking for
* the corresponding entry in the registry by name, so that the proper object can be read and returned.
Expand Down
Loading

0 comments on commit f35c346

Please sign in to comment.