Skip to content

Commit

Permalink
Extract ResponseLengthRecorder in REST layer (#104836)
Browse files Browse the repository at this point in the history
Related to #104752, we should not be confounding the lifecycle of the
`EncodedLengthTrackingChunkedRestResponseBody` with the lifecycle of the
overall `RestResponse`. This commit introduces a separate object to
record the total response length whose lifecycle is independent of that
of the response body.
  • Loading branch information
DaveCTurner authored Jan 29, 2024
1 parent 4f4e613 commit ed6e8ec
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions server/src/main/java/org/elasticsearch/rest/RestController.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
Expand All @@ -41,6 +40,7 @@
import org.elasticsearch.rest.RestHandler.Route;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.usage.SearchUsageHolder;
import org.elasticsearch.usage.UsageService;
import org.elasticsearch.xcontent.XContentBuilder;
Expand All @@ -60,6 +60,7 @@
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.elasticsearch.indices.SystemIndices.EXTERNAL_SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY;
Expand Down Expand Up @@ -826,9 +827,13 @@ public void sendResponse(RestResponse response) {
if (response.isChunked() == false) {
methodHandlers.addResponseStats(response.content().length());
} else {
final var wrapped = new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), methodHandlers);
final var responseLengthRecorder = new ResponseLengthRecorder(methodHandlers);
final var headers = response.getHeaders();
response = RestResponse.chunked(response.status(), wrapped, Releasables.wrap(wrapped, response));
response = RestResponse.chunked(
response.status(),
new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder),
Releasables.wrap(responseLengthRecorder, response)
);
for (final var header : headers.entrySet()) {
for (final var value : header.getValue()) {
response.addHeader(header.getKey(), value);
Expand Down Expand Up @@ -857,15 +862,44 @@ private void close() {
}
}

private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody, Releasable {
private static class ResponseLengthRecorder extends AtomicReference<MethodHandlers> implements Releasable {
private long responseLength;

private ResponseLengthRecorder(MethodHandlers methodHandlers) {
super(methodHandlers);
}

@Override
public void close() {
// closed just before sending the last chunk, and also when the whole RestResponse is closed since the client might abort the
// connection before we send the last chunk, in which case we won't have recorded the response in the
// stats yet; thus we need run-once semantics here:
final var methodHandlers = getAndSet(null);
if (methodHandlers != null) {
// if we started sending chunks then we're closed on the transport worker, no need for sync
assert responseLength == 0L || Transports.assertTransportThread();
methodHandlers.addResponseStats(responseLength);
}
}

void addChunkLength(long chunkLength) {
assert chunkLength >= 0L : chunkLength;
assert Transports.assertTransportThread(); // always called on the transport worker, no need for sync
responseLength += chunkLength;
}
}

private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody {

private final ChunkedRestResponseBody delegate;
private final RunOnce onCompletion;
private long encodedLength = 0;
private final ResponseLengthRecorder responseLengthRecorder;

private EncodedLengthTrackingChunkedRestResponseBody(ChunkedRestResponseBody delegate, MethodHandlers methodHandlers) {
private EncodedLengthTrackingChunkedRestResponseBody(
ChunkedRestResponseBody delegate,
ResponseLengthRecorder responseLengthRecorder
) {
this.delegate = delegate;
this.onCompletion = new RunOnce(() -> methodHandlers.addResponseStats(encodedLength));
this.responseLengthRecorder = responseLengthRecorder;
}

@Override
Expand All @@ -876,9 +910,9 @@ public boolean isDone() {
@Override
public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> recycler) throws IOException {
final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler);
encodedLength += bytesReference.length();
responseLengthRecorder.addChunkLength(bytesReference.length());
if (isDone()) {
onCompletion.run();
responseLengthRecorder.close();
}
return bytesReference;
}
Expand All @@ -887,13 +921,6 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler<BytesRef> rec
public String getResponseContentTypeString() {
return delegate.getResponseContentTypeString();
}

@Override
public void close() {
// the client might close the connection before we send the last chunk, in which case we won't have recorded the response in the
// stats yet, so we do it now:
onCompletion.run();
}
}

private static CircuitBreaker inFlightRequestsBreaker(CircuitBreakerService circuitBreakerService) {
Expand Down

0 comments on commit ed6e8ec

Please sign in to comment.