Skip to content

Commit

Permalink
Introduce Customizable AuthorizationFailureHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
leewin12 authored and sjohnr committed Mar 1, 2024
1 parent bd345fb commit 07ac0b6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,13 +25,15 @@

import org.springframework.core.log.LogMessage;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
import org.springframework.security.web.savedrequest.RequestCache;
import org.springframework.security.web.util.ThrowableAnalyzer;
Expand Down Expand Up @@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt

private RequestCache requestCache = new HttpSessionRequestCache();

private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;

/**
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
* parameters.
Expand Down Expand Up @@ -163,6 +167,18 @@ public final void setRequestCache(RequestCache requestCache) {
this.requestCache = requestCache;
}

/**
* Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
* the Authorization Server's Authorization Endpoint.
* @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
* to handle errors redirecting to the Authorization Server's Authorization Endpoint
* @since 6.3
*/
public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
this.authenticationFailureHandler = authenticationFailureHandler;
}

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
Expand All @@ -174,7 +190,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
}
}
catch (Exception ex) {
this.unsuccessfulRedirectForAuthorization(request, response, ex);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
return;
}
try {
Expand All @@ -199,7 +216,8 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
this.sendRedirectForAuthorization(request, response, authorizationRequest);
}
catch (Exception failed) {
this.unsuccessfulRedirectForAuthorization(request, response, failed);
AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
}
return;
}
Expand All @@ -223,9 +241,10 @@ private void sendRedirectForAuthorization(HttpServletRequest request, HttpServle
}

private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
Exception ex) throws IOException {
LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
AuthenticationException ex) throws IOException {
Throwable cause = ex.getCause();
LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
// Log an invalid registrationId at WARN level to allow these errors to be
// tuned separately from other errors
this.logger.warn(message, ex);
Expand All @@ -250,4 +269,12 @@ protected void initExtractorMap() {

}

private static final class OAuth2AuthorizationRequestException extends AuthenticationException {

OAuth2AuthorizationRequestException(Throwable cause) {
super(cause.getMessage(), cause);
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -119,6 +119,11 @@ public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentExcepti
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
}

@Test
public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
}

@Test
public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
String requestUri = "/path";
Expand All @@ -144,6 +149,31 @@ public void doFilterWhenAuthorizationRequestWithInvalidClientThenStatusInternalS
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}

@Test
public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+ this.registration1.getRegistrationId() + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
Throwable cause = ex.getCause();
if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
}
else {
response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
}
});
this.filter.doFilter(request, response, filterChain);
verifyNoMoreInteractions(filterChain);
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
}

@Test
public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
Expand Down

0 comments on commit 07ac0b6

Please sign in to comment.