Skip to content

Commit

Permalink
Implement custom JWT assertion signing (#1001)
Browse files Browse the repository at this point in the history
- Can be used to sign with KMS services instead of local private key
  • Loading branch information
clementdenis authored May 9, 2024
1 parent 596a52a commit f9b2999
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
11 changes: 11 additions & 0 deletions api/src/main/java/com/okta/sdk/client/ClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.file.Path;
import java.security.PrivateKey;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
*
Expand Down Expand Up @@ -241,6 +242,16 @@ public interface ClientBuilder {
*/
ClientBuilder setPrivateKey(PrivateKey privateKey);

/**
* Allows specifying a custom signer for signing JWT token, instead of using a locally stored private key.
*
* @param jwtSigner the JWT signer instance.
* @return the ClientBuilder instance for method chaining.
*
* @since 16.x.x
*/
ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm);

/**
* Allows specifying the user obtained OAuth2 access token to be used by the SDK.
* The SDK will NOT obtain access token automatically (using the supplied private key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

import com.okta.sdk.impl.retry.OktaHttpRequestRetryStrategy;
import com.okta.sdk.resource.model.GroupProfile;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import org.apache.hc.client5.http.auth.AuthScope;
import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
import org.apache.hc.client5.http.config.ConnectionConfig;
Expand Down Expand Up @@ -81,6 +82,7 @@
import java.security.PrivateKey;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -447,8 +449,12 @@ private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration)
"At least one scope is required");
String privateKey = clientConfiguration.getPrivateKey();
String oAuth2AccessToken = clientConfiguration.getOAuth2AccessToken();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken),
"Either Private Key (or) Access Token must be supplied for OAuth2 Authentication mode");
UnaryOperator<byte[]> jwtSigner = clientConfiguration.getJwtSigner();
String jwtSigningAlgorithm = clientConfiguration.getJwtSigningAlgorithm();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken)
|| Objects.nonNull(jwtSigner) && Objects.nonNull(jwtSigningAlgorithm),
"Either Private Key (or) Access Token (or) JWT Signer + Algorithm" +
" must be supplied for OAuth2 Authentication mode");

if (Strings.hasText(privateKey) && !ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) {
// privateKey is a file path, check if the file exists
Expand Down Expand Up @@ -575,6 +581,14 @@ private String readFromInputStream(InputStream inputStream) throws IOException {
return resultStringBuilder.toString();
}

@Override
public ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
Assert.notNull(jwtSigner, "jwtSigner cannot be null.");
Assert.notNull(algorithm, "algorithm cannot be null.");
clientConfig.setJwtSigner(jwtSigner, algorithm);
return this;
}

@Override
public ClientBuilder setClientId(String clientId) {
ConfigurationValidator.assertClientId(clientId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
* This class holds the default configuration properties.
Expand All @@ -54,6 +55,8 @@ public class ClientConfiguration extends HttpClientConfiguration {
private String privateKey;
private String oAuth2AccessToken;
private String kid;
private UnaryOperator<byte[]> jwtSigner;
private String jwtSigningAlgorithm;

public String getApiToken() {
return apiToken;
Expand Down Expand Up @@ -151,6 +154,23 @@ public void setKid(String kid) {
this.kid = kid;
}

public UnaryOperator<byte[]> getJwtSigner() {
return jwtSigner;
}

public void setJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
this.jwtSigner = jwtSigner;
this.jwtSigningAlgorithm = algorithm;
}

public String getJwtSigningAlgorithm() {
return jwtSigningAlgorithm;
}

public boolean hasCustomJwtSigner() {
return jwtSigner != null && jwtSigningAlgorithm != null;
}

/**
* Time to idle for cache manager in seconds
* @return seconds until time to idle expires
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,32 @@
import com.okta.sdk.impl.api.DefaultClientCredentialsResolver;
import com.okta.sdk.impl.config.ClientConfiguration;
import com.okta.sdk.impl.util.ConfigUtil;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecureRequest;
import io.jsonwebtoken.security.SecurityException;
import io.jsonwebtoken.security.VerifySecureDigestRequest;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.time.Instant;
Expand All @@ -60,8 +67,48 @@ public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverServ

static final String TOKEN_URI = "/oauth2/v1/token";

private static final KeyPair DUMMY_KEY_PAIR = Jwts.SIG.RS256.keyPair().build();

/**
* Custom SecureDigestAlgorithm that delegates signature to the jwtSigner in tokenClientConfiguration
*/
private class CustomJwtSigningAlgorithm implements SecureDigestAlgorithm<PrivateKey, Key> {
@Override
public byte[] digest(SecureRequest<InputStream, PrivateKey> request) throws SecurityException {
try {
byte[] bytes = readAllBytes(request.getPayload());
return tokenClientConfiguration.getJwtSigner().apply(bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

//to replace with InputStream.readAllBytes after migrating to Java 9+
private byte[] readAllBytes(InputStream payload) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int nRead;
byte[] data = new byte[16384];
while ((nRead = payload.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
return buffer.toByteArray();
}

@Override
public boolean verify(VerifySecureDigestRequest<Key> request) throws SecurityException {
//no need to verify JWTs
throw new UnsupportedOperationException();
}

@Override
public String getId() {
return tokenClientConfiguration.getJwtSigningAlgorithm();
}
}

private final ClientConfiguration tokenClientConfiguration;
private final ApiClient apiClient;
private final CustomJwtSigningAlgorithm customJwtSigningAlgorithm = new CustomJwtSigningAlgorithm();

public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, ApiClient apiClient) {
Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null.");
Expand Down Expand Up @@ -133,7 +180,6 @@ public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyEx
*/
String createSignedJWT() throws InvalidKeyException, IOException {
String clientId = tokenClientConfiguration.getClientId();
PrivateKey privateKey = parsePrivateKey(getPemReader());
Instant now = Instant.now();

JwtBuilder builder = Jwts.builder()
Expand All @@ -142,8 +188,14 @@ String createSignedJWT() throws InvalidKeyException, IOException {
.expiration(Date.from(now.plus(50, ChronoUnit.MINUTES))) // see Javadoc
.issuer(clientId)
.subject(clientId)
.claim("jti", UUID.randomUUID().toString())
.signWith(privateKey);
.claim("jti", UUID.randomUUID().toString());

if (tokenClientConfiguration.hasCustomJwtSigner()) {
//JwtBuilder requires a key to be passed, even if it's actually not used by the algorithm
builder.signWith(DUMMY_KEY_PAIR.getPrivate(), customJwtSigningAlgorithm);
} else {
builder = builder.signWith(parsePrivateKey(getPemReader()));
}

if (Strings.hasText(tokenClientConfiguration.getKid())) {
builder.header().add("kid", tokenClientConfiguration.getKid());
Expand Down Expand Up @@ -248,6 +300,7 @@ ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConf
tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId());
tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes());
tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey());
tokenClientConfiguration.setJwtSigner(apiClientConfiguration.getJwtSigner(), apiClientConfiguration.getJwtSigningAlgorithm());
tokenClientConfiguration.setKid(apiClientConfiguration.getKid());

// setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective
Expand Down

0 comments on commit f9b2999

Please sign in to comment.