Skip to content

Commit

Permalink
Lenient treatment of malformed Accept header for @ExceptionHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev authored and lxbzmy committed Mar 26, 2022
1 parent fc35f84 commit 00177ea
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Hints;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.converter.HttpMessageNotWritableException;
Expand Down Expand Up @@ -144,7 +145,20 @@ protected Mono<Void> writeBody(@Nullable Object body, MethodParameter bodyParame
return Mono.from((Publisher<Void>) publisher);
}

MediaType bestMediaType = selectMediaType(exchange, () -> getMediaTypesFor(elementType));
MediaType bestMediaType;
try {
bestMediaType = selectMediaType(exchange, () -> getMediaTypesFor(elementType));
}
catch (NotAcceptableStatusException ex) {
HttpStatus statusCode = exchange.getResponse().getStatusCode();
if (statusCode != null && statusCode.isError()) {
if (logger.isDebugEnabled()) {
logger.debug("Ignoring error response content (if any). " + ex.getReason());
}
return Mono.empty();
}
throw ex;
}
if (bestMediaType != null) {
String logPrefix = exchange.getLogPrefix();
if (logger.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -320,7 +320,20 @@ private Mono<? extends Void> render(List<View> views, Map<String, Object> model,
}
}
List<MediaType> mediaTypes = getMediaTypes(views);
MediaType bestMediaType = selectMediaType(exchange, () -> mediaTypes);
MediaType bestMediaType;
try {
bestMediaType = selectMediaType(exchange, () -> mediaTypes);
}
catch (NotAcceptableStatusException ex) {
HttpStatus statusCode = exchange.getResponse().getStatusCode();
if (statusCode != null && statusCode.isError()) {
if (logger.isDebugEnabled()) {
logger.debug("Ignoring error response content (if any). " + ex.getReason());
}
return Mono.empty();
}
throw ex;
}
if (bestMediaType != null) {
for (View view : views) {
for (MediaType mediaType : view.getSupportedMediaTypes()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
import org.springframework.web.testfixture.http.server.reactive.MockServerHttpResponse;
import org.springframework.web.testfixture.server.MockServerWebExchange;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.core.ResolvableType.forClassWithGenerics;
import static org.springframework.http.MediaType.APPLICATION_JSON;
import static org.springframework.http.ResponseEntity.notFound;
import static org.springframework.http.ResponseEntity.ok;
import static org.springframework.web.reactive.HandlerMapping.PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE;
import static org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get;
import static org.springframework.web.testfixture.method.ResolvableMethod.on;
Expand Down Expand Up @@ -199,7 +199,7 @@ public void responseEntityHeaders() throws Exception {
}

@Test
public void handleResponseEntityWithNullBody() throws Exception {
public void handleResponseEntityWithNullBody() {
Object returnValue = Mono.just(notFound().build());
MethodParameter type = on(TestController.class).resolveReturnType(Mono.class, entity(String.class));
HandlerResult result = handlerResult(returnValue, type);
Expand All @@ -211,23 +211,23 @@ public void handleResponseEntityWithNullBody() throws Exception {
}

@Test
public void handleReturnTypes() throws Exception {
Object returnValue = ok("abc");
public void handleReturnTypes() {
Object returnValue = ResponseEntity.ok("abc");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
testHandle(returnValue, returnType);

returnType = on(TestController.class).resolveReturnType(Object.class);
testHandle(returnValue, returnType);

returnValue = Mono.just(ok("abc"));
returnValue = Mono.just(ResponseEntity.ok("abc"));
returnType = on(TestController.class).resolveReturnType(Mono.class, entity(String.class));
testHandle(returnValue, returnType);

returnValue = Mono.just(ok("abc"));
returnValue = Mono.just(ResponseEntity.ok("abc"));
returnType = on(TestController.class).resolveReturnType(Single.class, entity(String.class));
testHandle(returnValue, returnType);

returnValue = Mono.just(ok("abc"));
returnValue = Mono.just(ResponseEntity.ok("abc"));
returnType = on(TestController.class).resolveReturnType(CompletableFuture.class, entity(String.class));
testHandle(returnValue, returnType);
}
Expand All @@ -239,7 +239,7 @@ public void handleReturnValueLastModified() throws Exception {
long timestamp = currentTime.toEpochMilli();
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path").ifModifiedSince(timestamp));

ResponseEntity<String> entity = ok().lastModified(oneMinAgo.toEpochMilli()).body("body");
ResponseEntity<String> entity = ResponseEntity.ok().lastModified(oneMinAgo.toEpochMilli()).body("body");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(entity, returnType);
this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
Expand All @@ -252,7 +252,7 @@ public void handleReturnValueEtag() throws Exception {
String etagValue = "\"deadb33f8badf00d\"";
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path").ifNoneMatch(etagValue));

ResponseEntity<String> entity = ok().eTag(etagValue).body("body");
ResponseEntity<String> entity = ResponseEntity.ok().eTag(etagValue).body("body");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(entity, returnType);
this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
Expand All @@ -264,7 +264,7 @@ public void handleReturnValueEtag() throws Exception {
public void handleReturnValueEtagInvalidIfNoneMatch() throws Exception {
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path").ifNoneMatch("unquoted"));

ResponseEntity<String> entity = ok().eTag("\"deadb33f8badf00d\"").body("body");
ResponseEntity<String> entity = ResponseEntity.ok().eTag("\"deadb33f8badf00d\"").body("body");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(entity, returnType);
this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
Expand All @@ -285,7 +285,7 @@ public void handleReturnValueETagAndLastModified() throws Exception {
.ifModifiedSince(currentTime.toEpochMilli())
);

ResponseEntity<String> entity = ok().eTag(eTag).lastModified(oneMinAgo.toEpochMilli()).body("body");
ResponseEntity<String> entity = ResponseEntity.ok().eTag(eTag).lastModified(oneMinAgo.toEpochMilli()).body("body");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(entity, returnType);
this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
Expand All @@ -306,7 +306,7 @@ public void handleReturnValueChangedETagAndLastModified() throws Exception {
.ifModifiedSince(currentTime.toEpochMilli())
);

ResponseEntity<String> entity = ok().eTag(newEtag).lastModified(oneMinAgo.toEpochMilli()).body("body");
ResponseEntity<String> entity = ResponseEntity.ok().eTag(newEtag).lastModified(oneMinAgo.toEpochMilli()).body("body");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(entity, returnType);
this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
Expand All @@ -320,7 +320,7 @@ public void handleMonoWithWildcardBodyType() throws Exception {
exchange.getAttributes().put(PRODUCIBLE_MEDIA_TYPES_ATTRIBUTE, Collections.singleton(APPLICATION_JSON));

MethodParameter type = on(TestController.class).resolveReturnType(Mono.class, ResponseEntity.class);
HandlerResult result = new HandlerResult(new TestController(), Mono.just(ok().body("body")), type);
HandlerResult result = new HandlerResult(new TestController(), Mono.just(ResponseEntity.ok().body("body")), type);

this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));

Expand Down Expand Up @@ -399,7 +399,7 @@ public void handleWithProducibleContentTypeShouldFailWithServerError() {
}

@Test // gh-26212
public void handleWithObjectMapperByTypeRegistration() throws Exception {
public void handleWithObjectMapperByTypeRegistration() {
MediaType halFormsMediaType = MediaType.parseMediaType("application/prs.hal-forms+json");
MediaType halMediaType = MediaType.parseMediaType("application/hal+json");

Expand Down Expand Up @@ -429,6 +429,22 @@ public void handleWithObjectMapperByTypeRegistration() throws Exception {
"}");
}

@Test // gh-24539
public void malformedAcceptHeader() {
ResponseEntity<String> value = ResponseEntity.badRequest().body("Foo");
MethodParameter returnType = on(TestController.class).resolveReturnType(entity(String.class));
HandlerResult result = handlerResult(value, returnType);
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path").header("Accept", "null"));

this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));
MockServerHttpResponse response = exchange.getResponse();
response.setComplete().block();

assertThat(response.getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
assertThat(response.getHeaders().getContentType()).isNull();
assertResponseBodyIsEmpty(exchange);
}


private void testHandle(Object returnValue, MethodParameter returnType) {
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -213,7 +213,21 @@ protected <T> void writeWithMessageConverters(@Nullable T value, MethodParameter
}
else {
HttpServletRequest request = inputMessage.getServletRequest();
List<MediaType> acceptableTypes = getAcceptableMediaTypes(request);
List<MediaType> acceptableTypes;
try {
acceptableTypes = getAcceptableMediaTypes(request);
}
catch (HttpMediaTypeNotAcceptableException ex) {
int series = outputMessage.getServletResponse().getStatus() / 100;
if (body == null || series == 4 || series == 5) {
if (logger.isDebugEnabled()) {
logger.debug("Ignoring error response content (if any). " + ex);
}
logger.debug(ex.getMessage());
return;
}
throw ex;
}
List<MediaType> producibleTypes = getProducibleMediaTypes(request, valueType, targetType);

if (body != null && producibleTypes.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -209,7 +209,6 @@ public void handleReturnValueCharSequence() throws Exception {

@Test // SPR-13423
public void handleReturnValueWithETagAndETagFilter() throws Exception {

String eTagValue = "\"deadb33f8badf00d\"";
String content = "body";

Expand Down Expand Up @@ -242,6 +241,25 @@ public void handleReturnValueWithETagAndETagFilter() throws Exception {
assertThat(this.servletResponse.getContentAsString()).isEqualTo(content);
}

@Test // gh-24539
public void handleReturnValueWithMalformedAcceptHeader() throws Exception {
webRequest.getNativeRequest(MockHttpServletRequest.class).addHeader("Accept", "null");

List<HttpMessageConverter<?>>converters = new ArrayList<>();
converters.add(new ByteArrayHttpMessageConverter());
converters.add(new StringHttpMessageConverter());

Method method = getClass().getDeclaredMethod("handle");
MethodParameter returnType = new MethodParameter(method, -1);
ResponseEntity<String> returnValue = ResponseEntity.badRequest().body("Foo");

HttpEntityMethodProcessor processor = new HttpEntityMethodProcessor(converters);
processor.handleReturnValue(returnValue, returnType, mavContainer, webRequest);

assertThat(servletResponse.getStatus()).isEqualTo(400);
assertThat(servletResponse.getHeader("Content-Type")).isNull();
assertThat(servletResponse.getContentAsString()).isEmpty();
}

@SuppressWarnings("unused")
private void handle(HttpEntity<List<SimpleBean>> arg1, HttpEntity<SimpleBean> arg2) {
Expand Down

0 comments on commit 00177ea

Please sign in to comment.