Skip to content

Commit

Permalink
Merge branch '6.0.x'
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Nov 15, 2023
2 parents f15b8b9 + 05c3ffb commit 3a70c71
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ public static String toString(Collection<? extends MimeType> mimeTypes) {
*/
public static <T extends MimeType> void sortBySpecificity(List<T> mimeTypes) {
Assert.notNull(mimeTypes, "'mimeTypes' must not be null");
Assert.isTrue(mimeTypes.size() <= 50, "Too many elements");
if (mimeTypes.size() >= 50) {
throw new InvalidMimeTypeException(mimeTypes.toString(), "Too many elements");
}

bubbleSort(mimeTypes, MimeType::isLessSpecific);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,9 @@ void sortBySpecificity() {
MimeType audioWave = new MimeType("audio", "wave");
MimeType audioBasicLevel = new MimeType("audio", "basic", singletonMap("level", "1"));

List<MimeType> mimeTypes = new ArrayList<>(List.of(MimeTypeUtils.ALL, audio, audioWave, audioBasic,
audioBasicLevel));
List<MimeType> mimeTypes = new ArrayList<>(
List.of(MimeTypeUtils.ALL, audio, audioWave, audioBasic, audioBasicLevel));

MimeTypeUtils.sortBySpecificity(mimeTypes);

assertThat(mimeTypes).containsExactly(audioWave, audioBasicLevel, audioBasic, audio, MimeTypeUtils.ALL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.Filter;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.InitializingBean;
Expand Down Expand Up @@ -87,6 +91,8 @@
public class HandlerMappingIntrospector
implements CorsConfigurationSource, ApplicationContextAware, InitializingBean {

private static final Log logger = LogFactory.getLog(HandlerMappingIntrospector.class.getName());

private static final String CACHED_RESULT_ATTRIBUTE =
HandlerMappingIntrospector.class.getName() + ".CachedResult";

Expand All @@ -99,6 +105,8 @@ public class HandlerMappingIntrospector

private Map<HandlerMapping, PathPatternMatchableHandlerMapping> pathPatternMappings = Collections.emptyMap();

private final CacheResultLogHelper cacheLogHelper = new CacheResultLogHelper();


@Override
public void setApplicationContext(ApplicationContext applicationContext) {
Expand Down Expand Up @@ -167,6 +175,36 @@ public List<HandlerMapping> getHandlerMappings() {
}


/**
* {@link Filter} that looks up the {@code MatchableHandlerMapping} and
* {@link CorsConfiguration} for the request proactively before delegating
* to the rest of the chain, caching the result in a request attribute, and
* restoring it after the chain returns.
* <p><strong>Note:</strong> Applications that rely on Spring Security do
* not use this component directly and should not deploy the filter instead
* allowing Spring Security to do it. Other custom security layers used in
* place of Spring Security that also rely on {@code HandlerMappingIntrospector}
* should deploy this filter ahead of other filters where lookups are
* performed, and should also make sure the filter is configured to handle
* all dispatcher types.
* @return the Filter instance to use
* @since 6.0.14
*/
public Filter createCacheFilter() {
return (request, response, chain) -> {
HandlerMappingIntrospector.CachedResult previous = setCache((HttpServletRequest) request);
try {
chain.doFilter(request, response);
}
catch (Exception ex) {
throw new ServletException("HandlerMapping introspection failed", ex);
}
finally {
resetCache(request, previous);
}
};
}

/**
* Perform a lookup and save the {@link CachedResult} as a request attribute.
* This method can be invoked from a filter before subsequent calls to
Expand All @@ -178,18 +216,18 @@ public List<HandlerMapping> getHandlerMappings() {
* @since 6.0.14
*/
@Nullable
public CachedResult setCache(HttpServletRequest request) throws ServletException {
CachedResult previous = getAttribute(request);
private CachedResult setCache(HttpServletRequest request) throws ServletException {
CachedResult previous = (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE);
if (previous == null || !previous.matches(request)) {
try {
HttpServletRequest wrapped = new AttributesPreservingRequest(request);
CachedResult cachedResult = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> {
CachedResult result = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> {
MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrapped);
CorsConfiguration corsConfig = getCorsConfiguration(wrapped, executionChain);
return new CachedResult(request, matchableMapping, corsConfig);
});
request.setAttribute(CACHED_RESULT_ATTRIBUTE,
cachedResult != null ? cachedResult : new CachedResult(request, null, null));
(result != null ? result : new CachedResult(request, null, null)));
}
catch (Throwable ex) {
throw new ServletException("HandlerMapping introspection failed", ex);
Expand All @@ -203,7 +241,7 @@ public CachedResult setCache(HttpServletRequest request) throws ServletException
* a filter after delegating to the rest of the chain.
* @since 6.0.14
*/
public void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) {
private void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) {
request.setAttribute(CACHED_RESULT_ATTRIBUTE, cachedResult);
}

Expand All @@ -218,10 +256,11 @@ public void resetCache(ServletRequest request, @Nullable CachedResult cachedResu
*/
@Nullable
public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception {
CachedResult cachedResult = getCachedResultFor(request);
if (cachedResult != null) {
return cachedResult.getHandlerMapping();
CachedResult result = CachedResult.forRequest(request);
if (result != null) {
return result.getHandlerMapping();
}
this.cacheLogHelper.logHandlerMappingCacheMiss(request);
HttpServletRequest requestToUse = new AttributesPreservingRequest(request);
return doWithHandlerMapping(requestToUse, false,
(mapping, executionChain) -> createMatchableHandlerMapping(mapping, requestToUse));
Expand All @@ -245,10 +284,11 @@ private MatchableHandlerMapping createMatchableHandlerMapping(HandlerMapping map
@Override
@Nullable
public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
CachedResult cachedResult = getCachedResultFor(request);
if (cachedResult != null) {
return cachedResult.getCorsConfig();
CachedResult result = CachedResult.forRequest(request);
if (result != null) {
return result.getCorsConfig();
}
this.cacheLogHelper.logCorsConfigCacheMiss(request);
try {
boolean ignoreException = true;
AttributesPreservingRequest requestToUse = new AttributesPreservingRequest(request);
Expand Down Expand Up @@ -312,28 +352,14 @@ private <T> T doWithHandlerMapping(
return null;
}

/**
* Return a {@link CachedResult} that matches the given request.
*/
@Nullable
private CachedResult getCachedResultFor(HttpServletRequest request) {
CachedResult result = getAttribute(request);
return (result != null && result.matches(request) ? result : null);
}

@Nullable
private static CachedResult getAttribute(HttpServletRequest request) {
return (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE);
}


/**
* Container for a {@link MatchableHandlerMapping} and {@link CorsConfiguration}
* for a given request identified by dispatcher type and requestURI.
* for a given request matched by dispatcher type and requestURI.
* @since 6.0.14
*/
@SuppressWarnings("serial")
public static final class CachedResult {
private static final class CachedResult {

private final DispatcherType dispatcherType;

Expand Down Expand Up @@ -371,7 +397,53 @@ public CorsConfiguration getCorsConfig() {

@Override
public String toString() {
return "CacheValue " + this.dispatcherType + " '" + this.requestURI + "'";
return "CachedResult for " + this.dispatcherType + " dispatch to '" + this.requestURI + "'";
}


/**
* Return a {@link CachedResult} that matches the given request.
*/
@Nullable
public static CachedResult forRequest(HttpServletRequest request) {
CachedResult result = (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE);
return (result != null && result.matches(request) ? result : null);
}

}


private static class CacheResultLogHelper {

private final Map<String, AtomicInteger> counters =
Map.of("MatchableHandlerMapping", new AtomicInteger(), "CorsConfiguration", new AtomicInteger());

public void logHandlerMappingCacheMiss(HttpServletRequest request) {
logCacheMiss("MatchableHandlerMapping", request);
}

public void logCorsConfigCacheMiss(HttpServletRequest request) {
logCacheMiss("CorsConfiguration", request);
}

private void logCacheMiss(String label, HttpServletRequest request) {
AtomicInteger counter = this.counters.get(label);
Assert.notNull(counter, "Expected '" + label + "' counter.");

String message = getLogMessage(label, request);

if (logger.isWarnEnabled() && counter.getAndIncrement() == 0) {
logger.warn(message + " This is logged once only at WARN level, and every time at TRACE.");
}
else if (logger.isTraceEnabled()) {
logger.trace("No CachedResult, performing " + label + " lookup instead.");
}
}

private static String getLogMessage(String label, HttpServletRequest request) {
return "Cache miss for " + request.getDispatcherType() + " dispatch to '" + request.getRequestURI() + "' " +
"(previous " + request.getAttribute(CACHED_RESULT_ATTRIBUTE) + "). " +
"Performing " + label + " lookup.";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.servlet.function.support.RouterFunctionMapping;
import org.springframework.web.servlet.handler.HandlerMappingIntrospector.CachedResult;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
import org.springframework.web.testfixture.servlet.MockFilterChain;
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
Expand Down Expand Up @@ -217,7 +216,7 @@ void cacheFilter() throws Exception {
MockHttpServletResponse response = new MockHttpServletResponse();

MockFilterChain filterChain = new MockFilterChain(
new TestServlet(), new CacheResultFilter(introspector), new AuthFilter(introspector, corsConfig));
new TestServlet(), introspector.createCacheFilter(), new AuthFilter(introspector, corsConfig));

filterChain.doFilter(request, response);

Expand All @@ -241,10 +240,10 @@ void cacheFilterWithNestedDispatch() throws Exception {

MockFilterChain filterChain = new MockFilterChain(
new TestServlet(),
new CacheResultFilter(introspector),
introspector.createCacheFilter(),
new AuthFilter(introspector, corsConfig1),
(req, res, chain) -> chain.doFilter(new MockHttpServletRequest("GET", "/2"), res),
new CacheResultFilter(introspector),
introspector.createCacheFilter(),
new AuthFilter(introspector, corsConfig2));

MockHttpServletResponse response = new MockHttpServletResponse();
Expand Down Expand Up @@ -372,32 +371,6 @@ public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
}


private static class CacheResultFilter implements Filter {

private final HandlerMappingIntrospector introspector;

private CacheResultFilter(HandlerMappingIntrospector introspector) {
this.introspector = introspector;
}

@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws ServletException {

CachedResult previousValue = this.introspector.setCache((HttpServletRequest) req);
try {
chain.doFilter(req, res);
}
catch (Exception ex) {
throw new ServletException("HandlerMapping introspection failed", ex);
}
finally {
this.introspector.resetCache(req, previousValue);
}
}
}


private static class AuthFilter implements Filter {

private final HandlerMappingIntrospector introspector;
Expand Down

0 comments on commit 3a70c71

Please sign in to comment.