diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 2462876a2c6f..c2038fb0ecc2 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -60,9 +60,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Nullable private Integer contentLength; - @Nullable - private String contentType; - /** * Create a new ContentCachingResponseWrapper for the given servlet response. @@ -150,28 +147,11 @@ public void setContentLengthLong(long len) { setContentLength((int) len); } - @Override - public void setContentType(@Nullable String type) { - this.contentType = type; - } - - @Override - @Nullable - public String getContentType() { - if (this.contentType != null) { - return this.contentType; - } - return super.getContentType(); - } - @Override public boolean containsHeader(String name) { if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { return true; } - else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - return true; - } else { return super.containsHeader(name); } @@ -182,9 +162,6 @@ public void setHeader(String name, String value) { if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { this.contentLength = Integer.valueOf(value); } - else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - this.contentType = value; - } else { super.setHeader(name, value); } @@ -195,9 +172,6 @@ public void addHeader(String name, String value) { if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { this.contentLength = Integer.valueOf(value); } - else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - this.contentType = value; - } else { super.addHeader(name, value); } @@ -229,9 +203,6 @@ public String getHeader(String name) { if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { return this.contentLength.toString(); } - else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - return this.contentType; - } else { return super.getHeader(name); } @@ -242,9 +213,6 @@ public Collection getHeaders(String name) { if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) { return Collections.singleton(this.contentLength.toString()); } - else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) { - return Collections.singleton(this.contentType); - } else { return super.getHeaders(name); } @@ -253,14 +221,9 @@ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(n @Override public Collection getHeaderNames() { Collection headerNames = super.getHeaderNames(); - if (this.contentLength != null || this.contentType != null) { + if (this.contentLength != null) { Set result = new LinkedHashSet<>(headerNames); - if (this.contentLength != null) { - result.add(HttpHeaders.CONTENT_LENGTH); - } - if (this.contentType != null) { - result.add(HttpHeaders.CONTENT_TYPE); - } + result.add(HttpHeaders.CONTENT_LENGTH); return result; } else { @@ -333,10 +296,6 @@ protected void copyBodyToResponse(boolean complete) throws IOException { } this.contentLength = null; } - if (this.contentType != null) { - rawResponse.setContentType(this.contentType); - this.contentType = null; - } } this.content.writeTo(rawResponse.getOutputStream()); this.content.reset(); diff --git a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java index ec586ae3da79..091adb3e2c83 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ContentCachingResponseWrapperTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,12 @@ package org.springframework.web.filter; -import java.util.function.BiConsumer; import java.util.stream.Stream; import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.springframework.http.MediaType; @@ -33,17 +32,17 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Named.named; -import static org.junit.jupiter.params.provider.Arguments.arguments; import static org.springframework.http.HttpHeaders.CONTENT_LENGTH; import static org.springframework.http.HttpHeaders.CONTENT_TYPE; import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING; /** * Unit tests for {@link ContentCachingResponseWrapper}. + * * @author Rossen Stoyanchev * @author Sam Brannen */ -public class ContentCachingResponseWrapperTests { +class ContentCachingResponseWrapperTests { @Test void copyBodyToResponse() throws Exception { @@ -119,31 +118,76 @@ void copyBodyToResponseWithPresetHeaders() throws Exception { } @ParameterizedTest(name = "[{index}] {0}") - @MethodSource("setContentTypeFunctions") - void copyBodyToResponseWithOverridingHeaders(BiConsumer setContentType) throws Exception { + @MethodSource("setContentLengthFunctions") + void copyBodyToResponseWithOverridingContentLength(SetContentLength setContentLength) throws Exception { byte[] responseBody = "Hello World".getBytes(UTF_8); int responseLength = responseBody.length; int originalContentLength = 11; int overridingContentLength = 22; - String originalContentType = MediaType.TEXT_PLAIN_VALUE; - String overridingContentType = MediaType.APPLICATION_JSON_VALUE; MockHttpServletResponse response = new MockHttpServletResponse(); response.setContentLength(originalContentLength); - response.setContentType(originalContentType); ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); - responseWrapper.setStatus(HttpServletResponse.SC_CREATED); responseWrapper.setContentLength(overridingContentLength); - setContentType.accept(responseWrapper, overridingContentType); - assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED); + setContentLength.invoke(responseWrapper, overridingContentLength); + assertThat(responseWrapper.getContentSize()).isZero(); - assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH); + assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH); assertHeader(response, CONTENT_LENGTH, originalContentLength); assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength); + + FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); + assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength); + + responseWrapper.copyBodyToResponse(); + + assertThat(responseWrapper.getContentSize()).isZero(); + assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH); + + assertHeader(response, CONTENT_LENGTH, responseLength); + assertHeader(responseWrapper, CONTENT_LENGTH, responseLength); + + assertThat(response.getContentLength()).isEqualTo(responseLength); + assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); + assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH); + } + + private static Stream> setContentLengthFunctions() { + return Stream.of( + named("setContentLength()", HttpServletResponse::setContentLength), + named("setContentLengthLong()", HttpServletResponse::setContentLengthLong), + named("setIntHeader()", (response, contentLength) -> response.setIntHeader(CONTENT_LENGTH, contentLength)), + named("addIntHeader()", (response, contentLength) -> response.addIntHeader(CONTENT_LENGTH, contentLength)), + named("setHeader()", (response, contentLength) -> response.setHeader(CONTENT_LENGTH, "" + contentLength)), + named("addHeader()", (response, contentLength) -> response.addHeader(CONTENT_LENGTH, "" + contentLength)) + ); + } + + @ParameterizedTest(name = "[{index}] {0}") + @MethodSource("setContentTypeFunctions") + void copyBodyToResponseWithOverridingContentType(SetContentType setContentType) throws Exception { + byte[] responseBody = "Hello World".getBytes(UTF_8); + int responseLength = responseBody.length; + String originalContentType = MediaType.TEXT_PLAIN_VALUE; + String overridingContentType = MediaType.APPLICATION_JSON_VALUE; + + MockHttpServletResponse response = new MockHttpServletResponse(); + response.setContentType(originalContentType); + + ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response); + assertContentTypeHeader(response, originalContentType); + assertContentTypeHeader(responseWrapper, originalContentType); + + setContentType.invoke(responseWrapper, overridingContentType); + + assertThat(responseWrapper.getContentSize()).isZero(); + assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE); + + assertContentTypeHeader(response, overridingContentType); assertContentTypeHeader(responseWrapper, overridingContentType); FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream()); @@ -151,7 +195,6 @@ void copyBodyToResponseWithOverridingHeaders(BiConsumer setContentTypeFunctions() { + private static Stream> setContentTypeFunctions() { return Stream.of( - namedArguments("setContentType()", HttpServletResponse::setContentType), - namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)), - namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType)) + named("setContentType()", HttpServletResponse::setContentType), + named("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)), + named("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType)) ); } - private static Arguments namedArguments(String name, BiConsumer setContentTypeFunction) { - return arguments(named(name, setContentTypeFunction)); - } - @Test void copyBodyToResponseWithTransferEncoding() throws Exception { byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8); @@ -217,4 +255,15 @@ private void assertContentTypeHeader(HttpServletResponse response, String conten assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType); } + + @FunctionalInterface + private interface SetContentLength { + void invoke(HttpServletResponse response, int contentLength); + } + + @FunctionalInterface + private interface SetContentType { + void invoke(HttpServletResponse response, String contentType); + } + } diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index a36cd55a62d1..aa153146869a 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -28,6 +28,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE; import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE; /** @@ -36,6 +37,7 @@ * @author Arjen Poutsma * @author Brian Clozel * @author Juergen Hoeller + * @author Sam Brannen */ class ShallowEtagHeaderFilterTests { @@ -123,7 +125,7 @@ void filterMatch() throws Exception { assertThat(response.getStatus()).as("Invalid status").isEqualTo(304); assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\""); assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse(); - assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse(); + assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(TEXT_PLAIN_VALUE); assertThat(response.getContentAsByteArray()).as("Invalid content").isEmpty(); } @@ -173,11 +175,13 @@ void filterWriter() throws Exception { void filterWriterWithDisabledCaching() throws Exception { MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels"); MockHttpServletResponse response = new MockHttpServletResponse(); + response.setContentType(TEXT_PLAIN_VALUE); byte[] responseBody = "Hello World".getBytes(UTF_8); FilterChain filterChain = (filterRequest, filterResponse) -> { assertThat(filterRequest).as("Invalid request passed").isEqualTo(request); ((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK); + filterResponse.setContentType(APPLICATION_JSON_VALUE); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); }; @@ -186,6 +190,7 @@ void filterWriterWithDisabledCaching() throws Exception { assertThat(response.getStatus()).isEqualTo(200); assertThat(response.getHeader("ETag")).isNull(); + assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(APPLICATION_JSON_VALUE); assertThat(response.getContentAsByteArray()).isEqualTo(responseBody); }