diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..256ced675ab --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProvider.java @@ -0,0 +1,176 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultTokenExchangeTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see OAuth2AuthorizedClientProvider + * @see DefaultTokenExchangeTokenResponseClient + */ +public final class TokenExchangeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + + private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultTokenExchangeTokenResponseClient(); + + private Function subjectTokenResolver = this::resolveSubjectToken; + + private Function actorTokenResolver = (context) -> null; + + private Duration clockSkew = Duration.ofSeconds(60); + + private Clock clock = Clock.systemUTC(); + + /** + * Attempt to authorize (or re-authorize) the + * {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided + * {@code context}. Returns {@code null} if authorization (or re-authorization) is not + * supported, e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() + * authorization grant type} is not {@link AuthorizationGrantType#TOKEN_EXCHANGE + * token-exchange} OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} + * is not expired. + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not + * supported + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.TOKEN_EXCHANGE.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no + // need for re-authorization + return null; + } + OAuth2Token subjectToken = this.subjectTokenResolver.apply(context); + if (subjectToken == null) { + return null; + } + + OAuth2Token actorToken = this.actorTokenResolver.apply(context); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, subjectToken, + actorToken); + OAuth2AccessTokenResponse tokenResponse = getTokenResponse(clientRegistration, grantRequest); + + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), + tokenResponse.getAccessToken()); + } + + private OAuth2Token resolveSubjectToken(OAuth2AuthorizationContext context) { + if (context.getPrincipal().getPrincipal() instanceof OAuth2Token accessToken) { + return accessToken; + } + return null; + } + + private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration, + TokenExchangeGrantRequest tokenExchangeGrantRequest) { + try { + return this.accessTokenResponseClient.getTokenResponse(tokenExchangeGrantRequest); + } + catch (OAuth2AuthorizationException ex) { + throw new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(), ex); + } + } + + private boolean hasTokenExpired(OAuth2Token token) { + return this.clock.instant().isAfter(token.getExpiresAt().minus(this.clockSkew)); + } + + /** + * Sets the client used when requesting an access token credential at the Token + * Endpoint for the {@code token-exchange} grant. + * @param accessTokenResponseClient the client used when requesting an access token + * credential at the Token Endpoint for the {@code token-exchange} grant + */ + public void setAccessTokenResponseClient( + OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token subject token}. + * @param subjectTokenResolver the resolver used for resolving the {@link OAuth2Token + * subject token} + */ + public void setSubjectTokenResolver(Function subjectTokenResolver) { + Assert.notNull(subjectTokenResolver, "subjectTokenResolver cannot be null"); + this.subjectTokenResolver = subjectTokenResolver; + } + + /** + * Sets the resolver used for resolving the {@link OAuth2Token actor token}. + * @param actorTokenResolver the resolver used for resolving the {@link OAuth2Token + * actor token} + */ + public void setActorTokenResolver(Function actorTokenResolver) { + Assert.notNull(actorTokenResolver, "actorTokenResolver cannot be null"); + this.actorTokenResolver = actorTokenResolver; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the + * {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is + * 60 seconds. + * + *

+ * An access token is considered expired if + * {@code OAuth2AccessToken#getExpiresAt() - clockSkew} is before the current time + * {@code clock#instant()}. + * @param clockSkew the maximum acceptable clock skew + */ + public void setClockSkew(Duration clockSkew) { + Assert.notNull(clockSkew, "clockSkew cannot be null"); + Assert.isTrue(clockSkew.getSeconds() >= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } + + /** + * Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking the access + * token expiry. + * @param clock the clock + */ + public void setClock(Clock clock) { + Assert.notNull(clock, "clock cannot be null"); + this.clock = clock; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java new file mode 100644 index 00000000000..f191b5f7669 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClient.java @@ -0,0 +1,126 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.util.Arrays; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +/** + * The default implementation of an {@link OAuth2AccessTokenResponseClient} for the + * {@link AuthorizationGrantType#TOKEN_EXCHANGE token-exchange} grant. This implementation + * uses a {@link RestOperations} when requesting an access token credential at the + * Authorization Server's Token Endpoint. + * + * @author Steve Riesenberg + * @since 6.3 + * @see OAuth2AccessTokenResponseClient + * @see TokenExchangeGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public final class DefaultTokenExchangeTokenResponseClient + implements OAuth2AccessTokenResponseClient { + + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private Converter> requestEntityConverter = new ClientAuthenticationMethodValidatingRequestEntityConverter<>( + new TokenExchangeGrantRequestEntityConverter()); + + private RestOperations restOperations; + + public DefaultTokenExchangeTokenResponseClient() { + RestTemplate restTemplate = new RestTemplate( + Arrays.asList(new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(TokenExchangeGrantRequest tokenExchangeGrantRequest) { + Assert.notNull(tokenExchangeGrantRequest, "tokenExchangeGrantRequest cannot be null"); + RequestEntity requestEntity = this.requestEntityConverter.convert(tokenExchangeGrantRequest); + ResponseEntity responseEntity = getResponse(requestEntity); + + return responseEntity.getBody(); + } + + private ResponseEntity getResponse(RequestEntity request) { + try { + return this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } + catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + + ex.getMessage(), + null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + } + + /** + * Sets the {@link Converter} used for converting the + * {@link TokenExchangeGrantRequest} to a {@link RequestEntity} representation of the + * OAuth 2.0 Access Token Request. + * @param requestEntityConverter the {@link Converter} used for converting to a + * {@link RequestEntity} representation of the Access Token Request + */ + public void setRequestEntityConverter( + Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token + * Response. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured + * with the following: + *

    + *
  1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and + * {@link OAuth2AccessTokenResponseHttpMessageConverter}
  2. + *
  3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  4. + *
+ * @param restOperations the {@link RestOperations} used when requesting the Access + * Token Response + */ + public void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java new file mode 100644 index 00000000000..ff835973308 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.util.Assert; + +/** + * A Token Exchange Grant request that holds the {@link OAuth2Token subject token} and + * optional {@link OAuth2Token actor token}. + * + * @author Steve Riesenberg + * @since 6.3 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see ClientRegistration + * @see OAuth2Token + * @see Section + * 1.1 Delegation vs. Impersonation Semantics + * @see Section + * 2.1 Request + * @see Section + * 2.2 Response + */ +public class TokenExchangeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + + private final OAuth2Token subjectToken; + + private final OAuth2Token actorToken; + + /** + * Constructs a {@code JwtBearerGrantRequest} using the provided parameters. + * @param clientRegistration the client registration + * @param subjectToken the subject token + * @param actorToken the actor token + */ + public TokenExchangeGrantRequest(ClientRegistration clientRegistration, OAuth2Token subjectToken, + OAuth2Token actorToken) { + super(AuthorizationGrantType.TOKEN_EXCHANGE, clientRegistration); + Assert.notNull(subjectToken, "subjectToken cannot be null"); + this.subjectToken = subjectToken; + this.actorToken = actorToken; + } + + /** + * Returns the {@link OAuth2Token subject token}. + * @return the {@link OAuth2Token subject token} + */ + public OAuth2Token getSubjectToken() { + return this.subjectToken; + } + + /** + * Returns the {@link OAuth2Token actor token}. + * @return the {@link OAuth2Token actor token} + */ + public OAuth2Token getActorToken() { + return this.actorToken; + } + +} \ No newline at end of file diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java new file mode 100644 index 00000000000..c8f72e4adb4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverter.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +/** + * An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter} + * that converts the provided {@link TokenExchangeGrantRequest} to a {@link RequestEntity} + * representation of an OAuth 2.0 Access Token Request for the Token Exchange Grant. + * + * @author Steve Riesenberg + * @since 6.3 + * @see AbstractOAuth2AuthorizationGrantRequestEntityConverter + * @see TokenExchangeGrantRequest + * @see RequestEntity + * @see Section + * 1.1 Delegation vs. Impersonation Semantics + */ +public class TokenExchangeGrantRequestEntityConverter + extends AbstractOAuth2AuthorizationGrantRequestEntityConverter { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + @Override + protected MultiValueMap createParameters(TokenExchangeGrantRequest grantRequest) { + ClientRegistration clientRegistration = grantRequest.getClientRegistration(); + MultiValueMap parameters = new LinkedMultiValueMap<>(); + parameters.add(OAuth2ParameterNames.GRANT_TYPE, grantRequest.getGrantType().getValue()); + parameters.add(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE); + OAuth2Token subjectToken = grantRequest.getSubjectToken(); + parameters.add(OAuth2ParameterNames.SUBJECT_TOKEN, subjectToken.getTokenValue()); + parameters.add(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, tokenType(subjectToken)); + OAuth2Token actorToken = grantRequest.getActorToken(); + if (actorToken != null) { + parameters.add(OAuth2ParameterNames.ACTOR_TOKEN, actorToken.getTokenValue()); + parameters.add(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, tokenType(actorToken)); + } + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + parameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())) { + parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + return parameters; + } + + private static String tokenType(OAuth2Token token) { + return (token instanceof Jwt) ? JWT_TOKEN_TYPE_VALUE : ACCESS_TOKEN_TYPE_VALUE; + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..b73243be74d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/TokenExchangeOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,384 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.function.Function; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; + +/** + * Tests for {@link TokenExchangeOAuth2AuthorizedClientProvider}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeOAuth2AuthorizedClientProviderTests { + + private TokenExchangeOAuth2AuthorizedClientProvider authorizedClientProvider; + + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + + private ClientRegistration clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private Authentication principal; + + @BeforeEach + public void setUp() { + this.authorizedClientProvider = new TokenExchangeOAuth2AuthorizedClientProvider(); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + // @formatter:off + this.clientRegistration = ClientRegistration.withRegistrationId("token-exchange") + .clientId("client-id") + .clientSecret("client-secret") + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC) + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .tokenUri("https://example.com/oauth2/token") + .build(); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + this.principal = new TestingAuthenticationToken(this.subjectToken, this.subjectToken); + } + + @Test + public void setAccessTokenResponseClientWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .withMessage("accessTokenResponseClient cannot be null"); + // @formatter:on + } + + @Test + public void setSubjectTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setSubjectTokenResolver(null)) + .withMessage("subjectTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setActorTokenResolverWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setActorTokenResolver(null)) + .withMessage("actorTokenResolver cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) + .withMessage("clockSkew cannot be null"); + // @formatter:on + } + + @Test + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .withMessage("clockSkew must be >= 0"); + // @formatter:on + } + + @Test + public void setClockWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.setClock(null)) + .withMessage("clock cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .withMessage("context cannot be null"); + // @formatter:on + } + + @Test + public void authorizeWhenNotTokenExchangeThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredThenNotReauthorized() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() { + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willThrow(new OAuth2AuthorizationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST))); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + + // @formatter:off + assertThatExceptionOfType(ClientAuthorizationException.class) + .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); + // @formatter:on + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenExpiredThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.minus(Duration.ofMinutes(30)); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", + issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), accessToken); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndTokenNotExpiredButClockSkewForcesExpiryThenReauthorized() { + Instant now = Instant.now(); + Instant issuedAt = now.minus(Duration.ofMinutes(60)); + Instant expiresAt = now.plus(Duration.ofMinutes(1)); + OAuth2AccessToken expiresInOneMinAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), expiresInOneMinAccessToken); + // Shorten the lifespan of the access token by 90 seconds, which will ultimately + // force it to expire on the client + this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(reauthorizedClient).isNotNull(); + assertThat(reauthorizedClient).isNotEqualTo(authorizedClient); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenDoesNotResolveThenUnableToAuthorize() { + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(new TestingAuthenticationToken("user", "password")) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verifyNoInteractions(this.accessTokenResponseClient); + } + + @Test + public void authorizeWhenTokenExchangeAndNotAuthorizedAndSubjectTokenResolvesThenAuthorized() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomSubjectTokenResolverSetThenUsed() { + Function subjectTokenResolver = mock(Function.class); + given(subjectTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.subjectToken); + this.authorizedClientProvider.setSubjectTokenResolver(subjectTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password"); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(subjectTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isNull(); + } + + @Test + public void authorizeWhenCustomActorTokenResolverSetThenUsed() { + Function actorTokenResolver = mock(Function.class); + given(actorTokenResolver.apply(any(OAuth2AuthorizationContext.class))).willReturn(this.actorToken); + this.authorizedClientProvider.setActorTokenResolver(actorTokenResolver); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + given(this.accessTokenResponseClient.getTokenResponse(any(TokenExchangeGrantRequest.class))) + .willReturn(accessTokenResponse); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withClientRegistration(this.clientRegistration) + .principal(this.principal) + .build(); + // @formatter:on + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + assertThat(authorizedClient).isNotNull(); + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(actorTokenResolver).apply(authorizationContext); + ArgumentCaptor grantRequestCaptor = ArgumentCaptor + .forClass(TokenExchangeGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(grantRequestCaptor.capture()); + TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue(); + assertThat(grantRequest.getSubjectToken()).isEqualTo(this.subjectToken); + assertThat(grantRequest.getActorToken()).isEqualTo(this.actorToken); + } + + @Test + public void authorizeWhenClockSetThenCalled() { + Clock clock = mock(Clock.class); + given(clock.instant()).willReturn(Instant.now()); + this.authorizedClientProvider.setClock(clock); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + // @formatter:off + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext + .withAuthorizedClient(authorizedClient) + .principal(this.principal) + .build(); + // @formatter:on + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + verify(clock).instant(); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java new file mode 100644 index 00000000000..d440a75c21b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultTokenExchangeTokenResponseClientTests.java @@ -0,0 +1,489 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import java.io.IOException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DefaultJwtBearerTokenResponseClient}. + * + * @author Steve Riesenberg + */ +public class DefaultTokenExchangeTokenResponseClientTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private DefaultTokenExchangeTokenResponseClient tokenResponseClient; + + private ClientRegistration.Builder clientRegistration; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + private MockWebServer server; + + @BeforeEach + public void setUp() throws IOException { + this.tokenResponseClient = new DefaultTokenExchangeTokenResponseClient(); + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + // @formatter:off + this.clientRegistration = TestClientRegistrations.clientCredentials() + .clientId("client-1") + .clientSecret("secret") + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .tokenUri(tokenUri) + .scope("read", "write"); + // @formatter:on + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @AfterEach + public void cleanUp() throws IOException { + this.server.shutdown(); + } + + @Test + public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) + .withMessage("requestEntityConverter cannot be null"); + // @formatter:on + } + + @Test + public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) + .withMessage("restOperations cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .withMessage("tokenExchangeGrantRequest cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsNotNullThenActorParametersAreSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenActorTokenIsJwtThenActorTokenTypeIsJwt() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read write\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)) + .isEqualTo(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)) + .isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + String formParameters = recordedRequest.getBody().readUtf8(); + // @formatter:off + assertThat(formParameters).contains( + param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.TOKEN_EXCHANGE.getValue()), + param(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SUBJECT_TOKEN, this.subjectToken.getTokenValue()), + param(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE, ACCESS_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.ACTOR_TOKEN, this.actorToken.getTokenValue()), + param(OAuth2ParameterNames.ACTOR_TOKEN_TYPE, JWT_TOKEN_TYPE_VALUE), + param(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(this.clientRegistration.build().getScopes(), " ")) + ); + // @formatter:on + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactlyInAnyOrder("read", "write"); + assertThat(accessTokenResponse.getRefreshToken()).isNull(); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + } + + @Test + public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSentSent() throws Exception { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-1", "client_secret=secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .havingRootCause().withMessage("tokenType cannot be null"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasNoScope() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(grantRequest); + assertThat(accessTokenResponse.getAccessToken().getScopes()).isEmpty(); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\"error\": \"invalid_grant\"}"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("[invalid_grant]"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + // @formatter:off + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("[invalid_token_response]") + .withMessageContaining("An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response"); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(new ClientAuthenticationMethod("basic")) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenUnsupportedClientAuthenticationMethodThenIllegalArgument() { + ClientRegistration clientRegistration = this.clientRegistration + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT) + .build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + // @formatter:off + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(grantRequest)) + .withMessageContaining("This class supports `client_secret_basic`, `client_secret_post`, and `none` by default."); + // @formatter:on + } + + @Test + public void getTokenResponseWhenCustomRequestEntityConverterSetThenUsed() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + Converter> requestEntityConverter = spy( + TokenExchangeGrantRequestEntityConverter.class); + this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(requestEntityConverter).convert(grantRequest); + } + + @Test + public void getTokenResponseWhenCustomRestOperationsSetThenUsed() { + // @formatter:off + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + RestOperations restOperations = mock(RestOperations.class); + given(restOperations.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .willReturn(new ResponseEntity<>(HttpStatus.OK)); + this.tokenResponseClient.setRestOperations(restOperations); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(this.clientRegistration.build(), + this.subjectToken, this.actorToken); + this.tokenResponseClient.getTokenResponse(grantRequest); + verify(restOperations).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); + } + + private static String param(String parameterName, String parameterValue) { + return "%s=%s".formatted(parameterName, URLEncoder.encode(parameterValue, StandardCharsets.UTF_8)); + } + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java new file mode 100644 index 00000000000..8a77a66dfb6 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/TokenExchangeGrantRequestEntityConverterTests.java @@ -0,0 +1,308 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.InOrder; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.jwt.TestJwts; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; + +/** + * Tests for {@link TokenExchangeGrantRequestEntityConverter}. + * + * @author Steve Riesenberg + */ +public class TokenExchangeGrantRequestEntityConverterTests { + + private static final String ACCESS_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:access_token"; + + private static final String JWT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"; + + private TokenExchangeGrantRequestEntityConverter converter; + + private OAuth2Token subjectToken; + + private OAuth2Token actorToken; + + @BeforeEach + public void setUp() { + this.converter = new TokenExchangeGrantRequestEntityConverter(); + this.subjectToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.actorToken = null; + } + + @Test + public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.setHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.addHeadersConverter(null)) + .withMessage("headersConverter cannot be null"); + // @formatter:on + } + + @Test + public void setParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.setParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void addParametersConverterWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.converter.addParametersConverter(null)) + .withMessage("parametersConverter cannot be null"); + // @formatter:on + } + + @Test + public void convertWhenHeadersConverterSetThenCalled() { + Converter headersConverter1 = mock(Converter.class); + this.converter.setHeadersConverter(headersConverter1); + Converter headersConverter2 = mock(Converter.class); + this.converter.addHeadersConverter(headersConverter2); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.converter.convert(grantRequest); + InOrder inOrder = inOrder(headersConverter1, headersConverter2); + inOrder.verify(headersConverter1).convert(grantRequest); + inOrder.verify(headersConverter2).convert(grantRequest); + } + + @Test + public void convertWhenParametersConverterSetThenCalled() { + Converter> parametersConverter1 = mock( + Converter.class); + this.converter.setParametersConverter(parametersConverter1); + Converter> parametersConverter2 = mock( + Converter.class); + this.converter.addParametersConverter(parametersConverter2); + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + this.converter.convert(grantRequest); + InOrder inOrder = inOrder(parametersConverter1, parametersConverter2); + inOrder.verify(parametersConverter1).convert(any(TokenExchangeGrantRequest.class)); + inOrder.verify(parametersConverter2).convert(any(TokenExchangeGrantRequest.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenGrantRequestValidThenConverts() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()) + .contains(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenClientAuthenticationMethodIsClientSecretPostThenClientIdAndSecretParametersPresent() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) + .scope("read", "write") + .build(); + // @formatter:on + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()) + .isEqualTo(clientRegistration.getProviderDetails().getTokenUri()); + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()) + .contains(MediaType.valueOf(MediaType.APPLICATION_JSON_VALUE + ";charset=UTF-8")); + assertThat(headers.getContentType()) + .isEqualTo(MediaType.valueOf(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).isNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isEqualTo(clientRegistration.getClientId()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_SECRET)) + .isEqualTo(clientRegistration.getClientSecret()); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenActorTokenIsNotNullThenActorTokenParametersPresent() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.actorToken = TestOAuth2AccessTokens.noScopes(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN)) + .isEqualTo(this.actorToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenSubjectTokenIsJwtThenSubjectTokenTypeIsJwt() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.subjectToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(JWT_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenActorTokenIsJwtThenActorTokenTypeIsJwt() { + // @formatter:off + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE) + .scope("read", "write") + .build(); + // @formatter:on + this.actorToken = TestJwts.jwt().build(); + TokenExchangeGrantRequest grantRequest = new TokenExchangeGrantRequest(clientRegistration, this.subjectToken, + this.actorToken); + RequestEntity requestEntity = this.converter.convert(grantRequest); + assertThat(requestEntity).isNotNull(); + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters).isNotNull(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) + .isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REQUESTED_TOKEN_TYPE)) + .isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN)) + .isEqualTo(this.subjectToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SUBJECT_TOKEN_TYPE)).isEqualTo(ACCESS_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN)) + .isEqualTo(this.actorToken.getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.ACTOR_TOKEN_TYPE)).isEqualTo(JWT_TOKEN_TYPE_VALUE); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)) + .isEqualTo(StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " ")); + } + +} diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java index 07b16e2b741..e1321bd7595 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/AuthorizationGrantType.java @@ -69,6 +69,12 @@ public final class AuthorizationGrantType implements Serializable { public static final AuthorizationGrantType DEVICE_CODE = new AuthorizationGrantType( "urn:ietf:params:oauth:grant-type:device_code"); + /** + * @since 6.3 + */ + public static final AuthorizationGrantType TOKEN_EXCHANGE = new AuthorizationGrantType( + "urn:ietf:params:oauth:grant-type:token-exchange"); + private final String value; /** diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java index d387b482d94..41f63d9a23e 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2ParameterNames.java @@ -182,6 +182,42 @@ public final class OAuth2ParameterNames { */ public static final String INTERVAL = "interval"; + /** + * {@code requested_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String REQUESTED_TOKEN_TYPE = "requested_token_type"; + + /** + * {@code issued_token_type} - used in Token Exchange Access Token Response. + * @since 6.3 + */ + private static final String ISSUED_TOKEN_TYPE = "issued_token_type"; + + /** + * {@code subject_token} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String SUBJECT_TOKEN = "subject_token"; + + /** + * {@code subject_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String SUBJECT_TOKEN_TYPE = "subject_token_type"; + + /** + * {@code actor_token} - used in Token Exchange Request. + * @since 6.3 + */ + public static final String ACTOR_TOKEN = "actor_token"; + + /** + * {@code actor_token_type} - used in Token Exchange Access Token Request. + * @since 6.3 + */ + public static final String ACTOR_TOKEN_TYPE = "actor_token_type"; + private OAuth2ParameterNames() { }