Skip to content

Commit

Permalink
Implemented code+tests for both imperative & reactive codelines for i…
Browse files Browse the repository at this point in the history
  • Loading branch information
mkheck committed May 3, 2019
1 parent d86550f commit 5a94d18
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand All @@ -42,6 +43,9 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;

/**
* An implementation of a {@link HandlerMethodArgumentResolver} that is capable
Expand Down Expand Up @@ -69,6 +73,10 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
new DefaultClientCredentialsTokenResponseClient();

private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);


/**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
*
Expand Down Expand Up @@ -105,18 +113,29 @@ public Object resolveArgument(MethodParameter parameter,
"@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
}

ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
if (clientRegistration == null) {
return null;
}

Authentication principal = SecurityContextHolder.getContext().getAuthentication();
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);

OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
clientRegistrationId, principal, servletRequest);
if (authorizedClient != null) {
return authorizedClient;
}
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
// MH TODO: Refresh token
}

ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
if (clientRegistration == null) {
return null;
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
if (hasTokenExpired(authorizedClient)) {
HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse);
}
}

return authorizedClient;
}

if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
Expand Down Expand Up @@ -172,6 +191,24 @@ private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistrati
return authorizedClient;
}

private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();

return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
}

/**
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ public OAuth2AuthorizedClientResolver(
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
*
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
Expand All @@ -80,6 +81,7 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie
/**
* If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token.
*
* @param clientRegistrationId the id to use
*/
public void setDefaultClientRegistrationId(String clientRegistrationId) {
Expand All @@ -89,6 +91,7 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
*
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
Expand All @@ -98,7 +101,7 @@ public void setClientCredentialsTokenResponseClient(
}

Mono<Request> createDefaultedRequest(String clientRegistrationId,
Authentication authentication, ServerWebExchange exchange) {
Authentication authentication, ServerWebExchange exchange) {
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
.switchIfEmpty(currentAuthentication());

Expand All @@ -124,14 +127,14 @@ Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {

private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}

Mono<OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
Expand All @@ -149,6 +152,7 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio

/**
* Attempts to load the client registration id from the current {@link Authentication}
*
* @return
*/
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
Expand Down Expand Up @@ -176,7 +180,7 @@ static class Request {
private final ServerWebExchange exchange;

public Request(String clientRegistrationId, Authentication authentication,
ServerWebExchange exchange) {
ServerWebExchange exchange) {
this.clientRegistrationId = clientRegistrationId;
this.authentication = authentication;
this.exchange = exchange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;

/**
Expand All @@ -57,6 +61,10 @@ class OAuth2AuthorizedClientResolver {

private String defaultClientRegistrationId;

private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);


public OAuth2AuthorizedClientResolver(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Expand All @@ -70,6 +78,7 @@ public OAuth2AuthorizedClientResolver(
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
*
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
Expand All @@ -80,6 +89,7 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie
/**
* If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token.
*
* @param clientRegistrationId the id to use
*/
public void setDefaultClientRegistrationId(String clientRegistrationId) {
Expand All @@ -89,6 +99,7 @@ public void setDefaultClientRegistrationId(String clientRegistrationId) {
/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
*
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
Expand All @@ -98,7 +109,7 @@ public void setClientCredentialsTokenResponseClient(
}

Mono<Request> createDefaultedRequest(String clientRegistrationId,
Authentication authentication, ServerWebExchange exchange) {
Authentication authentication, ServerWebExchange exchange) {
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
.switchIfEmpty(currentAuthentication());

Expand All @@ -120,19 +131,27 @@ Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange))
.flatMap(client -> {
if (hasTokenExpired(client)) {
return authorizedClientNotLoaded(clientRegistrationId, authentication, exchange);
} else {
return Mono.just(client);
}
});

}

private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}

private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
Expand All @@ -148,8 +167,31 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
.thenReturn(authorizedClient);
}

private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
if (authorizedClient.getAccessToken() == null) {
return false; // Test scenario: authorizedClient has no accessToken
} else {
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();

return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
}
}

/**
* Attempts to load the client registration id from the current {@link Authentication}
*
* @return
*/
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
Expand Down Expand Up @@ -177,7 +219,7 @@ static class Request {
private final ServerWebExchange exchange;

public Request(String clientRegistrationId, Authentication authentication,
ServerWebExchange exchange) {
ServerWebExchange exchange) {
this.clientRegistrationId = clientRegistrationId;
this.authentication = authentication;
this.exchange = exchange;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.core.MethodParameter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand All @@ -42,7 +44,9 @@
import org.springframework.web.context.request.ServletWebRequest;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.reflect.Method;
import java.time.Instant;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
Expand Down Expand Up @@ -104,7 +108,8 @@ public void setup() {
when(this.authorizedClientRepository.loadAuthorizedClient(
eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
.thenReturn(this.authorizedClient1);
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class));
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class, withSettings()
.name("expiresAt").defaultAnswer((Answer<Instant>) invocation -> Instant.now())));
when(this.authorizedClientRepository.loadAuthorizedClient(
eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
.thenReturn(this.authorizedClient2);
Expand Down Expand Up @@ -230,6 +235,32 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
}

@Test
public void resolveArgumentClientCredentialsExpireReacquireToken() throws Exception {
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
mock(OAuth2AccessTokenResponseClient.class);
this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
.build();
when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);

MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class);

OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument(
methodParameter, null, new ServletWebRequest(this.request), null);

assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken());

verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
}

private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
return new MethodParameter(method, 0);
Expand Down
Loading

0 comments on commit 5a94d18

Please sign in to comment.