Skip to content

Commit

Permalink
Do not cache Content-Type in ContentCachingResponseWrapper
Browse files Browse the repository at this point in the history
Based on feedback from several members of the community, we have
decided to revert the caching of the Content-Type header that was
introduced in ContentCachingResponseWrapper in 375e0e6.

This commit therefore completely removes Content-Type caching in
ContentCachingResponseWrapper and updates the existing tests
accordingly.

To provide guards against future regressions in this area, this commit
also introduces explicit tests for the 6 ways to set the content length
in ContentCachingResponseWrapper and modifies a test in
ShallowEtagHeaderFilterTests to ensure that a Content-Type header set
directly on ContentCachingResponseWrapper is propagated to the
underlying response even if content caching is disabled for the
ShallowEtagHeaderFilter.

See spring-projectsgh-32039
See spring-projectsgh-32317
Closes spring-projectsgh-32321
sbrannen committed Feb 28, 2024
1 parent 629c560 commit d1b3107
Showing 3 changed files with 81 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -60,9 +60,6 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper {
@Nullable
private Integer contentLength;

@Nullable
private String contentType;


/**
* Create a new ContentCachingResponseWrapper for the given servlet response.
@@ -150,28 +147,11 @@ public void setContentLengthLong(long len) {
setContentLength((int) len);
}

@Override
public void setContentType(@Nullable String type) {
this.contentType = type;
}

@Override
@Nullable
public String getContentType() {
if (this.contentType != null) {
return this.contentType;
}
return super.getContentType();
}

@Override
public boolean containsHeader(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return true;
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return true;
}
else {
return super.containsHeader(name);
}
@@ -182,9 +162,6 @@ public void setHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.setHeader(name, value);
}
@@ -195,9 +172,6 @@ public void addHeader(String name, String value) {
if (HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
this.contentLength = Integer.valueOf(value);
}
else if (HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
this.contentType = value;
}
else {
super.addHeader(name, value);
}
@@ -229,9 +203,6 @@ public String getHeader(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return this.contentLength.toString();
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return this.contentType;
}
else {
return super.getHeader(name);
}
@@ -242,9 +213,6 @@ public Collection<String> getHeaders(String name) {
if (this.contentLength != null && HttpHeaders.CONTENT_LENGTH.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentLength.toString());
}
else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(name)) {
return Collections.singleton(this.contentType);
}
else {
return super.getHeaders(name);
}
@@ -253,14 +221,9 @@ else if (this.contentType != null && HttpHeaders.CONTENT_TYPE.equalsIgnoreCase(n
@Override
public Collection<String> getHeaderNames() {
Collection<String> headerNames = super.getHeaderNames();
if (this.contentLength != null || this.contentType != null) {
if (this.contentLength != null) {
Set<String> result = new LinkedHashSet<>(headerNames);
if (this.contentLength != null) {
result.add(HttpHeaders.CONTENT_LENGTH);
}
if (this.contentType != null) {
result.add(HttpHeaders.CONTENT_TYPE);
}
result.add(HttpHeaders.CONTENT_LENGTH);
return result;
}
else {
@@ -333,10 +296,6 @@ protected void copyBodyToResponse(boolean complete) throws IOException {
}
this.contentLength = null;
}
if (this.contentType != null) {
rawResponse.setContentType(this.contentType);
this.contentType = null;
}
}
this.content.writeTo(rawResponse.getOutputStream());
this.content.reset();
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -16,13 +16,12 @@

package org.springframework.web.filter;

import java.util.function.BiConsumer;
import java.util.stream.Stream;

import jakarta.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.http.MediaType;
@@ -33,17 +32,17 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Named.named;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.springframework.http.HttpHeaders.CONTENT_LENGTH;
import static org.springframework.http.HttpHeaders.CONTENT_TYPE;
import static org.springframework.http.HttpHeaders.TRANSFER_ENCODING;

/**
* Unit tests for {@link ContentCachingResponseWrapper}.
*
* @author Rossen Stoyanchev
* @author Sam Brannen
*/
public class ContentCachingResponseWrapperTests {
class ContentCachingResponseWrapperTests {

@Test
void copyBodyToResponse() throws Exception {
@@ -119,39 +118,83 @@ void copyBodyToResponseWithPresetHeaders() throws Exception {
}

@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingHeaders(BiConsumer<HttpServletResponse, String> setContentType) throws Exception {
@MethodSource("setContentLengthFunctions")
void copyBodyToResponseWithOverridingContentLength(SetContentLength setContentLength) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
int originalContentLength = 11;
int overridingContentLength = 22;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;

MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentLength(originalContentLength);
response.setContentType(originalContentType);

ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);
responseWrapper.setStatus(HttpServletResponse.SC_CREATED);
responseWrapper.setContentLength(overridingContentLength);
setContentType.accept(responseWrapper, overridingContentType);

assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
setContentLength.invoke(responseWrapper, overridingContentLength);

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);

assertHeader(response, CONTENT_LENGTH, originalContentLength);
assertHeader(responseWrapper, CONTENT_LENGTH, overridingContentLength);

FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);

responseWrapper.copyBodyToResponse();

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);

assertHeader(response, CONTENT_LENGTH, responseLength);
assertHeader(responseWrapper, CONTENT_LENGTH, responseLength);

assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_LENGTH);
}

private static Stream<Named<SetContentLength>> setContentLengthFunctions() {
return Stream.of(
named("setContentLength()", HttpServletResponse::setContentLength),
named("setContentLengthLong()", HttpServletResponse::setContentLengthLong),
named("setIntHeader()", (response, contentLength) -> response.setIntHeader(CONTENT_LENGTH, contentLength)),
named("addIntHeader()", (response, contentLength) -> response.addIntHeader(CONTENT_LENGTH, contentLength)),
named("setHeader()", (response, contentLength) -> response.setHeader(CONTENT_LENGTH, "" + contentLength)),
named("addHeader()", (response, contentLength) -> response.addHeader(CONTENT_LENGTH, "" + contentLength))
);
}

@ParameterizedTest(name = "[{index}] {0}")
@MethodSource("setContentTypeFunctions")
void copyBodyToResponseWithOverridingContentType(SetContentType setContentType) throws Exception {
byte[] responseBody = "Hello World".getBytes(UTF_8);
int responseLength = responseBody.length;
String originalContentType = MediaType.TEXT_PLAIN_VALUE;
String overridingContentType = MediaType.APPLICATION_JSON_VALUE;

MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(originalContentType);

ContentCachingResponseWrapper responseWrapper = new ContentCachingResponseWrapper(response);

assertContentTypeHeader(response, originalContentType);
assertContentTypeHeader(responseWrapper, originalContentType);

setContentType.invoke(responseWrapper, overridingContentType);

assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE);

assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);

FileCopyUtils.copy(responseBody, responseWrapper.getOutputStream());
assertThat(responseWrapper.getContentSize()).isEqualTo(responseLength);

responseWrapper.copyBodyToResponse();

assertThat(responseWrapper.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(responseWrapper.getContentSize()).isZero();
assertThat(responseWrapper.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);

@@ -160,24 +203,19 @@ void copyBodyToResponseWithOverridingHeaders(BiConsumer<HttpServletResponse, Str
assertContentTypeHeader(response, overridingContentType);
assertContentTypeHeader(responseWrapper, overridingContentType);

assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_CREATED);
assertThat(response.getContentLength()).isEqualTo(responseLength);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
assertThat(response.getHeaderNames()).containsExactlyInAnyOrder(CONTENT_TYPE, CONTENT_LENGTH);
}

private static Stream<Arguments> setContentTypeFunctions() {
private static Stream<Named<SetContentType>> setContentTypeFunctions() {
return Stream.of(
namedArguments("setContentType()", HttpServletResponse::setContentType),
namedArguments("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
namedArguments("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
named("setContentType()", HttpServletResponse::setContentType),
named("setHeader()", (response, contentType) -> response.setHeader(CONTENT_TYPE, contentType)),
named("addHeader()", (response, contentType) -> response.addHeader(CONTENT_TYPE, contentType))
);
}

private static Arguments namedArguments(String name, BiConsumer<HttpServletResponse, String> setContentTypeFunction) {
return arguments(named(name, setContentTypeFunction));
}

@Test
void copyBodyToResponseWithTransferEncoding() throws Exception {
byte[] responseBody = "6\r\nHello 5\r\nWorld0\r\n\r\n".getBytes(UTF_8);
@@ -217,4 +255,15 @@ private void assertContentTypeHeader(HttpServletResponse response, String conten
assertThat(response.getContentType()).as(CONTENT_TYPE).isEqualTo(contentType);
}


@FunctionalInterface
private interface SetContentLength {
void invoke(HttpServletResponse response, int contentLength);
}

@FunctionalInterface
private interface SetContentType {
void invoke(HttpServletResponse response, String contentType);
}

}
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;
import static org.springframework.http.MediaType.TEXT_PLAIN_VALUE;

/**
@@ -36,6 +37,7 @@
* @author Arjen Poutsma
* @author Brian Clozel
* @author Juergen Hoeller
* @author Sam Brannen
*/
class ShallowEtagHeaderFilterTests {

@@ -123,7 +125,7 @@ void filterMatch() throws Exception {
assertThat(response.getStatus()).as("Invalid status").isEqualTo(304);
assertThat(response.getHeader("ETag")).as("Invalid ETag").isEqualTo("\"0b10a8db164e0754105b7a99be72e3fe5\"");
assertThat(response.containsHeader("Content-Length")).as("Response has Content-Length header").isFalse();
assertThat(response.containsHeader("Content-Type")).as("Response has Content-Type header").isFalse();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(TEXT_PLAIN_VALUE);
assertThat(response.getContentAsByteArray()).as("Invalid content").isEmpty();
}

@@ -173,11 +175,13 @@ void filterWriter() throws Exception {
void filterWriterWithDisabledCaching() throws Exception {
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/hotels");
MockHttpServletResponse response = new MockHttpServletResponse();
response.setContentType(TEXT_PLAIN_VALUE);

byte[] responseBody = "Hello World".getBytes(UTF_8);
FilterChain filterChain = (filterRequest, filterResponse) -> {
assertThat(filterRequest).as("Invalid request passed").isEqualTo(request);
((HttpServletResponse) filterResponse).setStatus(HttpServletResponse.SC_OK);
filterResponse.setContentType(APPLICATION_JSON_VALUE);
FileCopyUtils.copy(responseBody, filterResponse.getOutputStream());
};

@@ -186,6 +190,7 @@ void filterWriterWithDisabledCaching() throws Exception {

assertThat(response.getStatus()).isEqualTo(200);
assertThat(response.getHeader("ETag")).isNull();
assertThat(response.getContentType()).as("Invalid Content-Type header").isEqualTo(APPLICATION_JSON_VALUE);
assertThat(response.getContentAsByteArray()).isEqualTo(responseBody);
}

0 comments on commit d1b3107

Please sign in to comment.