Skip to content

Commit

Permalink
Add test for refresh_token grant with public client
Browse files Browse the repository at this point in the history
Related gh-1432
  • Loading branch information
jgrandja committed Jan 10, 2024
1 parent e76fde8 commit faad0be
Showing 1 changed file with 160 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,8 @@
import java.util.List;
import java.util.Set;

import jakarta.servlet.http.HttpServletRequest;

import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
Expand All @@ -34,6 +36,7 @@

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
Expand All @@ -43,16 +46,25 @@
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.Transient;
import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
Expand All @@ -66,6 +78,7 @@
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
Expand All @@ -77,10 +90,15 @@
import org.springframework.security.oauth2.server.authorization.test.SpringTestContextExtension;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
Expand Down Expand Up @@ -217,6 +235,32 @@ public void requestWhenRevokeAndRefreshThenAccessTokenActive() throws Exception
assertThat(accessToken.isActive()).isTrue();
}

// gh-1430
@Test
public void requestWhenRefreshTokenRequestWithPublicClientThenReturnAccessTokenResponse() throws Exception {
this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();

RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.build();
this.registeredClientRepository.save(registeredClient);

OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
this.authorizationService.save(authorization);

this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getRefreshTokenRequestParameters(authorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty());
}

private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
Expand Down Expand Up @@ -307,4 +351,119 @@ static class ParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2Autho
}

}

@EnableWebSecurity
@Configuration(proxyBeanMethods = false)
static class AuthorizationServerConfigurationWithPublicClientAuthentication extends AuthorizationServerConfiguration {
// @formatter:off
@Bean
SecurityFilterChain authorizationServerSecurityFilterChain(
HttpSecurity http, RegisteredClientRepository registeredClientRepository) throws Exception {

OAuth2AuthorizationServerConfigurer authorizationServerConfigurer =
new OAuth2AuthorizationServerConfigurer();
authorizationServerConfigurer
.clientAuthentication(clientAuthentication ->
clientAuthentication
.authenticationConverter(
new PublicClientRefreshTokenAuthenticationConverter())
.authenticationProvider(
new PublicClientRefreshTokenAuthenticationProvider(registeredClientRepository))
);
RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();

http
.securityMatcher(endpointsMatcher)
.authorizeHttpRequests(authorize ->
authorize.anyRequest().authenticated()
)
.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
.apply(authorizationServerConfigurer);
return http.build();
}
// @formatter:on
}

@Transient
private static final class PublicClientRefreshTokenAuthenticationToken extends OAuth2ClientAuthenticationToken {

private PublicClientRefreshTokenAuthenticationToken(String clientId) {
super(clientId, ClientAuthenticationMethod.NONE, null, null);
}

private PublicClientRefreshTokenAuthenticationToken(RegisteredClient registeredClient) {
super(registeredClient, ClientAuthenticationMethod.NONE, null);
}

}

private static final class PublicClientRefreshTokenAuthenticationConverter implements AuthenticationConverter {

@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
return null;
}

// client_id (REQUIRED)
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId)) {
return null;
}

return new PublicClientRefreshTokenAuthenticationToken(clientId);
}

}

private static final class PublicClientRefreshTokenAuthenticationProvider implements AuthenticationProvider {
private final RegisteredClientRepository registeredClientRepository;

private PublicClientRefreshTokenAuthenticationProvider(RegisteredClientRepository registeredClientRepository) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
this.registeredClientRepository = registeredClientRepository;
}

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
PublicClientRefreshTokenAuthenticationToken publicClientAuthentication =
(PublicClientRefreshTokenAuthenticationToken) authentication;

if (!ClientAuthenticationMethod.NONE.equals(publicClientAuthentication.getClientAuthenticationMethod())) {
return null;
}

String clientId = publicClientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
}

if (!registeredClient.getClientAuthenticationMethods().contains(
publicClientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method");
}

return new PublicClientRefreshTokenAuthenticationToken(registeredClient);
}

@Override
public boolean supports(Class<?> authentication) {
return PublicClientRefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}

private static void throwInvalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Public client authentication failed: " + parameterName,
null
);
throw new OAuth2AuthenticationException(error);
}

}

}

0 comments on commit faad0be

Please sign in to comment.