diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 1487809d942d..2462876a2c6f 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -21,10 +21,10 @@ import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.List; +import java.util.LinkedHashSet; +import java.util.Set; import jakarta.servlet.ServletOutputStream; import jakarta.servlet.WriteListener; @@ -254,7 +254,7 @@ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(n public Collection getHeaderNames() { Collection headerNames = super.getHeaderNames(); if (this.contentLength != null || this.contentType != null) { - List result = new ArrayList<>(headerNames); + Set result = new LinkedHashSet<>(headerNames); if (this.contentLength != null) { result.add(HttpHeaders.CONTENT_LENGTH); } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java index e63f206f9568..4be7fb318b8d 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java @@ -16,8 +16,14 @@ package org.springframework.web.filter; +import java.util.function.BiConsumer; +import java.util.stream.Stream; + import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.http.MediaType; import org.springframework.util.FileCopyUtils; @@ -26,6 +32,8 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Named.named; +import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.springframework.http.HttpHeaders.CONTENT_LENGTH; import static org.springframework.http.HttpHeaders.CONTENT_TYPE; import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING; @@ -44,11 +52,11 @@ void copyBodyToResponse() throws Exception { MockHttpServletResponse response = new MockHttpServletResponse(); ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); - responseWrapper.setStatus(HttpServletResponse.SC_OK); + responseWrapper.setStatus(HttpServletResponse.SC_CREATED); FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); responseWrapper.copyBodyToResponse(); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); assertThat(response.getContentLength()).isGreaterThan(0); assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); } @@ -61,69 +69,114 @@ void copyBodyToResponseWithPresetHeaders() throws Exception { String MAGIC = "42"; byte[] responseBody = "Hello World".getBytes(UTF_8); - String responseLength = Integer.toString(responseBody.length); + int responseLength = responseBody.length; + int originalContentLength = 999; String contentType = MediaType.APPLICATION_JSON_VALUE; MockHttpServletResponse response = new MockHttpServletResponse(); response.setContentType(contentType); - response.setContentLength(999); + response.setContentLength(originalContentLength); response.setHeader(PUZZLE, ENIGMA); response.setIntHeader(NUMBER, 42); ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); - responseWrapper.setStatus(HttpServletResponse.SC_OK); + responseWrapper.setStatus(HttpServletResponse.SC_CREATED); - assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_OK); + assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); assertThat(responseWrapper.getContentSize()).isZero(); assertThat(responseWrapper.getHeaderNames()) .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); - assertThat(responseWrapper.containsHeader(PUZZLE)).as(PUZZLE).isTrue(); - assertThat(responseWrapper.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA); - assertThat(responseWrapper.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA); - - assertThat(responseWrapper.containsHeader(NUMBER)).as(NUMBER).isTrue(); - assertThat(responseWrapper.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC); - assertThat(responseWrapper.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC); - - assertThat(responseWrapper.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue(); - assertThat(responseWrapper.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType); - assertThat(responseWrapper.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType); - assertThat(responseWrapper.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); - - assertThat(responseWrapper.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue(); - assertThat(responseWrapper.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo("999"); - assertThat(responseWrapper.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly("999"); + assertHeader(responseWrapper, PUZZLE, ENIGMA); + assertHeader(responseWrapper, NUMBER, MAGIC); + assertHeader(responseWrapper, CONTENT_LENGTH, originalContentLength); + assertContentTypeHeader(responseWrapper, contentType); FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength); + responseWrapper.copyBodyToResponse(); + assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertThat(responseWrapper.getContentSize()).isZero(); assertThat(responseWrapper.getHeaderNames()) .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - assertThat(response.getContentType()).isEqualTo(contentType); - assertThat(response.getContentLength()).isEqualTo(responseBody.length); + assertHeader(responseWrapper, PUZZLE, ENIGMA); + assertHeader(responseWrapper, NUMBER, MAGIC); + assertHeader(responseWrapper, CONTENT_LENGTH, responseLength); + assertContentTypeHeader(responseWrapper, contentType); + + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertThat(response.getContentLength()).isEqualTo(responseLength); assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); assertThat(response.getHeaderNames()) .containsExactlyInAnyOrder(PUZZLE, NUMBER, CONTENT_TYPE, CONTENT_LENGTH); - assertThat(response.containsHeader(PUZZLE)).as(PUZZLE).isTrue(); - assertThat(response.getHeader(PUZZLE)).as(PUZZLE).isEqualTo(ENIGMA); - assertThat(response.getHeaders(PUZZLE)).as(PUZZLE).containsExactly(ENIGMA); + assertHeader(response, PUZZLE, ENIGMA); + assertHeader(response, NUMBER, MAGIC); + assertHeader(response, CONTENT_LENGTH, responseLength); + assertContentTypeHeader(response, contentType); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("setContentTypeFunctions") + void copyBodyToResponseWithOverridingHeaders(BiConsumer setContentType) throws Exception { + byte[] responseBody = "Hello World".getBytes(UTF_8); + int responseLength = responseBody.length; + int originalContentLength = 11; + int overridingContentLength = 22; + String originalContentType = MediaType.TEXT_PLAIN_VALUE; + String overridingContentType = MediaType.APPLICATION_JSON_VALUE; - assertThat(response.containsHeader(NUMBER)).as(NUMBER).isTrue(); - assertThat(response.getHeader(NUMBER)).as(NUMBER).isEqualTo(MAGIC); - assertThat(response.getHeaders(NUMBER)).as(NUMBER).containsExactly(MAGIC); + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setContentLength(originalContentLength); + response.setContentType(originalContentType); - assertThat(response.containsHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isTrue(); - assertThat(response.getHeader(CONTENT_TYPE)).as(CONTENT_TYPE).isEqualTo(contentType); - assertThat(response.getHeaders(CONTENT_TYPE)).as(CONTENT_TYPE).containsExactly(contentType); - assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); + ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); + responseWrapper.setStatus(HttpServletResponse.SC_CREATED); + responseWrapper.setContentLength(overridingContentLength); + setContentType.accept(responseWrapper, overridingContentType); + + assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertThat(responseWrapper.getContentSize()).isZero(); + assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH); + + assertHeader(response, CONTENT_LENGTH, originalContentLength); + assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength); + assertContentTypeHeader(response, originalContentType); + assertContentTypeHeader(responseWrapper, overridingContentType); + + FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength); - assertThat(response.containsHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isTrue(); - assertThat(response.getHeader(CONTENT_LENGTH)).as(CONTENT_LENGTH).isEqualTo(responseLength); - assertThat(response.getHeaders(CONTENT_LENGTH)).as(CONTENT_LENGTH).containsExactly(responseLength); + responseWrapper.copyBodyToResponse(); + + assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertThat(responseWrapper.getContentSize()).isZero(); + assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH); + + assertHeader(response, CONTENT_LENGTH, responseLength); + assertHeader(responseWrapper, CONTENT_LENGTH, responseLength); + assertContentTypeHeader(response, overridingContentType); + assertContentTypeHeader(responseWrapper, overridingContentType); + + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertThat(response.getContentLength()).isEqualTo(responseLength); + assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); + assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH); + } + + private static Stream setContentTypeFunctions() { + return Stream.of( + namedArguments("setContentType()", HttpServletResponse::setContentType), + namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)), + namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType)) + ); + } + + private static Arguments namedArguments(String name, BiConsumer setContentTypeFunction) { + return arguments(named(name, setContentTypeFunction)); } @Test @@ -132,15 +185,37 @@ void copyBodyToResponseWithTransferEncoding() throws Exception { MockHttpServletResponse response = new MockHttpServletResponse(); ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); - responseWrapper.setStatus(HttpServletResponse.SC_OK); + responseWrapper.setStatus(HttpServletResponse.SC_CREATED); responseWrapper.setHeader(TRANSFER_ENCODING, "chunked"); FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); responseWrapper.copyBodyToResponse(); - assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_OK); - assertThat(response.getHeader(TRANSFER_ENCODING)).isEqualTo("chunked"); - assertThat(response.getHeader(CONTENT_LENGTH)).isNull(); + assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + assertHeader(response, TRANSFER_ENCODING, "chunked"); + assertHeader(response, CONTENT_LENGTH, null); assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); } + private void assertHeader(HttpServletResponse response, String header, int value) { + assertHeader(response, header, Integer.toString(value)); + } + + private void assertHeader(HttpServletResponse response, String header, String value) { + if (value == null) { + assertThat(response.containsHeader(header)).as(header).isFalse(); + assertThat(response.getHeader(header)).as(header).isNull(); + assertThat(response.getHeaders(header)).as(header).isEmpty(); + } + else { + assertThat(response.containsHeader(header)).as(header).isTrue(); + assertThat(response.getHeader(header)).as(header).isEqualTo(value); + assertThat(response.getHeaders(header)).as(header).containsExactly(value); + } + } + + private void assertContentTypeHeader(HttpServletResponse response, String contentType) { + assertHeader(response, CONTENT_TYPE, contentType); + assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); + } + }