Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement custom JWT assertion signing (#1001) #1215

Merged
merged 3 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions api/src/main/java/com/okta/sdk/client/ClientBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.file.Path;
import java.security.PrivateKey;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
*
Expand Down Expand Up @@ -241,6 +242,16 @@ public interface ClientBuilder {
*/
ClientBuilder setPrivateKey(PrivateKey privateKey);

/**
* Allows specifying a custom signer for signing JWT token, instead of using a locally stored private key.
*
* @param jwtSigner the JWT signer instance.
* @return the ClientBuilder instance for method chaining.
*
* @since 16.x.x
*/
ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm);

/**
* Allows specifying the user obtained OAuth2 access token to be used by the SDK.
* The SDK will NOT obtain access token automatically (using the supplied private key)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
{{!
Copyright 2021-Present Okta, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
}}
{{!
Based on https://github.com/OpenAPITools/openapi-generator/blob/v6.6.0/modules/openapi-generator/src/main/resources/Java/typeInfoAnnotation.mustache
- Add defaultImpl to deserialize to base type if discriminator is null or unknown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import java.security.PrivateKey;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -447,8 +448,12 @@ private void validateOAuth2ClientConfig(ClientConfiguration clientConfiguration)
"At least one scope is required");
String privateKey = clientConfiguration.getPrivateKey();
String oAuth2AccessToken = clientConfiguration.getOAuth2AccessToken();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken),
"Either Private Key (or) Access Token must be supplied for OAuth2 Authentication mode");
UnaryOperator<byte[]> jwtSigner = clientConfiguration.getJwtSigner();
String jwtSigningAlgorithm = clientConfiguration.getJwtSigningAlgorithm();
Assert.isTrue(Objects.nonNull(privateKey) || Objects.nonNull(oAuth2AccessToken)
|| Objects.nonNull(jwtSigner) && Objects.nonNull(jwtSigningAlgorithm),
"Either Private Key (or) Access Token (or) JWT Signer + Algorithm" +
" must be supplied for OAuth2 Authentication mode");

if (Strings.hasText(privateKey) && !ConfigUtil.hasPrivateKeyContentWrapper(privateKey)) {
// privateKey is a file path, check if the file exists
Expand Down Expand Up @@ -575,6 +580,14 @@ private String readFromInputStream(InputStream inputStream) throws IOException {
return resultStringBuilder.toString();
}

@Override
public ClientBuilder setCustomJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
Assert.notNull(jwtSigner, "jwtSigner cannot be null.");
Assert.notNull(algorithm, "algorithm cannot be null.");
clientConfig.setJwtSigner(jwtSigner, algorithm);
return this;
}

@Override
public ClientBuilder setClientId(String clientId) {
ConfigurationValidator.assertClientId(clientId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
* This class holds the default configuration properties.
Expand All @@ -54,6 +55,8 @@ public class ClientConfiguration extends HttpClientConfiguration {
private String privateKey;
private String oAuth2AccessToken;
private String kid;
private UnaryOperator<byte[]> jwtSigner;
private String jwtSigningAlgorithm;

public String getApiToken() {
return apiToken;
Expand Down Expand Up @@ -151,6 +154,23 @@ public void setKid(String kid) {
this.kid = kid;
}

public UnaryOperator<byte[]> getJwtSigner() {
return jwtSigner;
}

public void setJwtSigner(UnaryOperator<byte[]> jwtSigner, String algorithm) {
this.jwtSigner = jwtSigner;
this.jwtSigningAlgorithm = algorithm;
}

public String getJwtSigningAlgorithm() {
return jwtSigningAlgorithm;
}

public boolean hasCustomJwtSigner() {
return jwtSigner != null && jwtSigningAlgorithm != null;
}

/**
* Time to idle for cache manager in seconds
* @return seconds until time to idle expires
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,32 @@
import com.okta.sdk.impl.api.DefaultClientCredentialsResolver;
import com.okta.sdk.impl.config.ClientConfiguration;
import com.okta.sdk.impl.util.ConfigUtil;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecureRequest;
import io.jsonwebtoken.security.SecurityException;
import io.jsonwebtoken.security.VerifySecureDigestRequest;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import com.okta.sdk.resource.client.ApiClient;
import com.okta.sdk.resource.client.ApiException;
import com.okta.sdk.resource.model.HttpMethod;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.Reader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.time.Instant;
Expand All @@ -60,8 +67,48 @@ public class AccessTokenRetrieverServiceImpl implements AccessTokenRetrieverServ

static final String TOKEN_URI = "/oauth2/v1/token";

private static final KeyPair DUMMY_KEY_PAIR = Jwts.SIG.RS256.keyPair().build();

/**
* Custom SecureDigestAlgorithm that delegates signature to the jwtSigner in tokenClientConfiguration
*/
private class CustomJwtSigningAlgorithm implements SecureDigestAlgorithm<PrivateKey, Key> {
@Override
public byte[] digest(SecureRequest<InputStream, PrivateKey> request) throws SecurityException {
try {
byte[] bytes = readAllBytes(request.getPayload());
return tokenClientConfiguration.getJwtSigner().apply(bytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

//to replace with InputStream.readAllBytes after migrating to Java 9+
private byte[] readAllBytes(InputStream payload) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int nRead;
byte[] data = new byte[16384];
while ((nRead = payload.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
return buffer.toByteArray();
}

@Override
public boolean verify(VerifySecureDigestRequest<Key> request) throws SecurityException {
//no need to verify JWTs
throw new UnsupportedOperationException();
}

@Override
public String getId() {
return tokenClientConfiguration.getJwtSigningAlgorithm();
}
}

private final ClientConfiguration tokenClientConfiguration;
private final ApiClient apiClient;
private final CustomJwtSigningAlgorithm customJwtSigningAlgorithm = new CustomJwtSigningAlgorithm();

public AccessTokenRetrieverServiceImpl(ClientConfiguration apiClientConfiguration, ApiClient apiClient) {
Assert.notNull(apiClientConfiguration, "apiClientConfiguration must not be null.");
Expand Down Expand Up @@ -133,7 +180,6 @@ public OAuth2AccessToken getOAuth2AccessToken() throws IOException, InvalidKeyEx
*/
String createSignedJWT() throws InvalidKeyException, IOException {
String clientId = tokenClientConfiguration.getClientId();
PrivateKey privateKey = parsePrivateKey(getPemReader());
Instant now = Instant.now();

JwtBuilder builder = Jwts.builder()
Expand All @@ -142,8 +188,14 @@ String createSignedJWT() throws InvalidKeyException, IOException {
.expiration(Date.from(now.plus(50, ChronoUnit.MINUTES))) // see Javadoc
.issuer(clientId)
.subject(clientId)
.claim("jti", UUID.randomUUID().toString())
.signWith(privateKey);
.claim("jti", UUID.randomUUID().toString());

if (tokenClientConfiguration.hasCustomJwtSigner()) {
//JwtBuilder requires a key to be passed, even if it's actually not used by the algorithm
builder.signWith(DUMMY_KEY_PAIR.getPrivate(), customJwtSigningAlgorithm);
} else {
builder = builder.signWith(parsePrivateKey(getPemReader()));
}

if (Strings.hasText(tokenClientConfiguration.getKid())) {
builder.header().add("kid", tokenClientConfiguration.getKid());
Expand Down Expand Up @@ -248,6 +300,7 @@ ClientConfiguration constructTokenClientConfig(ClientConfiguration apiClientConf
tokenClientConfiguration.setClientId(apiClientConfiguration.getClientId());
tokenClientConfiguration.setScopes(apiClientConfiguration.getScopes());
tokenClientConfiguration.setPrivateKey(apiClientConfiguration.getPrivateKey());
tokenClientConfiguration.setJwtSigner(apiClientConfiguration.getJwtSigner(), apiClientConfiguration.getJwtSigningAlgorithm());
tokenClientConfiguration.setKid(apiClientConfiguration.getKid());

// setting this to '0' will disable this check and only 'retryMaxAttempts' will be effective
Expand Down