diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index 729e04f19a62c..2ebee9c59482e 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -26,7 +26,6 @@ import org.elasticsearch.common.path.PathTrie; import org.elasticsearch.common.recycler.Recycler; import org.elasticsearch.common.util.Maps; -import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -41,6 +40,7 @@ import org.elasticsearch.rest.RestHandler.Route; import org.elasticsearch.tasks.Task; import org.elasticsearch.telemetry.tracing.Tracer; +import org.elasticsearch.transport.Transports; import org.elasticsearch.usage.SearchUsageHolder; import org.elasticsearch.usage.UsageService; import org.elasticsearch.xcontent.XContentBuilder; @@ -60,6 +60,7 @@ import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import static org.elasticsearch.indices.SystemIndices.EXTERNAL_SYSTEM_INDEX_ACCESS_CONTROL_HEADER_KEY; @@ -826,9 +827,13 @@ public void sendResponse(RestResponse response) { if (response.isChunked() == false) { methodHandlers.addResponseStats(response.content().length()); } else { - final var wrapped = new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), methodHandlers); + final var responseLengthRecorder = new ResponseLengthRecorder(methodHandlers); final var headers = response.getHeaders(); - response = RestResponse.chunked(response.status(), wrapped, Releasables.wrap(wrapped, response)); + response = RestResponse.chunked( + response.status(), + new EncodedLengthTrackingChunkedRestResponseBody(response.chunkedContent(), responseLengthRecorder), + Releasables.wrap(responseLengthRecorder, response) + ); for (final var header : headers.entrySet()) { for (final var value : header.getValue()) { response.addHeader(header.getKey(), value); @@ -857,15 +862,44 @@ private void close() { } } - private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody, Releasable { + private static class ResponseLengthRecorder extends AtomicReference implements Releasable { + private long responseLength; + + private ResponseLengthRecorder(MethodHandlers methodHandlers) { + super(methodHandlers); + } + + @Override + public void close() { + // closed just before sending the last chunk, and also when the whole RestResponse is closed since the client might abort the + // connection before we send the last chunk, in which case we won't have recorded the response in the + // stats yet; thus we need run-once semantics here: + final var methodHandlers = getAndSet(null); + if (methodHandlers != null) { + // if we started sending chunks then we're closed on the transport worker, no need for sync + assert responseLength == 0L || Transports.assertTransportThread(); + methodHandlers.addResponseStats(responseLength); + } + } + + void addChunkLength(long chunkLength) { + assert chunkLength >= 0L : chunkLength; + assert Transports.assertTransportThread(); // always called on the transport worker, no need for sync + responseLength += chunkLength; + } + } + + private static class EncodedLengthTrackingChunkedRestResponseBody implements ChunkedRestResponseBody { private final ChunkedRestResponseBody delegate; - private final RunOnce onCompletion; - private long encodedLength = 0; + private final ResponseLengthRecorder responseLengthRecorder; - private EncodedLengthTrackingChunkedRestResponseBody(ChunkedRestResponseBody delegate, MethodHandlers methodHandlers) { + private EncodedLengthTrackingChunkedRestResponseBody( + ChunkedRestResponseBody delegate, + ResponseLengthRecorder responseLengthRecorder + ) { this.delegate = delegate; - this.onCompletion = new RunOnce(() -> methodHandlers.addResponseStats(encodedLength)); + this.responseLengthRecorder = responseLengthRecorder; } @Override @@ -876,9 +910,9 @@ public boolean isDone() { @Override public ReleasableBytesReference encodeChunk(int sizeHint, Recycler recycler) throws IOException { final ReleasableBytesReference bytesReference = delegate.encodeChunk(sizeHint, recycler); - encodedLength += bytesReference.length(); + responseLengthRecorder.addChunkLength(bytesReference.length()); if (isDone()) { - onCompletion.run(); + responseLengthRecorder.close(); } return bytesReference; } @@ -887,13 +921,6 @@ public ReleasableBytesReference encodeChunk(int sizeHint, Recycler rec public String getResponseContentTypeString() { return delegate.getResponseContentTypeString(); } - - @Override - public void close() { - // the client might close the connection before we send the last chunk, in which case we won't have recorded the response in the - // stats yet, so we do it now: - onCompletion.run(); - } } private static CircuitBreaker inFlightRequestsBreaker(CircuitBreakerService circuitBreakerService) {