diff --git a/spring-core/src/main/java/org/springframework/util/MimeTypeUtils.java b/spring-core/src/main/java/org/springframework/util/MimeTypeUtils.java index 27bbc609fc8f..fb23809994dd 100644 --- a/spring-core/src/main/java/org/springframework/util/MimeTypeUtils.java +++ b/spring-core/src/main/java/org/springframework/util/MimeTypeUtils.java @@ -362,7 +362,9 @@ public static String toString(Collection mimeTypes) { */ public static void sortBySpecificity(List mimeTypes) { Assert.notNull(mimeTypes, "'mimeTypes' must not be null"); - Assert.isTrue(mimeTypes.size() <= 50, "Too many elements"); + if (mimeTypes.size() >= 50) { + throw new InvalidMimeTypeException(mimeTypes.toString(), "Too many elements"); + } bubbleSort(mimeTypes, MimeType::isLessSpecific); } diff --git a/spring-core/src/test/java/org/springframework/util/MimeTypeTests.java b/spring-core/src/test/java/org/springframework/util/MimeTypeTests.java index 90b1fea3fa20..e582bc60ec40 100644 --- a/spring-core/src/test/java/org/springframework/util/MimeTypeTests.java +++ b/spring-core/src/test/java/org/springframework/util/MimeTypeTests.java @@ -451,8 +451,9 @@ void sortBySpecificity() { MimeType audioWave = new MimeType("audio", "wave"); MimeType audioBasicLevel = new MimeType("audio", "basic", singletonMap("level", "1")); - List mimeTypes = new ArrayList<>(List.of(MimeTypeUtils.ALL, audio, audioWave, audioBasic, - audioBasicLevel)); + List mimeTypes = new ArrayList<>( + List.of(MimeTypeUtils.ALL, audio, audioWave, audioBasic, audioBasicLevel)); + MimeTypeUtils.sortBySpecificity(mimeTypes); assertThat(mimeTypes).containsExactly(audioWave, audioBasicLevel, audioBasic, audio, MimeTypeUtils.ALL); diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java index cd5eef8f5059..2847b62eaaa1 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/handler/HandlerMappingIntrospector.java @@ -24,14 +24,18 @@ import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.stream.Collectors; import jakarta.servlet.DispatcherType; +import jakarta.servlet.Filter; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequestWrapper; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.InitializingBean; @@ -87,6 +91,8 @@ public class HandlerMappingIntrospector implements CorsConfigurationSource, ApplicationContextAware, InitializingBean { + private static final Log logger = LogFactory.getLog(HandlerMappingIntrospector.class.getName()); + private static final String CACHED_RESULT_ATTRIBUTE = HandlerMappingIntrospector.class.getName() + ".CachedResult"; @@ -99,6 +105,8 @@ public class HandlerMappingIntrospector private Map pathPatternMappings = Collections.emptyMap(); + private final CacheResultLogHelper cacheLogHelper = new CacheResultLogHelper(); + @Override public void setApplicationContext(ApplicationContext applicationContext) { @@ -167,6 +175,36 @@ public List getHandlerMappings() { } + /** + * {@link Filter} that looks up the {@code MatchableHandlerMapping} and + * {@link CorsConfiguration} for the request proactively before delegating + * to the rest of the chain, caching the result in a request attribute, and + * restoring it after the chain returns. + *

Note: Applications that rely on Spring Security do + * not use this component directly and should not deploy the filter instead + * allowing Spring Security to do it. Other custom security layers used in + * place of Spring Security that also rely on {@code HandlerMappingIntrospector} + * should deploy this filter ahead of other filters where lookups are + * performed, and should also make sure the filter is configured to handle + * all dispatcher types. + * @return the Filter instance to use + * @since 6.0.14 + */ + public Filter createCacheFilter() { + return (request, response, chain) -> { + HandlerMappingIntrospector.CachedResult previous = setCache((HttpServletRequest) request); + try { + chain.doFilter(request, response); + } + catch (Exception ex) { + throw new ServletException("HandlerMapping introspection failed", ex); + } + finally { + resetCache(request, previous); + } + }; + } + /** * Perform a lookup and save the {@link CachedResult} as a request attribute. * This method can be invoked from a filter before subsequent calls to @@ -178,18 +216,18 @@ public List getHandlerMappings() { * @since 6.0.14 */ @Nullable - public CachedResult setCache(HttpServletRequest request) throws ServletException { - CachedResult previous = getAttribute(request); + private CachedResult setCache(HttpServletRequest request) throws ServletException { + CachedResult previous = (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE); if (previous == null || !previous.matches(request)) { try { HttpServletRequest wrapped = new AttributesPreservingRequest(request); - CachedResult cachedResult = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> { + CachedResult result = doWithHandlerMapping(wrapped, false, (mapping, executionChain) -> { MatchableHandlerMapping matchableMapping = createMatchableHandlerMapping(mapping, wrapped); CorsConfiguration corsConfig = getCorsConfiguration(wrapped, executionChain); return new CachedResult(request, matchableMapping, corsConfig); }); request.setAttribute(CACHED_RESULT_ATTRIBUTE, - cachedResult != null ? cachedResult : new CachedResult(request, null, null)); + (result != null ? result : new CachedResult(request, null, null))); } catch (Throwable ex) { throw new ServletException("HandlerMapping introspection failed", ex); @@ -203,7 +241,7 @@ public CachedResult setCache(HttpServletRequest request) throws ServletException * a filter after delegating to the rest of the chain. * @since 6.0.14 */ - public void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) { + private void resetCache(ServletRequest request, @Nullable CachedResult cachedResult) { request.setAttribute(CACHED_RESULT_ATTRIBUTE, cachedResult); } @@ -218,10 +256,11 @@ public void resetCache(ServletRequest request, @Nullable CachedResult cachedResu */ @Nullable public MatchableHandlerMapping getMatchableHandlerMapping(HttpServletRequest request) throws Exception { - CachedResult cachedResult = getCachedResultFor(request); - if (cachedResult != null) { - return cachedResult.getHandlerMapping(); + CachedResult result = CachedResult.forRequest(request); + if (result != null) { + return result.getHandlerMapping(); } + this.cacheLogHelper.logHandlerMappingCacheMiss(request); HttpServletRequest requestToUse = new AttributesPreservingRequest(request); return doWithHandlerMapping(requestToUse, false, (mapping, executionChain) -> createMatchableHandlerMapping(mapping, requestToUse)); @@ -245,10 +284,11 @@ private MatchableHandlerMapping createMatchableHandlerMapping(HandlerMapping map @Override @Nullable public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { - CachedResult cachedResult = getCachedResultFor(request); - if (cachedResult != null) { - return cachedResult.getCorsConfig(); + CachedResult result = CachedResult.forRequest(request); + if (result != null) { + return result.getCorsConfig(); } + this.cacheLogHelper.logCorsConfigCacheMiss(request); try { boolean ignoreException = true; AttributesPreservingRequest requestToUse = new AttributesPreservingRequest(request); @@ -312,28 +352,14 @@ private T doWithHandlerMapping( return null; } - /** - * Return a {@link CachedResult} that matches the given request. - */ - @Nullable - private CachedResult getCachedResultFor(HttpServletRequest request) { - CachedResult result = getAttribute(request); - return (result != null && result.matches(request) ? result : null); - } - - @Nullable - private static CachedResult getAttribute(HttpServletRequest request) { - return (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE); - } - /** * Container for a {@link MatchableHandlerMapping} and {@link CorsConfiguration} - * for a given request identified by dispatcher type and requestURI. + * for a given request matched by dispatcher type and requestURI. * @since 6.0.14 */ @SuppressWarnings("serial") - public static final class CachedResult { + private static final class CachedResult { private final DispatcherType dispatcherType; @@ -371,7 +397,53 @@ public CorsConfiguration getCorsConfig() { @Override public String toString() { - return "CacheValue " + this.dispatcherType + " '" + this.requestURI + "'"; + return "CachedResult for " + this.dispatcherType + " dispatch to '" + this.requestURI + "'"; + } + + + /** + * Return a {@link CachedResult} that matches the given request. + */ + @Nullable + public static CachedResult forRequest(HttpServletRequest request) { + CachedResult result = (CachedResult) request.getAttribute(CACHED_RESULT_ATTRIBUTE); + return (result != null && result.matches(request) ? result : null); + } + + } + + + private static class CacheResultLogHelper { + + private final Map counters = + Map.of("MatchableHandlerMapping", new AtomicInteger(), "CorsConfiguration", new AtomicInteger()); + + public void logHandlerMappingCacheMiss(HttpServletRequest request) { + logCacheMiss("MatchableHandlerMapping", request); + } + + public void logCorsConfigCacheMiss(HttpServletRequest request) { + logCacheMiss("CorsConfiguration", request); + } + + private void logCacheMiss(String label, HttpServletRequest request) { + AtomicInteger counter = this.counters.get(label); + Assert.notNull(counter, "Expected '" + label + "' counter."); + + String message = getLogMessage(label, request); + + if (logger.isWarnEnabled() && counter.getAndIncrement() == 0) { + logger.warn(message + " This is logged once only at WARN level, and every time at TRACE."); + } + else if (logger.isTraceEnabled()) { + logger.trace("No CachedResult, performing " + label + " lookup instead."); + } + } + + private static String getLogMessage(String label, HttpServletRequest request) { + return "Cache miss for " + request.getDispatcherType() + " dispatch to '" + request.getRequestURI() + "' " + + "(previous " + request.getAttribute(CACHED_RESULT_ATTRIBUTE) + "). " + + "Performing " + label + " lookup."; } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java index 94ff2671a693..600ae8b88655 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/handler/HandlerMappingIntrospectorTests.java @@ -52,7 +52,6 @@ import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.support.RouterFunctionMapping; -import org.springframework.web.servlet.handler.HandlerMappingIntrospector.CachedResult; import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping; import org.springframework.web.testfixture.servlet.MockFilterChain; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -217,7 +216,7 @@ void cacheFilter() throws Exception { MockHttpServletResponse response = new MockHttpServletResponse(); MockFilterChain filterChain = new MockFilterChain( - new TestServlet(), new CacheResultFilter(introspector), new AuthFilter(introspector, corsConfig)); + new TestServlet(), introspector.createCacheFilter(), new AuthFilter(introspector, corsConfig)); filterChain.doFilter(request, response); @@ -241,10 +240,10 @@ void cacheFilterWithNestedDispatch() throws Exception { MockFilterChain filterChain = new MockFilterChain( new TestServlet(), - new CacheResultFilter(introspector), + introspector.createCacheFilter(), new AuthFilter(introspector, corsConfig1), (req, res, chain) -> chain.doFilter(new MockHttpServletRequest("GET", "/2"), res), - new CacheResultFilter(introspector), + introspector.createCacheFilter(), new AuthFilter(introspector, corsConfig2)); MockHttpServletResponse response = new MockHttpServletResponse(); @@ -372,32 +371,6 @@ public CorsConfiguration getCorsConfiguration(HttpServletRequest request) { } - private static class CacheResultFilter implements Filter { - - private final HandlerMappingIntrospector introspector; - - private CacheResultFilter(HandlerMappingIntrospector introspector) { - this.introspector = introspector; - } - - @Override - public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) - throws ServletException { - - CachedResult previousValue = this.introspector.setCache((HttpServletRequest) req); - try { - chain.doFilter(req, res); - } - catch (Exception ex) { - throw new ServletException("HandlerMapping introspection failed", ex); - } - finally { - this.introspector.resetCache(req, previousValue); - } - } - } - - private static class AuthFilter implements Filter { private final HandlerMappingIntrospector introspector;