Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to save memory on aggregations #53793

Merged
merged 8 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@

package org.elasticsearch.action.search;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
Expand Down Expand Up @@ -58,16 +67,8 @@
import org.elasticsearch.search.suggest.Suggest.Suggestion;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectObjectHashMap;

public final class SearchPhaseController {
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
Expand Down Expand Up @@ -429,7 +430,7 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
* @see QuerySearchResult#consumeProfileResult()
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce) {
Expand All @@ -453,7 +454,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
final boolean hasSuggest = firstResult.suggest() != null;
final boolean hasProfileResults = firstResult.hasProfileResults();
final boolean consumeAggs;
final List<InternalAggregations> aggregationsList;
final List<Supplier<InternalAggregations>> aggregationsList;
if (bufferedAggs != null) {
consumeAggs = false;
// we already have results from intermediate reduces and just need to perform the final reduce
Expand Down Expand Up @@ -492,7 +493,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
}
}
if (consumeAggs) {
aggregationsList.add((InternalAggregations) result.consumeAggs());
aggregationsList.add(result.consumeAggs());
}
if (hasProfileResults) {
String key = result.getSearchShardTarget().toString();
Expand All @@ -508,8 +509,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
}
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(aggregationsList,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, aggregationsList);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
reducedCompletionSuggestions);
Expand All @@ -519,6 +519,24 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}

private InternalAggregations reduceAggs(
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
boolean performFinalReduce,
List<Supplier<InternalAggregations>> aggregationsList
) {
/*
* Parse the aggregations, clearing the list as we go so bits backing
* the DelayedWriteable can be collected immediately.
*/
List<InternalAggregations> toReduce = new ArrayList<>(aggregationsList.size());
for (int i = 0; i < aggregationsList.size(); i++) {
toReduce.add(aggregationsList.get(i).get());
aggregationsList.set(i, null);
}
return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
}

/*
* Returns the size of the requested top documents (from + size)
*/
Expand Down Expand Up @@ -600,7 +618,7 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
*/
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
private final SearchShardTarget[] processedShards;
private final InternalAggregations[] aggsBuffer;
private final Supplier<InternalAggregations>[] aggsBuffer;
private final TopDocs[] topDocsBuffer;
private final boolean hasAggs;
private final boolean hasTopDocs;
Expand Down Expand Up @@ -642,7 +660,9 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
this.progressListener = progressListener;
this.processedShards = new SearchShardTarget[expectedResultSize];
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
this.aggsBuffer = new InternalAggregations[hasAggs ? bufferSize : 0];
@SuppressWarnings("unchecked")
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0];
this.aggsBuffer = aggsBuffer;
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
Expand All @@ -665,10 +685,14 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
if (querySearchResult.isNull() == false) {
if (index == bufferSize) {
if (hasAggs) {
ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Arrays.asList(aggsBuffer), reduceContext);
Arrays.fill(aggsBuffer, null);
aggsBuffer[0] = reducedAggs;
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
for (int i = 0; i < aggsBuffer.length; i++) {
aggs.add(aggsBuffer[i].get());
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
}
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(
aggs, aggReduceContextBuilder.forPartialReduction());
aggsBuffer[0] = () -> reducedAggs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we nullify the rest of the array to make the reduced aggs eligible for gc ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the line right above does that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but we keep the serialized + deserialized form until after the partial reduce. We can try to release the serialized form early with:

 List<InternalAggregations> toReduce = Arrays.stream(aggsBuffer).map(Supplier::get).collect(toList());
 Arrays.fill(aggsBuffer, null);
 InternalAggregaions reducedAggs = InternalAggregations.topLevelReduce(toReduce, aggReduceContextBuilder.forPartialReduction());
 aggsBuffer[0] = () -> reducedAggs;

Or we can nullify the serialized form when the supplier is called like discussed below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! I noticed that right after I sent this. I'm playing with nulling the cell in the array as soon as I call get. That feels a little safer than nulling the bytes.

}
if (hasTopDocs) {
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
Expand All @@ -681,12 +705,12 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
index = 1;
if (hasAggs || hasTopDocs) {
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases);
}
}
final int i = index++;
if (hasAggs) {
aggsBuffer[i] = (InternalAggregations) querySearchResult.consumeAggs();
aggsBuffer[i] = querySearchResult.consumeAggs();
}
if (hasTopDocs) {
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
Expand All @@ -698,7 +722,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget();
}

private synchronized List<InternalAggregations> getRemainingAggs() {
private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.common.io.stream;

import org.elasticsearch.Version;
import org.elasticsearch.common.bytes.BytesReference;

import java.io.IOException;
import java.util.function.Supplier;

/**
* A holder for {@link Writeable}s that can delays reading the underlying
* {@linkplain Writeable} when it is read from a remote node.
*/
public abstract class DelayableWriteable<T extends Writeable> implements Supplier<T>, Writeable {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're only using this for InternalAggregations, but it is a heck of a lot simpler to test if it is generic.

/**
* Build a {@linkplain DelayableWriteable} that wraps an existing object
* but is serialized so that deserializing it can be delayed.
*/
public static <T extends Writeable> DelayableWriteable<T> referencing(T reference) {
return new Referencing<>(reference);
}
/**
* Build a {@linkplain DelayableWriteable} that copies a buffer from
* the provided {@linkplain StreamInput} and deserializes the buffer
* when {@link Supplier#get()} is called.
*/
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
return new Delayed<>(reader, in);
}

private DelayableWriteable() {}

public abstract boolean isDelayed();

private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
private T reference;

Referencing(T reference) {
this.reference = reference;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
reference.writeTo(buffer);
out.writeBytesReference(buffer.bytes());
}
}

@Override
public T get() {
return reference;
}

@Override
public boolean isDelayed() {
return false;
}
}

private static class Delayed<T extends Writeable> extends DelayableWriteable<T> {
private final Writeable.Reader<T> reader;
private final Version remoteVersion;
private final BytesReference serialized;
private final NamedWriteableRegistry registry;

Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
this.reader = reader;
remoteVersion = in.getVersion();
serialized = in.readBytesReference();
registry = in.namedWriteableRegistry();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getVersion() == remoteVersion) {
/*
* If the version *does* line up we can just copy the bytes
* which is good because this is how shard request caching
* works.
*/
out.writeBytesReference(serialized);
} else {
/*
* If the version doesn't line up then we have to deserialize
* into the Writeable and re-serialize it against the new
* output stream so it can apply any backwards compatibility
* differences in the wire protocol. This ain't efficient but
* it should be quite rare.
*/
referencing(get()).writeTo(out);
}
}

@Override
public T get() {
try {
try (StreamInput in = registry == null ?
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
in.setVersion(remoteVersion);
return reader.read(in);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we nullify the bytes ref before returning the deserialized aggs ? We could also protect against multiple calls by keeping the deserialized aggs internally on the first call ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about race conditions with that. The way it is it is fairly simple the look at and say "there are no race conditions." I think nulifying the other references would be good enough from a GC perspective. Do you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep nullifying the reference should be enough but it would be better if we can nullify after each deserialization. Otherwise you'd need to keep the deserialized aggs and their bytes representation during the entire partial reduce which defeats the purpose of saving memories here ?

}
} catch (IOException e) {
throw new RuntimeException("unexpected error expanding aggregations", e);
}
}

@Override
public boolean isDelayed() {
return true;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,9 @@ public void setVersion(Version version) {
protected void ensureCanReadBytes(int length) throws EOFException {
delegate.ensureCanReadBytes(length);
}

@Override
public NamedWriteableRegistry namedWriteableRegistry() {
return delegate.namedWriteableRegistry();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ public <C extends NamedWriteable> C readNamedWriteable(@SuppressWarnings("unused
+ "] than it was read from [" + name + "].";
return c;
}

@Override
public NamedWriteableRegistry namedWriteableRegistry() {
return namedWriteableRegistry;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,14 @@ public <T extends Exception> T readException() throws IOException {
return null;
}

/**
* Get the registry of named writeables is his stream has one,
* {@code null} otherwise.
*/
public NamedWriteableRegistry namedWriteableRegistry() {
return null;
}

/**
* Reads a {@link NamedWriteable} from the current stream, by first reading its name and then looking for
* the corresponding entry in the registry by name, so that the proper object can be read and returned.
Expand Down
Loading