Skip to content

Commit

Permalink
Fixes spring-projectsgh-5893 when expired retrieve new Client Credent…
Browse files Browse the repository at this point in the history
…ials token.

Once client credentials access token has expired retrieve a new token from the OAuth2 authorization server.
These tokens can't be refreshed because they do not have a refresh token associated with. This is standard behaviour for Oauth 2 client credentails
  • Loading branch information
Warren Bailey authored and warrenbailey committed Dec 22, 2018
1 parent b5455b0 commit 321bb75
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegi
});
}

private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
Mono<OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private final OAuth2AuthorizedClientResolver authorizedClientResolver;

public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this(authorizedClientRepository, new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository));
}

ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository, OAuth2AuthorizedClientResolver authorizedClientResolver) {
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
this.authorizedClientResolver = authorizedClientResolver;
}

/**
Expand Down Expand Up @@ -245,13 +249,30 @@ private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest
}

private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
return createRequest(request)
.flatMap(r -> authorizeWithClientCredentials(clientRegistration, r));
} else if (shouldRefresh(authorizedClient)) {
return createRequest(request)
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
}
return Mono.just(authorizedClient);
}

private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}

private Mono<OAuth2AuthorizedClient> authorizeWithClientCredentials(ClientRegistration clientRegistration, OAuth2AuthorizedClientResolver.Request request) {
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();

return this.authorizedClientResolver.clientCredentials(clientRegistration, authentication, exchange).
flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}

private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
Expand Down Expand Up @@ -280,6 +301,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,16 @@ private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId,
if (clientRegistration == null) {
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
}
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
if (isClientCredentialsGrantType(clientRegistration)) {
return getAuthorizedClient(clientRegistration, attrs);
}
throw new ClientAuthorizationRequiredException(clientRegistrationId);
}

private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) {
return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType());
}


private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
Map<String, Object> attrs) {
Expand Down Expand Up @@ -366,7 +370,11 @@ private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegi
}

private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) {
//Client credentials grant do not have refresh tokens but can expire so we need to get another one
return Mono.fromSupplier(() -> getAuthorizedClient(clientRegistration, request.attributes()));
} else if (shouldRefresh(authorizedClient)) {
return refreshAuthorizedClient(request, next, authorizedClient);
}
return Mono.just(authorizedClient);
Expand Down Expand Up @@ -407,6 +415,10 @@ private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (refreshToken == null) {
return false;
}
return hasTokenExpired(authorizedClient);
}

private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
Instant now = this.clock.instant();
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.reactive.function.client.OAuth2AuthorizedClientResolver.Request;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
Expand All @@ -67,6 +68,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
Expand All @@ -86,6 +88,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;

@Mock
private OAuth2AuthorizedClientResolver oAuth2AuthorizedClientResolver;

@Mock
private ServerWebExchange serverWebExchange;

Expand Down Expand Up @@ -144,6 +149,88 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
}

@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
String clientRegistrationId = registration.getClientId();

this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);

OAuth2AccessToken newAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"new-token",
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
OAuth2AuthorizedClient newAuthorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", newAccessToken, null);
Request r = new Request(clientRegistrationId, authentication, null);
when(this.oAuth2AuthorizedClientResolver.clientCredentials(any(), any(), any())).thenReturn(Mono.just(newAuthorizedClient));
when(this.oAuth2AuthorizedClientResolver.createDefaultedRequest(any(), any(), any())).thenReturn(Mono.just(r));

when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());

Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));

OAuth2AccessToken accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);


OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();


this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();

verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
verify(this.oAuth2AuthorizedClientResolver).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver).createDefaultedRequest(any(), any(), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();

this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository, this.oAuth2AuthorizedClientResolver);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.build();

this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();

verify(this.oAuth2AuthorizedClientResolver, never()).clientCredentials(any(), any(), any());
verify(this.oAuth2AuthorizedClientResolver, never()).createDefaultedRequest(any(), any(), any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenRefreshRequiredThenRefresh() {
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -423,6 +424,80 @@ public void filterWhenRefreshRequiredThenRefresh() {
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() {
this.registration = TestClientRegistrations.clientCredentials().build();

this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();

this.function.filter(request, this.exchange).block();

verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), eq(this.authentication), any(), any());

verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);

ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
this.registration = TestClientRegistrations.clientCredentials().build();

OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse().build();
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
accessTokenResponse);

Instant issuedAt = Instant.now().minus(Duration.ofDays(1));
Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1));

this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(),
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
this.authorizedClientRepository);
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);

OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, null);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.attributes(oauth2AuthorizedClient(authorizedClient))
.attributes(authentication(this.authentication))
.build();

this.function.filter(request, this.exchange).block();

verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any());

verify(clientCredentialsTokenResponseClient).getTokenResponse(any());

List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);

ClientRequest request1 = requests.get(0);
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token");
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request1)).isEmpty();
}

@Test
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
Expand Down

0 comments on commit 321bb75

Please sign in to comment.