Skip to content

Commit

Permalink
Add step annotate and improve AADHandleConditionalAccessFilter for Co…
Browse files Browse the repository at this point in the history
…nditional Access. (Azure#19681)
  • Loading branch information
han-gao authored Mar 8, 2021
1 parent e9d211e commit ce0f46a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String registra
request.setAttribute(oboAuthorizedClientAttributeName, (T) oAuth2AuthorizedClient);
return (T) oAuth2AuthorizedClient;
} catch (ExecutionException exception) {
// Handle conditional access policy for obo flow.
// Handle conditional access policy, step 1.
// A user interaction is required, but we are in a web API, and therefore, we need to report back to the
// client through a 'WWW-Authenticate' header https://tools.ietf.org/html/rfc6750#section-3.1
Optional.of(exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

package com.azure.spring.aad.webapp;

import com.azure.spring.aad.AADClientRegistrationRepository;
import com.azure.spring.autoconfigure.aad.Constants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.reactive.function.client.WebClientResponseException;
Expand All @@ -32,11 +32,11 @@
*/
public class AADHandleConditionalAccessFilter extends OncePerRequestFilter {

private static final Logger LOGGER = LoggerFactory.getLogger(AADHandleConditionalAccessFilter.class);

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws IOException, ServletException {
// Handle conditional access policy, step 2.
try {
filterChain.doFilter(request, response);
} catch (Exception exception) {
Expand All @@ -53,18 +53,19 @@ protected void doFilterInternal(HttpServletRequest request, HttpServletResponse
if (authParameters != null && authParameters.containsKey(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS)) {
request.getSession().setAttribute(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS,
authParameters.get(Constants.CONDITIONAL_ACCESS_POLICY_CLAIMS));
response.setStatus(302);
try {
response.sendRedirect(Constants.DEFAULT_AUTHORITY_ENDPOINT_URI);
} catch (IOException e) {
LOGGER.error("Failed to redirect at this response.", exception);
}
return;
// OAuth2AuthorizationRequestRedirectFilter will catch this exception to re-authorize.
throw new ClientAuthorizationRequiredException(AADClientRegistrationRepository.AZURE_CLIENT_REGISTRATION_ID);
}
throw exception;
}
}

/**
* Get claims filed form the header to re-authorize.
*
* @param wwwAuthenticateHeader httpHeader
* @return authParametersMap
*/
private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
return Stream.of(wwwAuthenticateHeader)
.filter(header -> !StringUtils.isEmpty(header))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
Expand Down Expand Up @@ -68,6 +69,9 @@ public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String id,

if (repo.isClientNeedConsentWhenLogin(id)) {
OAuth2AuthorizedClient azureClient = loadAuthorizedClient(getAzureClientId(), principal, request);
if (azureClient == null) {
throw new ClientAuthorizationRequiredException(AADClientRegistrationRepository.AZURE_CLIENT_REGISTRATION_ID);
}
OAuth2AuthorizedClient fakeAuthzClient = createFakeAuthzClient(azureClient, id, principal);
OAuth2AuthorizationContext.Builder contextBuilder =
OAuth2AuthorizationContext.withAuthorizedClient(fakeAuthzClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.web.logout.OidcClientInitiatedLogoutSuccessHandler;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.web.access.ExceptionTranslationFilter;
import org.springframework.security.web.authentication.logout.LogoutSuccessHandler;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -51,7 +51,7 @@ protected void configure(HttpSecurity http) throws Exception {
.logout()
.logoutSuccessHandler(oidcLogoutSuccessHandler())
.and()
.addFilterBefore(new AADHandleConditionalAccessFilter(), ExceptionTranslationFilter.class);
.addFilterAfter(new AADHandleConditionalAccessFilter(), OAuth2AuthorizationRequestRedirectFilter.class);
// @formatter:off
}

Expand Down

0 comments on commit ce0f46a

Please sign in to comment.