Skip to content

Commit

Permalink
Add 4096 bit Diffie-Hellman key exchange and encrypt sasl jwt transmi…
Browse files Browse the repository at this point in the history
…ssion
  • Loading branch information
zzuljin committed Jan 8, 2025
1 parent 3dae9a8 commit f25e64a
Show file tree
Hide file tree
Showing 8 changed files with 509 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2025 Mishmash IO UK Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package io.mishmash.stacks.oidc.sasl;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.AlgorithmParameters;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;

public class DHUtils {

public static ByteBuffer encryptChallenge(
final SecretKeySpec key,
final byte[] response)
throws GeneralSecurityException, IOException {
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
cipher.init(Cipher.ENCRYPT_MODE, key);
byte[] encrypted = cipher.doFinal(response);
AlgorithmParameters p = cipher.getParameters();
byte[] params = cipher.getParameters().getEncoded();
ByteBuffer resp = ByteBuffer.allocate(
Integer.BYTES + params.length + encrypted.length);
resp.putInt(params.length);
resp.put(params);
resp.put(encrypted);

return resp;
}

public static byte[] decryptChallenge(
final SecretKeySpec key,
final ByteBuffer challenge)
throws GeneralSecurityException, IOException {
int paramsLen;
if (challenge.remaining() <= Integer.BYTES
|| (paramsLen = challenge.getInt()) > 16 * 1024
|| challenge.remaining() <= paramsLen
|| challenge.remaining() > paramsLen + (16 * 1024)) {
throw new IllegalArgumentException("Corrupt buffer received");
}

byte[] params = new byte[paramsLen];
challenge.get(params);
byte[] encrypted = new byte[challenge.remaining()];
challenge.get(encrypted);

AlgorithmParameters algP = AlgorithmParameters.getInstance("AES");
algP.init(params);
Cipher cipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
cipher.init(Cipher.DECRYPT_MODE, key, algP);

return cipher.doFinal(encrypted);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
public class OAUTHBearerClient implements SaslClient {

private boolean isComplete = false;
private boolean checksFailed = false;
private OIDCClientPrincipal oidc;
private String authz;
private String server;
Expand All @@ -60,6 +61,11 @@ public boolean hasInitialResponse() {
@Override
public byte[] evaluateChallenge(final byte[] challenge)
throws SaslException {
if (checksFailed || challenge == null) {
throw new SaslException(
getMechanismName() + " authentication failed");
}

if (challenge.length == 0) {
// initial response, send ticket
try {
Expand All @@ -74,16 +80,17 @@ public byte[] evaluateChallenge(final byte[] challenge)
// write a final delimiter
baos.write(0x01);

// FIXME:
isComplete = true;
return baos.toByteArray();
} catch (Exception e) {
throw new SaslException(
"Could not send auth info to server", e);
getMechanismName()
+ " SASL client failed to send auth info to server", e);
}
} else {
// got an error response, should not be reached, but - confirm it
return new byte[] {0x01};
}

return null;
}

@Override
Expand All @@ -96,15 +103,19 @@ public byte[] unwrap(
final byte[] incoming,
final int offset,
final int len) throws SaslException {
throw new IllegalStateException("Sasl integrity and privacy are not supported");
throw new SaslException(
getMechanismName()
+ " SASL integrity and privacy are not supported");
}

@Override
public byte[] wrap(
final byte[] outgoing,
final int offset,
final int len) throws SaslException {
throw new IllegalStateException("Sasl integrity and privacy are not supported");
throw new SaslException(
getMechanismName()
+ " SASL integrity and privacy are not supported");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright 2025 Mishmash IO UK Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package io.mishmash.stacks.oidc.sasl;

import java.nio.ByteBuffer;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;

import javax.crypto.KeyAgreement;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslException;

import io.mishmash.stacks.oidc.login.OIDCClientPrincipal;

public class OAUTHBearerClientDH extends OAUTHBearerClient {

private int keyLength;
private KeyPairGenerator keyPairGenerator;
private KeyPair keyPair;
private KeyAgreement agreement;
private boolean keyExchangeComplete = false;
private PublicKey serverPublicKey;
private SecretKeySpec secretKey;

public OAUTHBearerClientDH(
final OIDCClientPrincipal oidcClient,
final String authzId,
final String serverName,
final int keyLen) throws SaslException {
super(oidcClient, authzId, serverName);

this.keyLength = keyLen;

try {
keyPairGenerator = KeyPairGenerator.getInstance("DH");
keyPairGenerator.initialize(keyLength);
keyPair = keyPairGenerator.generateKeyPair();
agreement = KeyAgreement.getInstance("DH");
agreement.init(keyPair.getPrivate());
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new SaslException(
getMechanismName()
+ " SASL client initialization error");
}
}

@Override
public byte[] evaluateChallenge(final byte[] challenge)
throws SaslException {
if (challenge == null) {
throw new SaslException(
getMechanismName() + " authentication failed");
}

if (keyExchangeComplete) {
// decrypt, pass to parent and encrypt response
return encrypt(super.evaluateChallenge(decrypt(challenge)));
} else {
if (serverPublicKey != null) {
throw new SaslException(
getMechanismName()
+ " SASL client unexpected challenge from server");
}

if (challenge.length == 0) {
// begin DH exchange, send public key
return keyPair.getPublic().getEncoded();
} else {
// received server public key
try {
KeyFactory factory = KeyFactory.getInstance("DH");
X509EncodedKeySpec keySpec =
new X509EncodedKeySpec(challenge);
serverPublicKey = factory.generatePublic(keySpec);
agreement.doPhase(serverPublicKey, true);
byte[] sharedSecret = agreement.generateSecret();
secretKey = new SecretKeySpec(
sharedSecret, 0, 16, "AES");
keyExchangeComplete = true;

// get the initial response from parent
return encrypt(super.evaluateChallenge(new byte[0]));
} catch (NoSuchAlgorithmException
| InvalidKeySpecException
| InvalidKeyException
| IllegalStateException e) {
throw new SaslException(
getMechanismName()
+ " SASL client failed during key exchange");
}
}
}
}

@Override
public String getMechanismName() {
return super.getMechanismName() + "-DH" + keyLength;
}

@Override
public void dispose() throws SaslException {
keyLength = 0;
keyPairGenerator = null;
keyPair = null;
agreement = null;
keyExchangeComplete = false;
serverPublicKey = null;
secretKey = null;

super.dispose();
}

protected byte[] encrypt(final byte[] challenge) throws SaslException {
if (challenge == null || challenge.length == 0) {
return challenge;
}

try {
return DHUtils
.encryptChallenge(secretKey, challenge)
.array();
} catch (Exception e) {
throw new SaslException(
getMechanismName()
+ " SASL client failed to encrypt");
}
}

protected byte[] decrypt(final byte[] challenge)
throws SaslException {
if (challenge == null || challenge.length == 0) {
return challenge;
}

try {
return DHUtils
.decryptChallenge(secretKey, ByteBuffer.wrap(challenge));
} catch (Exception e) {
throw new SaslException(
getMechanismName()
+ " SASL client failed to encrypt");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package io.mishmash.stacks.oidc.sasl;

import java.security.AccessController;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand Down Expand Up @@ -51,8 +50,7 @@ public SaslClient createSaslClient(
+ " server name: " + serverName
+ " props: " + props);

// FIXME: update for Java 21
Subject subject = Subject.getSubject(AccessController.getContext());
Subject subject = Subject.current();

if (subject == null) {
throw new SaslException(
Expand All @@ -69,13 +67,38 @@ public SaslClient createSaslClient(
"Could not find an OIDC client");
}

return new OAUTHBearerClient(client,
// prioritize the first mechanism given
String mechanism = mechanisms[0];
if (OAUTHBearerProvider.MECHANISM.equals(mechanism)) {
return new OAUTHBearerClient(client,
authorizationId,
serverName);
} else if (mechanism.startsWith(
OAUTHBearerProvider.MECHANISM + "-DH")) {
int keyLen = Integer.valueOf(
mechanism.substring(
(OAUTHBearerProvider.MECHANISM + "-DH")
.length()));

switch (keyLen) {
case 4096:
return new OAUTHBearerClientDH(
client,
authorizationId,
serverName,
4096);
default:
throw new SaslException("Unsupported key length: " + keyLen);
}
} else {
throw new SaslException("Unsupported mechanism " + mechanism);
}
}

@Override
public String[] getMechanismNames(final Map<String, ?> props) {
return new String[] {OAUTHBearerProvider.MECHANISM};
return new String[] {
OAUTHBearerProvider.MECHANISM,
OAUTHBearerProvider.MECHANISM + "-DH4096"};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,11 @@ public OAUTHBearerProvider() {
OAUTHBearerClientFactory.class.getName());
put(SaslServerFactory.class.getSimpleName() + "." + MECHANISM,
OAUTHBearerServerFactory.class.getName());
put(SaslClientFactory.class.getSimpleName()
+ "." + MECHANISM + "-DH4096",
OAUTHBearerClientFactory.class.getName());
put(SaslServerFactory.class.getSimpleName()
+ "." + MECHANISM + "-DH4096",
OAUTHBearerServerFactory.class.getName());
}
}
Loading

0 comments on commit f25e64a

Please sign in to comment.