Skip to content

Commit

Permalink
Support content negotiation for RFC 7807
Browse files Browse the repository at this point in the history
Closes gh-28189
  • Loading branch information
rstoyanchev committed May 9, 2022
1 parent f3fd8f9 commit 78ab4d7
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,6 +38,7 @@
import org.springframework.core.codec.Hints;
import org.springframework.http.HttpLogging;
import org.springframework.http.MediaType;
import org.springframework.http.ProblemDetail;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -89,15 +90,24 @@ public abstract class Jackson2CodecSupport {

private final List<MimeType> mimeTypes;

private final List<MimeType> problemDetailMimeTypes;


/**
* Constructor with a Jackson {@link ObjectMapper} to use.
*/
protected Jackson2CodecSupport(ObjectMapper objectMapper, MimeType... mimeTypes) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
this.defaultObjectMapper = objectMapper;
this.mimeTypes = !ObjectUtils.isEmpty(mimeTypes) ?
List.of(mimeTypes) : DEFAULT_MIME_TYPES;
this.mimeTypes = (!ObjectUtils.isEmpty(mimeTypes) ? List.of(mimeTypes) : DEFAULT_MIME_TYPES);
this.problemDetailMimeTypes = initProblemDetailMediaTypes(this.mimeTypes);
}

private static List<MimeType> initProblemDetailMediaTypes(List<MimeType> supportedMimeTypes) {
List<MimeType> mimeTypes = new ArrayList<>();
mimeTypes.add(MediaType.APPLICATION_PROBLEM_JSON);
mimeTypes.addAll(supportedMimeTypes);
return Collections.unmodifiableList(mimeTypes);
}


Expand Down Expand Up @@ -180,7 +190,10 @@ protected List<MimeType> getMimeTypes(ResolvableType elementType) {
result.addAll(entry.getValue().keySet());
}
}
return (CollectionUtils.isEmpty(result) ? getMimeTypes() : result);
if (!CollectionUtils.isEmpty(result)) {
return result;
}
return (ProblemDetail.class.isAssignableFrom(elementClass) ? this.problemDetailMimeTypes : getMimeTypes());
}

protected boolean supportsMimeType(@Nullable MimeType mimeType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -100,12 +100,12 @@ protected AbstractHttpMessageConverter(Charset defaultCharset, MediaType... supp
*/
public void setSupportedMediaTypes(List<MediaType> supportedMediaTypes) {
Assert.notEmpty(supportedMediaTypes, "MediaType List must not be empty");
this.supportedMediaTypes = new ArrayList<>(supportedMediaTypes);
this.supportedMediaTypes = Collections.unmodifiableList(new ArrayList<>(supportedMediaTypes));
}

@Override
public List<MediaType> getSupportedMediaTypes() {
return Collections.unmodifiableList(this.supportedMediaTypes);
return this.supportedMediaTypes;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.ProblemDetail;
import org.springframework.http.converter.AbstractGenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConversionException;
import org.springframework.http.converter.HttpMessageConverter;
Expand Down Expand Up @@ -92,6 +93,8 @@ public abstract class AbstractJackson2HttpMessageConverter extends AbstractGener
}


private List<MediaType> problemDetailMediaTypes = Collections.singletonList(MediaType.APPLICATION_PROBLEM_JSON);

protected ObjectMapper defaultObjectMapper;

@Nullable
Expand Down Expand Up @@ -122,6 +125,19 @@ protected AbstractJackson2HttpMessageConverter(ObjectMapper objectMapper, MediaT
}


@Override
public void setSupportedMediaTypes(List<MediaType> supportedMediaTypes) {
this.problemDetailMediaTypes = initProblemDetailMediaTypes(supportedMediaTypes);
super.setSupportedMediaTypes(supportedMediaTypes);
}

private List<MediaType> initProblemDetailMediaTypes(List<MediaType> supportedMediaTypes) {
List<MediaType> mediaTypes = new ArrayList<>();
mediaTypes.add(MediaType.APPLICATION_PROBLEM_JSON);
mediaTypes.addAll(supportedMediaTypes);
return Collections.unmodifiableList(mediaTypes);
}

/**
* Configure the main {@code ObjectMapper} to use for Object conversion.
* If not set, a default {@link ObjectMapper} instance is created.
Expand Down Expand Up @@ -198,7 +214,11 @@ public List<MediaType> getSupportedMediaTypes(Class<?> clazz) {
result.addAll(entry.getValue().keySet());
}
}
return (CollectionUtils.isEmpty(result) ? getSupportedMediaTypes() : result);
if (!CollectionUtils.isEmpty(result)) {
return result;
}
return (ProblemDetail.class.isAssignableFrom(clazz) ?
this.problemDetailMediaTypes : getSupportedMediaTypes());
}

private Map<Class<?>, Map<MediaType, ObjectMapper>> getObjectMapperRegistrations() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,25 @@ protected ReactiveAdapter getAdapter(HandlerResult result) {
}

/**
* Select the best media type for the current request through a content negotiation algorithm.
* Select the best media type for the current request through a content
* negotiation algorithm.
* @param exchange the current request
* @param producibleTypesSupplier the media types that can be produced for the current request
* @param producibleTypesSupplier the media types producible for the request
* @return the selected media type, or {@code null} if none
*/
@Nullable
protected MediaType selectMediaType(ServerWebExchange exchange, Supplier<List<MediaType>> producibleTypesSupplier) {
return selectMediaType(exchange, producibleTypesSupplier, getAcceptableTypes(exchange));
}

/**
* Variant of {@link #selectMediaType(ServerWebExchange, Supplier)} with a
* given list of requested (acceptable) media types.
*/
@Nullable
protected MediaType selectMediaType(
ServerWebExchange exchange, Supplier<List<MediaType>> producibleTypesSupplier) {
ServerWebExchange exchange, Supplier<List<MediaType>> producibleTypesSupplier,
List<MediaType> acceptableTypes) {

MediaType contentType = exchange.getResponse().getHeaders().getContentType();
if (contentType != null && contentType.isConcrete()) {
Expand All @@ -128,7 +139,6 @@ protected MediaType selectMediaType(
return contentType;
}

List<MediaType> acceptableTypes = getAcceptableTypes(exchange);
List<MediaType> producibleTypes = getProducibleTypes(exchange, producibleTypesSupplier);

Set<MediaType> compatibleMediaTypes = new LinkedHashSet<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.web.reactive.result.method.annotation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

Expand All @@ -32,6 +33,7 @@
import org.springframework.core.codec.Hints;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.ProblemDetail;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.lang.Nullable;
Expand All @@ -57,6 +59,9 @@ public abstract class AbstractMessageWriterResultHandler extends HandlerResultHa

private final List<HttpMessageWriter<?>> messageWriters;

private final List<MediaType> problemMediaTypes =
Arrays.asList(MediaType.APPLICATION_PROBLEM_JSON, MediaType.APPLICATION_PROBLEM_XML);


/**
* Constructor with {@link HttpMessageWriter HttpMessageWriters} and a
Expand Down Expand Up @@ -161,6 +166,12 @@ protected Mono<Void> writeBody(@Nullable Object body, MethodParameter bodyParame
}
throw ex;
}

// Fall back on RFC 7807 format for ProblemDetail
if (bestMediaType == null && elementType.toClass().equals(ProblemDetail.class)) {
bestMediaType = selectMediaType(exchange, () -> getMediaTypesFor(elementType), this.problemMediaTypes);
}

if (bestMediaType != null) {
String logPrefix = exchange.getLogPrefix();
if (logger.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,16 @@

package org.springframework.web.reactive.result.method.annotation;

import java.net.URI;
import java.util.List;

import reactor.core.publisher.Mono;

import org.springframework.core.MethodParameter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.ProblemDetail;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.reactive.HandlerResult;
Expand Down Expand Up @@ -83,6 +86,13 @@ public boolean supports(HandlerResult result) {
public Mono<Void> handleResult(ServerWebExchange exchange, HandlerResult result) {
Object body = result.getReturnValue();
MethodParameter bodyTypeParameter = result.getReturnTypeSource();
if (body instanceof ProblemDetail detail) {
exchange.getResponse().setStatusCode(HttpStatusCode.valueOf(detail.getStatus()));
if (detail.getInstance() == null) {
URI path = URI.create(exchange.getRequest().getPath().value());
detail.setInstance(path);
}
}
return writeBody(body, bodyTypeParameter, exchange);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.web.reactive.result.method.annotation;

import java.lang.reflect.Method;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -25,23 +26,31 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.core.codec.ByteBufferEncoder;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ProblemDetail;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.HttpMessageWriter;
import org.springframework.http.codec.ResourceHttpMessageWriter;
import org.springframework.http.codec.json.Jackson2JsonEncoder;
import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.HandlerResult;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
import org.springframework.web.testfixture.server.MockServerWebExchange;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.web.testfixture.http.server.reactive.MockServerHttpRequest.get;
import static org.springframework.web.testfixture.method.ResolvableMethod.on;

/**
Expand Down Expand Up @@ -82,7 +91,7 @@ public void supports() {
testSupports(controller, method);

method = on(TestController.class).annotNotPresent(ResponseBody.class).resolveMethod("doWork");
HandlerResult handlerResult = getHandlerResult(controller, method);
HandlerResult handlerResult = getHandlerResult(controller, null, method);
assertThat(this.resultHandler.supports(handlerResult)).isFalse();
}

Expand All @@ -105,20 +114,60 @@ public void supportsRestController() {
}

private void testSupports(Object controller, Method method) {
HandlerResult handlerResult = getHandlerResult(controller, method);
HandlerResult handlerResult = getHandlerResult(controller, null, method);
assertThat(this.resultHandler.supports(handlerResult)).isTrue();
}

private HandlerResult getHandlerResult(Object controller, Method method) {
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
return new HandlerResult(handlerMethod, null, handlerMethod.getReturnType());
@Test
void problemDetailContentNegotiation() {

// Default
MockServerWebExchange exchange = MockServerWebExchange.from(get("/path"));
testProblemDetailMediaType(exchange, MediaType.APPLICATION_PROBLEM_JSON);

// JSON requested
exchange = MockServerWebExchange.from(get("/path").accept(MediaType.APPLICATION_JSON));
testProblemDetailMediaType(exchange, MediaType.APPLICATION_JSON);

// No match fallback
exchange = MockServerWebExchange.from(get("/path").accept(MediaType.APPLICATION_PDF));
testProblemDetailMediaType(exchange, MediaType.APPLICATION_PROBLEM_JSON);
}

private void testProblemDetailMediaType(MockServerWebExchange exchange, MediaType expectedMediaType) {
ProblemDetail problemDetail = ProblemDetail.forStatus(HttpStatus.BAD_REQUEST);

Method method = on(TestRestController.class).returning(ProblemDetail.class).resolveMethod();
HandlerResult result = getHandlerResult(new TestRestController(), problemDetail, method);

this.resultHandler.handleResult(exchange, result).block(Duration.ofSeconds(5));

assertThat(exchange.getResponse().getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
assertThat(exchange.getResponse().getHeaders().getContentType()).isEqualTo(expectedMediaType);
assertResponseBody(exchange,
"{\"type\":\"about:blank\"," +
"\"title\":\"Bad Request\"," +
"\"status\":400," +
"\"detail\":null," +
"\"instance\":\"/path\"}");
}

@Test
public void defaultOrder() {
assertThat(this.resultHandler.getOrder()).isEqualTo(100);
}

private HandlerResult getHandlerResult(Object controller, @Nullable Object returnValue, Method method) {
HandlerMethod handlerMethod = new HandlerMethod(controller, method);
return new HandlerResult(handlerMethod, returnValue, handlerMethod.getReturnType());
}

private void assertResponseBody(MockServerWebExchange exchange, @Nullable String responseBody) {
StepVerifier.create(exchange.getResponse().getBody())
.consumeNextWith(buf -> assertThat(buf.toString(UTF_8)).isEqualTo(responseBody))
.expectComplete()
.verify();
}


@RestController
Expand All @@ -142,6 +191,11 @@ public Single<String> handleToSingleString() {
public Completable handleToCompletable() {
return null;
}

public ProblemDetail handleToProblemDetail() {
return null;
}

}


Expand Down
Loading

0 comments on commit 78ab4d7

Please sign in to comment.