diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
index 017c0f635628..2462876a2c6f 100644
--- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
+++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java
@@ -21,10 +21,10 @@
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
-import java.util.List;
+import java.util.LinkedHashSet;
+import java.util.Set;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
@@ -43,6 +43,7 @@
*
Used e.g. by {@link org.springframework.web.filter.ShallowEtagHeaderFilter}.
*
* @author Juergen Hoeller
+ * @author Sam Brannen
* @since 4.1.3
* @see ContentCachingRequestWrapper
*/
@@ -157,16 +158,19 @@ public void setContentType(@Nullable String type) {
@Override
@Nullable
public String getContentType() {
- return this.contentType;
+ if (this.contentType != null) {
+ return this.contentType;
+ }
+ return super.getContentType();
}
@Override
public boolean containsHeader(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return this.contentLength != null;
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return true;
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
- return this.contentType != null;
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ return true;
}
else {
return super.containsHeader(name);
@@ -222,10 +226,10 @@ public void addIntHeader(String name, int value) {
@Override
@Nullable
public String getHeader(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return (this.contentLength != null) ? this.contentLength.toString() : null;
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return this.contentLength.toString();
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
@@ -235,12 +239,11 @@ else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
@Override
public Collection getHeaders(String name) {
- if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
- return this.contentLength != null ? Collections.singleton(this.contentLength.toString()) :
- Collections.emptySet();
+ if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
+ return Collections.singleton(this.contentLength.toString());
}
- else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
- return this.contentType != null ? Collections.singleton(this.contentType) : Collections.emptySet();
+ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
+ return Collections.singleton(this.contentType);
}
else {
return super.getHeaders(name);
@@ -251,7 +254,7 @@ else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
public Collection getHeaderNames() {
Collection headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
- List result = new ArrayList<>(headerNames);
+ Set result = new LinkedHashSet<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
@@ -330,7 +333,7 @@ protected void copyBodyToResponse(boolean complete) throws IOException {
}
this.contentLength = null;
}
- if (complete || this.contentType != null) {
+ if (this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
index 1c11b5a969b1..ec586ae3da79 100644
--- a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
+++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java
@@ -16,21 +16,32 @@
package org.springframework.web.filter;
+import java.util.function.BiConsumer;
+import java.util.stream.Stream;
+
import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.springframework.http.MediaType;
import org.springframework.util.FileCopyUtils;
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
import org.springframework.web.util.ContentCachingResponseWrapper;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Named.named;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
+import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING;
/**
* Unit tests for {@link ContentCachingResponseWrapper}.
* @author Rossen Stoyanchev
+ * @author Sam Brannen
*/
public class ContentCachingResponseWrapperTests {
@@ -49,6 +60,124 @@ void copyBodyToResponse() throws Exception {
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}
+ @Test
+ void copyBodyToResponseWithPresetHeaders() throws Exception {
+ String PUZZLE = "puzzle";
+ String ENIGMA = "enigma";
+ String NUMBER = "number";
+ String MAGIC = "42";
+
+ byte[] responseBody = "Hello World".getBytes(UTF_8);
+ int responseLength = responseBody.length;
+ int originalContentLength = 999;
+ String contentType = MediaType.APPLICATION_JSON_VALUE;
+
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ response.setContentType(contentType);
+ response.setContentLength(originalContentLength);
+ response.setHeader(PUZZLE, ENIGMA);
+ response.setIntHeader(NUMBER, 42);
+
+ ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
+ responseWrapper.setStatus(HttpServletResponse.SC_CREATED);
+
+ assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(responseWrapper.getContentSize()).isZero();
+ assertThat(responseWrapper.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertHeader(responseWrapper, PUZZLE, ENIGMA);
+ assertHeader(responseWrapper, NUMBER, MAGIC);
+ assertHeader(responseWrapper, CONTENT_LENGTH, originalContentLength);
+ assertContentTypeHeader(responseWrapper, contentType);
+
+ FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
+ assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);
+
+ responseWrapper.copyBodyToResponse();
+
+ assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(responseWrapper.getContentSize()).isZero();
+ assertThat(responseWrapper.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertHeader(responseWrapper, PUZZLE, ENIGMA);
+ assertHeader(responseWrapper, NUMBER, MAGIC);
+ assertHeader(responseWrapper, CONTENT_LENGTH, responseLength);
+ assertContentTypeHeader(responseWrapper, contentType);
+
+ assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(response.getContentLength()).isEqualTo(responseLength);
+ assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
+ assertThat(response.getHeaderNames())
+ .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertHeader(response, PUZZLE, ENIGMA);
+ assertHeader(response, NUMBER, MAGIC);
+ assertHeader(response, CONTENT_LENGTH, responseLength);
+ assertContentTypeHeader(response, contentType);
+ }
+
+ @ParameterizedTest(name = "[{index}] {0}")
+ @MethodSource("setContentTypeFunctions")
+ void copyBodyToResponseWithOverridingHeaders(BiConsumer setContentType) throws Exception {
+ byte[] responseBody = "Hello World".getBytes(UTF_8);
+ int responseLength = responseBody.length;
+ int originalContentLength = 11;
+ int overridingContentLength = 22;
+ String originalContentType = MediaType.TEXT_PLAIN_VALUE;
+ String overridingContentType = MediaType.APPLICATION_JSON_VALUE;
+
+ MockHttpServletResponse response = new MockHttpServletResponse();
+ response.setContentLength(originalContentLength);
+ response.setContentType(originalContentType);
+
+ ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
+ responseWrapper.setStatus(HttpServletResponse.SC_CREATED);
+ responseWrapper.setContentLength(overridingContentLength);
+ setContentType.accept(responseWrapper, overridingContentType);
+
+ assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(responseWrapper.getContentSize()).isZero();
+ assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertHeader(response, CONTENT_LENGTH, originalContentLength);
+ assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength);
+ assertContentTypeHeader(response, originalContentType);
+ assertContentTypeHeader(responseWrapper, overridingContentType);
+
+ FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
+ assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);
+
+ responseWrapper.copyBodyToResponse();
+
+ assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(responseWrapper.getContentSize()).isZero();
+ assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
+
+ assertHeader(response, CONTENT_LENGTH, responseLength);
+ assertHeader(responseWrapper, CONTENT_LENGTH, responseLength);
+ assertContentTypeHeader(response, overridingContentType);
+ assertContentTypeHeader(responseWrapper, overridingContentType);
+
+ assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
+ assertThat(response.getContentLength()).isEqualTo(responseLength);
+ assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
+ assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
+ }
+
+ private static Stream setContentTypeFunctions() {
+ return Stream.of(
+ namedArguments("setContentType()", HttpServletResponse::setContentType),
+ namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
+ namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
+ );
+ }
+
+ private static Arguments namedArguments(String name, BiConsumer setContentTypeFunction) {
+ return arguments(named(name, setContentTypeFunction));
+ }
+
@Test
void copyBodyToResponseWithTransferEncoding() throws Exception {
byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);
@@ -66,6 +195,10 @@ void copyBodyToResponseWithTransferEncoding() throws Exception {
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}
+ private void assertHeader(HttpServletResponse response, String header, int value) {
+ assertHeader(response, header, Integer.toString(value));
+ }
+
private void assertHeader(HttpServletResponse response, String header, String value) {
if (value == null) {
assertThat(response.containsHeader(header)).as(header).isFalse();
@@ -79,4 +212,9 @@ private void assertHeader(HttpServletResponse response, String header, String va
}
}
+ private void assertContentTypeHeader(HttpServletResponse response, String contentType) {
+ assertHeader(response, CONTENT_TYPE, contentType);
+ assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
+ }
+
}