Skip to content

Commit

Permalink
Removed clock/expiration checking per Joe for spring-projects#6609
Browse files Browse the repository at this point in the history
  • Loading branch information
mkheck committed May 7, 2019
1 parent ff0c229 commit a8e2b6b
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand All @@ -43,9 +42,6 @@

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;

/**
* An implementation of a {@link HandlerMethodArgumentResolver} that is capable
Expand Down Expand Up @@ -73,10 +69,6 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
new DefaultClientCredentialsTokenResponseClient();

private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);


/**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
*
Expand Down Expand Up @@ -124,17 +116,6 @@ public Object resolveArgument(MethodParameter parameter,
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
clientRegistrationId, principal, servletRequest);
if (authorizedClient != null) {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) {
// MH TODO: Refresh token
}

if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
if (hasTokenExpired(authorizedClient)) {
HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse);
}
}

return authorizedClient;
}

Expand Down Expand Up @@ -191,24 +172,6 @@ private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistrati
return authorizedClient;
}

private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();

return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
}

/**
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;

/**
Expand All @@ -61,10 +57,6 @@ class OAuth2AuthorizedClientResolver {

private String defaultClientRegistrationId;

private Clock clock = Clock.systemUTC();
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);


public OAuth2AuthorizedClientResolver(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Expand All @@ -78,7 +70,6 @@ public OAuth2AuthorizedClientResolver(
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
*
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
Expand All @@ -89,7 +80,6 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie
/**
* If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token.
*
* @param clientRegistrationId the id to use
*/
public void setDefaultClientRegistrationId(String clientRegistrationId) {
Expand Down Expand Up @@ -131,15 +121,7 @@ Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange))
.flatMap(client -> {
if (hasTokenExpired(client)) {
return authorizedClientNotLoaded(clientRegistrationId, authentication, exchange);
} else {
return Mono.just(client);
}
});

.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
}

private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
Expand Down Expand Up @@ -167,28 +149,6 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
.thenReturn(authorizedClient);
}

private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
}
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
if (authorizedClient.getAccessToken() == null) {
return false; // Test scenario: authorizedClient has no accessToken
} else {
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();

return now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew));
}
}

/**
* Attempts to load the client registration id from the current {@link Authentication}
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import org.springframework.core.MethodParameter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.security.authentication.TestingAuthenticationToken;
Expand All @@ -44,7 +43,6 @@

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.time.Instant;

import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
Expand Down Expand Up @@ -104,8 +102,7 @@ public void setup() {
when(this.authorizedClientRepository.loadAuthorizedClient(
eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
.thenReturn(this.authorizedClient1);
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class, withSettings()
.name("expiresAt").defaultAnswer((Answer<Instant>) invocation -> Instant.now())));
this.authorizedClient2 = new OAuth2AuthorizedClient(this.registration2, this.principalName, mock(OAuth2AccessToken.class));
when(this.authorizedClientRepository.loadAuthorizedClient(
eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class)))
.thenReturn(this.authorizedClient2);
Expand Down Expand Up @@ -231,32 +228,6 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
}

@Test
public void resolveArgumentClientCredentialsExpireReacquireToken() throws Exception {
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
mock(OAuth2AccessTokenResponseClient.class);
this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(3600)
.build();
when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);

MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class);

OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument(
methodParameter, null, new ServletWebRequest(this.request), null);

assertThat(authorizedClient).isNotNull();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principalName);
assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken());

verify(this.authorizedClientRepository).saveAuthorizedClient(
eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null));
}

private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
return new MethodParameter(method, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,13 @@
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.ReflectionUtils;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
Expand Down Expand Up @@ -145,39 +136,6 @@ public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuth
.isInstanceOf(ClientAuthorizationRequiredException.class);
}

@Test
public void resolveArgumentClientCredentialsExpireReacquireToken() { //throws Exception {
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
mock(ReactiveOAuth2AccessTokenResponseClient.class);
setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);

OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("access-token-1234")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(0)
.build();

ClientRegistration registration = ClientRegistration.withRegistrationId("client2")
.clientId("client-2")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri("https://provider.com/oauth2/token")
.build();
when(clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));

OAuth2AuthorizedClient authorizedClient2 = new OAuth2AuthorizedClient(registration, authentication.getPrincipal().toString(), accessTokenResponse.getAccessToken());
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(Authentication.class), any())).thenReturn(Mono.just(authorizedClient2));
when(this.authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class), any(Authentication.class), any())).thenReturn(Mono.empty());
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(registration));

MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient2", OAuth2AuthorizedClient.class);
OAuth2AuthorizedClient resolvedClient = (OAuth2AuthorizedClient) resolveArgument(methodParameter);
assertThat(resolvedClient).isNotSameAs(authorizedClient2);
assertThat(resolvedClient).isEqualToComparingFieldByField(authorizedClient2);
}

private Object resolveArgument(MethodParameter methodParameter) {
return this.argumentResolver.resolveArgument(methodParameter, null, null)
.subscriberContext(this.authentication == null ? Context.empty() : ReactiveSecurityContextHolder.withAuthentication(this.authentication))
Expand All @@ -190,26 +148,10 @@ private MethodParameter getMethodParameter(String methodName, Class<?>... paramT
return new MethodParameter(method, 0);
}

private void setClientCredentialsTokenResponseClient(ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
try {
Field clientResolverField = OAuth2AuthorizedClientArgumentResolver.class.getDeclaredField("authorizedClientResolver");
clientResolverField.setAccessible(true);
OAuth2AuthorizedClientResolver clientResolver = (OAuth2AuthorizedClientResolver) clientResolverField.get(this.argumentResolver);

Method setClientCredsTokenRespClientMethod = OAuth2AuthorizedClientResolver.class.getMethod("setClientCredentialsTokenResponseClient", ReactiveOAuth2AccessTokenResponseClient.class);
setClientCredsTokenRespClientMethod.invoke(clientResolver, clientCredentialsTokenResponseClient);
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException | NoSuchFieldException e) {
e.printStackTrace();
}
}

static class TestController {
void paramTypeAuthorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAuth2AuthorizedClient authorizedClient) {
}

void paramTypeAuthorizedClient2(@RegisteredOAuth2AuthorizedClient("client2") OAuth2AuthorizedClient authorizedClient) {
}

void paramTypeAuthorizedClientWithoutAnnotation(OAuth2AuthorizedClient authorizedClient) {
}

Expand Down

0 comments on commit a8e2b6b

Please sign in to comment.