From f9b2999118f72635febb3e9380ccebc7cbb36b31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Denis?= Date: Thu, 9 May 2024 22:14:21 +0200 Subject: [PATCH] Implement custom JWT assertion signing (#1001) - Can be used to sign with KMS services instead of local private key --- .../com/okta/sdk/client/ClientBuilder.java | 11 ++++ .../sdk/impl/client/DefaultClientBuilder.java | 18 ++++- .../sdk/impl/config/ClientConfiguration.java | 20 ++++++ .../AccessTokenRetrieverServiceImpl.java | 65 +++++++++++++++++-- 4 files changed, 106 insertions(+), 8 deletions(-) diff --git a/api/src/main/java/com/okta/sdk/client/ClientBuilder.java b/api/src/main/java/com/okta/sdk/client/ClientBuilder.java index eb952de09dd..131e9b8384d 100644 --- a/api/src/main/java/com/okta/sdk/client/ClientBuilder.java +++ b/api/src/main/java/com/okta/sdk/client/ClientBuilder.java @@ -25,6 +25,7 @@ import java.nio.file.Path; import java.security.PrivateKey; import java.util.Set; +import java.util.function.UnaryOperator; /** * @@ -241,6 +242,16 @@ public interface ClientBuilder { */ ClientBuilder setPrivateKey(PrivateKey privateKey); + /** + * Allows specifying a custom signer for signing JWT token, instead of using a locally stored private key. + * + * @param jwtSigner the JWT signer instance. + * @return the ClientBuilder instance for method chaining. + * + * @since 16.x.x + */ + ClientBuilder setCustomJwtSigner(UnaryOperator jwtSigner, String algorithm); + /** * Allows specifying the user obtained OAuth2 access token to be used by the SDK. * The SDK will NOT obtain access token automatically (using the supplied private key) diff --git a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java index 0a6074e244a..dbd374b0251 100644 --- a/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java +++ b/impl/src/main/java/com/okta/sdk/impl/client/DefaultClientBuilder.java @@ -51,6 +51,7 @@ import com.okta.sdk.impl.retry.OktaHttpRequestRetryStrategy; import com.okta.sdk.resource.model.GroupProfile; +import io.jsonwebtoken.security.SecureDigestAlgorithm; import org.apache.hc.client5.http.auth.AuthScope; import org.apache.hc.client5.http.auth.UsernamePasswordCredentials; import org.apache.hc.client5.http.config.ConnectionConfig; @@ -81,6 +82,7 @@ import java.security.PrivateKey; import java.util.*; import java.util.concurrent.TimeUnit; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; /** @@ -447,8 +449,12 @@ private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration) "At least one scope is required"); String privateKey = clientConfiguration.getPrivateKey(); String oAuth2AccessToken = clientConfiguration.getOAuth2AccessToken(); - Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken), - "Either Private Key (or) Access Token must be supplied for OAuth2 Authentication mode"); + UnaryOperator jwtSigner = clientConfiguration.getJwtSigner(); + String jwtSigningAlgorithm = clientConfiguration.getJwtSigningAlgorithm(); + Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken) + || Objects.nonNull(jwtSigner) && Objects.nonNull(jwtSigningAlgorithm), + "Either Private Key (or) Access Token (or) JWT Signer + Algorithm" + + " must be supplied for OAuth2 Authentication mode"); if (Strings.hasText(privateKey) && !ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) { // privateKey is a file path, check if the file exists @@ -575,6 +581,14 @@ private String readFromInputStream(InputStream inputStream) throws IOException { return resultStringBuilder.toString(); } + @Override + public ClientBuilder setCustomJwtSigner(UnaryOperator jwtSigner, String algorithm) { + Assert.notNull(jwtSigner, "jwtSigner cannot be null."); + Assert.notNull(algorithm, "algorithm cannot be null."); + clientConfig.setJwtSigner(jwtSigner, algorithm); + return this; + } + @Override public ClientBuilder setClientId(String clientId) { ConfigurationValidator.assertClientId(clientId); diff --git a/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java b/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java index b108761b880..3ed3140ebf0 100644 --- a/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java +++ b/impl/src/main/java/com/okta/sdk/impl/config/ClientConfiguration.java @@ -28,6 +28,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; +import java.util.function.UnaryOperator; /** * This class holds the default configuration properties. @@ -54,6 +55,8 @@ public class ClientConfiguration extends HttpClientConfiguration { private String privateKey; private String oAuth2AccessToken; private String kid; + private UnaryOperator jwtSigner; + private String jwtSigningAlgorithm; public String getApiToken() { return apiToken; @@ -151,6 +154,23 @@ public void setKid(String kid) { this.kid = kid; } + public UnaryOperator getJwtSigner() { + return jwtSigner; + } + + public void setJwtSigner(UnaryOperator jwtSigner, String algorithm) { + this.jwtSigner = jwtSigner; + this.jwtSigningAlgorithm = algorithm; + } + + public String getJwtSigningAlgorithm() { + return jwtSigningAlgorithm; + } + + public boolean hasCustomJwtSigner() { + return jwtSigner != null && jwtSigningAlgorithm != null; + } + /** * Time to idle for cache manager in seconds * @return seconds until time to idle expires diff --git a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java index 591faaaaca2..5d35d4970e6 100644 --- a/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java +++ b/impl/src/main/java/com/okta/sdk/impl/oauth2/AccessTokenRetrieverServiceImpl.java @@ -25,25 +25,32 @@ import com.okta.sdk.impl.api.DefaultClientCredentialsResolver; import com.okta.sdk.impl.config.ClientConfiguration; import com.okta.sdk.impl.util.ConfigUtil; +import com.okta.sdk.resource.client.ApiClient; +import com.okta.sdk.resource.client.ApiException; +import com.okta.sdk.resource.model.HttpMethod; import io.jsonwebtoken.JwtBuilder; import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.security.SecureDigestAlgorithm; +import io.jsonwebtoken.security.SecureRequest; +import io.jsonwebtoken.security.SecurityException; +import io.jsonwebtoken.security.VerifySecureDigestRequest; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; import org.bouncycastle.openssl.PEMKeyPair; import org.bouncycastle.openssl.PEMParser; import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; -import com.okta.sdk.resource.client.ApiClient; -import com.okta.sdk.resource.client.ApiException; -import com.okta.sdk.resource.model.HttpMethod; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.Reader; import java.io.StringReader; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Paths; import java.security.InvalidKeyException; +import java.security.Key; import java.security.KeyPair; import java.security.PrivateKey; import java.time.Instant; @@ -60,8 +67,48 @@ public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverServ static final String TOKEN_URI = "/oauth2/v1/token"; + private static final KeyPair DUMMY_KEY_PAIR = Jwts.SIG.RS256.keyPair().build(); + + /** + * Custom SecureDigestAlgorithm that delegates signature to the jwtSigner in tokenClientConfiguration + */ + private class CustomJwtSigningAlgorithm implements SecureDigestAlgorithm { + @Override + public byte[] digest(SecureRequest request) throws SecurityException { + try { + byte[] bytes = readAllBytes(request.getPayload()); + return tokenClientConfiguration.getJwtSigner().apply(bytes); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + //to replace with InputStream.readAllBytes after migrating to Java 9+ + private byte[] readAllBytes(InputStream payload) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + int nRead; + byte[] data = new byte[16384]; + while ((nRead = payload.read(data, 0, data.length)) != -1) { + buffer.write(data, 0, nRead); + } + return buffer.toByteArray(); + } + + @Override + public boolean verify(VerifySecureDigestRequest request) throws SecurityException { + //no need to verify JWTs + throw new UnsupportedOperationException(); + } + + @Override + public String getId() { + return tokenClientConfiguration.getJwtSigningAlgorithm(); + } + } + private final ClientConfiguration tokenClientConfiguration; private final ApiClient apiClient; + private final CustomJwtSigningAlgorithm customJwtSigningAlgorithm = new CustomJwtSigningAlgorithm(); public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, ApiClient apiClient) { Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null."); @@ -133,7 +180,6 @@ public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyEx */ String createSignedJWT() throws InvalidKeyException, IOException { String clientId = tokenClientConfiguration.getClientId(); - PrivateKey privateKey = parsePrivateKey(getPemReader()); Instant now = Instant.now(); JwtBuilder builder = Jwts.builder() @@ -142,8 +188,14 @@ String createSignedJWT() throws InvalidKeyException, IOException { .expiration(Date.from(now.plus(50, ChronoUnit.MINUTES))) // see Javadoc .issuer(clientId) .subject(clientId) - .claim("jti", UUID.randomUUID().toString()) - .signWith(privateKey); + .claim("jti", UUID.randomUUID().toString()); + + if (tokenClientConfiguration.hasCustomJwtSigner()) { + //JwtBuilder requires a key to be passed, even if it's actually not used by the algorithm + builder.signWith(DUMMY_KEY_PAIR.getPrivate(), customJwtSigningAlgorithm); + } else { + builder = builder.signWith(parsePrivateKey(getPemReader())); + } if (Strings.hasText(tokenClientConfiguration.getKid())) { builder.header().add("kid", tokenClientConfiguration.getKid()); @@ -248,6 +300,7 @@ ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConf tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId()); tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes()); tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey()); + tokenClientConfiguration.setJwtSigner(apiClientConfiguration.getJwtSigner(), apiClientConfiguration.getJwtSigningAlgorithm()); tokenClientConfiguration.setKid(apiClientConfiguration.getKid()); // setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective