From 4be566062defa249435c4d72eb106fe7b933e023 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 11 Aug 2021 18:04:55 -0500 Subject: [PATCH] Update Spark key negotiation protocol --- common/network-common/pom.xml | 4 + .../network/crypto/AuthClientBootstrap.java | 6 +- .../spark/network/crypto/AuthEngine.java | 420 ++++++++---------- .../{ServerResponse.java => AuthMessage.java} | 56 +-- .../spark/network/crypto/AuthRpcHandler.java | 6 +- .../spark/network/crypto/ClientChallenge.java | 101 ----- .../org/apache/spark/network/crypto/README.md | 217 ++++----- .../spark/network/crypto/AuthEngineSuite.java | 182 +++++--- .../network/crypto/AuthMessagesSuite.java | 46 +- dev/deps/spark-deps-hadoop-2.6 | 1 + dev/deps/spark-deps-hadoop-2.7 | 1 + dev/deps/spark-deps-hadoop-3.1 | 1 + pom.xml | 6 + 13 files changed, 432 insertions(+), 615 deletions(-) rename common/network-common/src/main/java/org/apache/spark/network/crypto/{ServerResponse.java => AuthMessage.java} (53%) delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index cd57c43aae549..d585185263ece 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -85,6 +85,10 @@ org.apache.commons commons-crypto + + com.google.crypto.tink + tink + diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 737e1871c519d..15869894032c6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -98,15 +98,15 @@ private void doSparkAuth(TransportClient client, Channel channel) String secretKey = secretKeyHolder.getSecretKey(appId); try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) { - ClientChallenge challenge = engine.challenge(); + AuthMessage challenge = engine.challenge(); ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); challenge.encode(challengeData); ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(), conf.authRTTimeoutMs()); - ServerResponse response = ServerResponse.decodeMessage(responseData); + AuthMessage response = AuthMessage.decodeMessage(responseData); - engine.validate(response); + engine.deriveSessionCipher(challenge, response); engine.sessionCipher().addToChannel(channel); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 64fdb32a67ada..078d9ceb317b8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -17,134 +17,216 @@ package org.apache.spark.network.crypto; +import javax.crypto.spec.SecretKeySpec; import java.io.Closeable; -import java.io.IOException; -import java.math.BigInteger; import java.security.GeneralSecurityException; import java.util.Arrays; import java.util.Properties; -import javax.crypto.Cipher; -import javax.crypto.SecretKey; -import javax.crypto.SecretKeyFactory; -import javax.crypto.ShortBufferException; -import javax.crypto.spec.IvParameterSpec; -import javax.crypto.spec.PBEKeySpec; -import javax.crypto.spec.SecretKeySpec; -import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.primitives.Bytes; -import org.apache.commons.crypto.cipher.CryptoCipher; -import org.apache.commons.crypto.cipher.CryptoCipherFactory; -import org.apache.commons.crypto.random.CryptoRandom; -import org.apache.commons.crypto.random.CryptoRandomFactory; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - +import com.google.crypto.tink.subtle.AesGcmJce; +import com.google.crypto.tink.subtle.Hkdf; +import com.google.crypto.tink.subtle.Random; +import com.google.crypto.tink.subtle.X25519; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import static java.nio.charset.StandardCharsets.UTF_8; import org.apache.spark.network.util.TransportConf; /** - * A helper class for abstracting authentication and key negotiation details. This is used by - * both client and server sides, since the operations are basically the same. + * A helper class for abstracting authentication and key negotiation details. + * This supports a forward-secure authentication protocol based on X25519 Diffie-Hellman Key + * Exchange, using a pre-shared key to derive an AES-GCM key encrypting key. */ class AuthEngine implements Closeable { - - private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class); - private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); - - private final byte[] appId; - private final char[] secret; + public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8); + public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8); + private static final String MAC_ALGORITHM = "HMACSHA256"; + private static final int AES_GCM_KEY_SIZE_BYTES = 16; + private static final byte[] EMPTY_TRANSCRIPT = new byte[0]; + + private final String appId; + private final byte[] preSharedSecret; private final TransportConf conf; private final Properties cryptoConf; - private final CryptoRandom random; - - private byte[] authNonce; - - @VisibleForTesting - byte[] challenge; + private byte[] clientPrivateKey; private TransportCipher sessionCipher; - private CryptoCipher encryptor; - private CryptoCipher decryptor; - AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { - this.appId = appId.getBytes(UTF_8); + AuthEngine(String appId, String preSharedSecret, TransportConf conf) { + Preconditions.checkNotNull(appId); + Preconditions.checkNotNull(preSharedSecret); + this.appId = appId; + this.preSharedSecret = preSharedSecret.getBytes(UTF_8); this.conf = conf; this.cryptoConf = conf.cryptoConf(); - this.secret = secret.toCharArray(); - this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); + } + + @VisibleForTesting + void setClientPrivateKey(byte[] privateKey) { + this.clientPrivateKey = privateKey; } /** - * Create the client challenge. + * This method will derive a key from a pre-shared secret, a random salt, and an arbitrary + * transcript. It will then use that derived key to AES-GCM encrypt an ephemeral X25519 public + * key. * - * @return A challenge to be sent the remote side. + * @param ephemeralX25519PublicKey Ephemeral X25519 Public Key to encrypt under a derived key. + * @param transcript Optional byte array representing a protocol transcript, which + * is mixed into the key derivation and included as AES-GCM + * associated authenticated data (AAD). + * @return An encrypted ephemeral X25519 public key. + * @throws GeneralSecurityException If HKDF key deriviation or AES-GCM encryption fails. */ - ClientChallenge challenge() throws GeneralSecurityException { - this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); - SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), - authNonce, conf.encryptionKeyLength()); - initializeForAuth(conf.cipherTransformation(), authNonce, authKey); - - this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); - return new ClientChallenge(new String(appId, UTF_8), - conf.keyFactoryAlgorithm(), - conf.keyFactoryIterations(), - conf.cipherTransformation(), - conf.encryptionKeyLength(), - authNonce, - challenge(appId, authNonce, challenge)); + private AuthMessage encryptEphemeralPublicKey( + byte[] ephemeralX25519PublicKey, + byte[] transcript) throws GeneralSecurityException { + // This non-secret salt is used in the HKDF key derivations and will be sent in plaintext as + // part of the AES-GCM encrypted X25519 public key. It will be included as additional + // associated data (AAD). + byte[] nonSecretSalt = Random.randBytes(AES_GCM_KEY_SIZE_BYTES); + // Mix in the app ID, salt, and transcript into HKDF and use it as AES-GCM AAD + byte[] aadState = Bytes.concat(appId.getBytes(UTF_8), nonSecretSalt, transcript); + // Use HKDF to derive an AES_GCM key from the pre-shared key, non-secret salt, and AAD state + byte[] derivedKeyEncryptingKey = Hkdf.computeHkdf( + MAC_ALGORITHM, + preSharedSecret, + nonSecretSalt, + aadState, + AES_GCM_KEY_SIZE_BYTES); + // AES-GCM encrypt the X25519 public key and include the app ID, salt, and transcript as AAD + byte[] aesGcmCiphertext = new AesGcmJce(derivedKeyEncryptingKey) + .encrypt(ephemeralX25519PublicKey, aadState); + return new AuthMessage(appId, nonSecretSalt, aesGcmCiphertext); } /** - * Validates the client challenge, and create the encryption backend for the channel from the - * parameters sent by the client. + * This method will derive a key from a pre-shared secret, a random salt, and an arbitrary + * transcript. It will then use that derived key to AES-GCM encrypt an ephemeral X25519 + * public key. * - * @param clientChallenge The challenge from the client. - * @return A response to be sent to the client. + * @param encryptedPublicKey An X25519 public key to decrypt with a derived key + * @param transcript Optional byte array representing a protocol transcript, which is + * mixed into the key derivation and included as AES-GCM associated + * authenticated data (AAD). + * @return A decrypted ephemeral public key + * @throws GeneralSecurityException If decryption fails, notably if authenticated checks fails. */ - ServerResponse respond(ClientChallenge clientChallenge) - throws GeneralSecurityException { - - SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, - clientChallenge.nonce, clientChallenge.keyLength); - initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); - - byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); - byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge)); - byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); - byte[] inputIv = randomBytes(conf.ivLength()); - byte[] outputIv = randomBytes(conf.ivLength()); + private byte[] decryptEphemeralPublicKey( + AuthMessage encryptedPublicKey, + byte[] transcript) throws GeneralSecurityException { + Preconditions.checkArgument(appId.equals(encryptedPublicKey.appId)); + // Mix in the app ID, salt, and transcript into HKDF and use it as AES-GCM AAD + byte[] aadState = Bytes.concat(appId.getBytes(UTF_8), encryptedPublicKey.salt, transcript); + // Use HKDF to derive an AES_GCM key from the pre-shared key, non-secret salt, and AAD state + byte[] derivedKeyEncryptingKey = Hkdf.computeHkdf( + MAC_ALGORITHM, + preSharedSecret, + encryptedPublicKey.salt, + aadState, + AES_GCM_KEY_SIZE_BYTES); + // If the AES-GCM payload is modified at all or if the AAD state does not match, decryption + // will throw a GeneralSecurityException. + return new AesGcmJce(derivedKeyEncryptingKey) + .decrypt(encryptedPublicKey.ciphertext, aadState); + } - SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, - sessionNonce, clientChallenge.keyLength); - this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, - inputIv, outputIv); + /** + * Encrypt an ephemeral X25519 public key to be sent to the server as a challenge. + * + * @return An encrypted client ephemeral public key to be sent to the server. + */ + AuthMessage challenge() throws GeneralSecurityException { + setClientPrivateKey(X25519.generatePrivateKey()); + return encryptEphemeralPublicKey( + X25519.publicFromPrivate(clientPrivateKey), + EMPTY_TRANSCRIPT); + } - // Note the IVs are swapped in the response. - return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv)); + /** + * Validates the client challenge by decrypting the ephemeral X25519 public key, computing a + * shared secret from it, then encrypting a server ephemeral X25519 public key for the client. + * + * @param encryptedClientPublicKey The encrypted public key from the client to be decrypted. + * @return An encrypted server ephemeral public key to be sent to the client. + */ + AuthMessage response(AuthMessage encryptedClientPublicKey) throws GeneralSecurityException { + Preconditions.checkArgument(appId.equals(encryptedClientPublicKey.appId)); + // Compute a shared secret given the client public key and the server private key + byte[] clientPublicKey = + decryptEphemeralPublicKey(encryptedClientPublicKey, EMPTY_TRANSCRIPT); + // Generate an ephemeral X25519 private key. + byte[] serverEphemeralPrivateKey = X25519.generatePrivateKey(); + // Encrypt the X25519 public key with a key derived from the preSharedSecret and transcript + AuthMessage ephemeralServerPublicKey = encryptEphemeralPublicKey( + X25519.publicFromPrivate(serverEphemeralPrivateKey), + getTranscript(encryptedClientPublicKey)); + // Compute a shared secret given the client public key and the server private key + byte[] sharedSecret = + X25519.computeSharedSecret(serverEphemeralPrivateKey, clientPublicKey); + byte[] challengeResponseTranscript = + getTranscript(encryptedClientPublicKey, ephemeralServerPublicKey); + this.sessionCipher = + generateTransportCipher(sharedSecret, false, challengeResponseTranscript); + return ephemeralServerPublicKey; } /** * Validates the server response and initializes the cipher to use for the session. * - * @param serverResponse The response from the server. + * @param encryptedClientPublicKey The encrypted ephemeral public key from the client. + * @param encryptedServerPublicKey The encrypted ephemeral public key from the server. */ - void validate(ServerResponse serverResponse) throws GeneralSecurityException { - byte[] response = validateChallenge(authNonce, serverResponse.response); - - byte[] expected = rawResponse(challenge); - Preconditions.checkArgument(Arrays.equals(expected, response)); - - byte[] nonce = decrypt(serverResponse.nonce); - byte[] inputIv = decrypt(serverResponse.inputIv); - byte[] outputIv = decrypt(serverResponse.outputIv); - - SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), - nonce, conf.encryptionKeyLength()); - this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey, - inputIv, outputIv); + void deriveSessionCipher(AuthMessage encryptedClientPublicKey, + AuthMessage encryptedServerPublicKey) throws GeneralSecurityException { + Preconditions.checkArgument(appId.equals(encryptedClientPublicKey.appId)); + Preconditions.checkArgument(appId.equals(encryptedServerPublicKey.appId)); + // Compute a shared secret given the server public key and the client private key, + // mixing in the protocol transcript. + byte[] serverPublicKey = decryptEphemeralPublicKey( + encryptedServerPublicKey, + getTranscript(encryptedClientPublicKey)); + // Compute a shared secret given the client public key and the server private key + byte[] sharedSecret = X25519.computeSharedSecret(clientPrivateKey, serverPublicKey); + byte[] challengeResponseTranscript = + getTranscript(encryptedClientPublicKey, encryptedServerPublicKey); + this.sessionCipher = + generateTransportCipher(sharedSecret, true, challengeResponseTranscript); + } + + private TransportCipher generateTransportCipher( + byte[] sharedSecret, + boolean isClient, + byte[] transcript) throws GeneralSecurityException { + byte[] clientIv = Hkdf.computeHkdf( + MAC_ALGORITHM, + sharedSecret, + transcript, // Passing this as the HKDF salt + INPUT_IV_INFO, // This is the HKDF info field used to differentiate IV values + AES_GCM_KEY_SIZE_BYTES); + byte[] serverIv = Hkdf.computeHkdf( + MAC_ALGORITHM, + sharedSecret, + transcript, // Passing this as the HKDF salt + OUTPUT_IV_INFO, // This is the HKDF info field used to differentiate IV values + AES_GCM_KEY_SIZE_BYTES); + SecretKeySpec sessionKey = new SecretKeySpec(sharedSecret, "AES"); + return new TransportCipher( + cryptoConf, + conf.cipherTransformation(), + sessionKey, + isClient ? clientIv : serverIv, // If it's the client, use the client IV first + isClient ? serverIv : clientIv); + } + + private byte[] getTranscript(AuthMessage... encryptedPublicKeys) { + ByteBuf transcript = Unpooled.buffer( + Arrays.stream(encryptedPublicKeys).mapToInt(k -> k.encodedLength()).sum()); + Arrays.stream(encryptedPublicKeys).forEachOrdered(k -> k.encode(transcript)); + return transcript.array(); } TransportCipher sessionCipher() { @@ -153,163 +235,7 @@ TransportCipher sessionCipher() { } @Override - public void close() throws IOException { - // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that - // internal state is cleaned up. Error handling here is just for paranoia, and not meant to - // accurately report the errors when they happen. - RuntimeException error = null; - byte[] dummy = new byte[8]; - if (encryptor != null) { - try { - doCipherOp(Cipher.ENCRYPT_MODE, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); - } - encryptor = null; - } - if (decryptor != null) { - try { - doCipherOp(Cipher.DECRYPT_MODE, dummy, true); - } catch (Exception e) { - error = new RuntimeException(e); - } - decryptor = null; - } - random.close(); - - if (error != null) { - throw error; - } - } + public void close() { - @VisibleForTesting - byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException { - return encrypt(Bytes.concat(appId, nonce, challenge)); } - - @VisibleForTesting - byte[] rawResponse(byte[] challenge) { - BigInteger orig = new BigInteger(challenge); - BigInteger response = orig.add(ONE); - return response.toByteArray(); - } - - private byte[] decrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(Cipher.DECRYPT_MODE, in, false); - } - - private byte[] encrypt(byte[] in) throws GeneralSecurityException { - return doCipherOp(Cipher.ENCRYPT_MODE, in, false); - } - - private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) - throws GeneralSecurityException { - - // commons-crypto currently only supports ciphers that require an initial vector; so - // create a dummy vector so that we can initialize the ciphers. In the future, if - // different ciphers are supported, this will have to be configurable somehow. - byte[] iv = new byte[conf.ivLength()]; - System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); - - CryptoCipher _encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - _encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); - this.encryptor = _encryptor; - - CryptoCipher _decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); - _decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); - this.decryptor = _decryptor; - } - - /** - * Validates an encrypted challenge as defined in the protocol, and returns the byte array - * that corresponds to the actual challenge data. - */ - private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge) - throws GeneralSecurityException { - - byte[] challenge = decrypt(encryptedChallenge); - checkSubArray(appId, challenge, 0); - checkSubArray(nonce, challenge, appId.length); - return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length); - } - - private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength) - throws GeneralSecurityException { - - SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf); - PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength); - - long start = System.nanoTime(); - SecretKey key = factory.generateSecret(spec); - long end = System.nanoTime(); - - LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(), - (end - start) / 1000); - - return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); - } - - private byte[] doCipherOp(int mode, byte[] in, boolean isFinal) - throws GeneralSecurityException { - - CryptoCipher cipher; - switch (mode) { - case Cipher.ENCRYPT_MODE: - cipher = encryptor; - break; - case Cipher.DECRYPT_MODE: - cipher = decryptor; - break; - default: - throw new IllegalArgumentException(String.valueOf(mode)); - } - - Preconditions.checkState(cipher != null, "Cipher is invalid because of previous error."); - - try { - int scale = 1; - while (true) { - int size = in.length * scale; - byte[] buffer = new byte[size]; - try { - int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) - : cipher.update(in, 0, in.length, buffer, 0); - if (outSize != buffer.length) { - byte[] output = new byte[outSize]; - System.arraycopy(buffer, 0, output, 0, output.length); - return output; - } else { - return buffer; - } - } catch (ShortBufferException e) { - // Try again with a bigger buffer. - scale *= 2; - } - } - } catch (InternalError ie) { - // SPARK-25535. The commons-cryto library will throw InternalError if something goes wrong, - // and leave bad state behind in the Java wrappers, so it's not safe to use them afterwards. - if (mode == Cipher.ENCRYPT_MODE) { - this.encryptor = null; - } else { - this.decryptor = null; - } - throw ie; - } - } - - private byte[] randomBytes(int count) { - byte[] bytes = new byte[count]; - random.nextBytes(bytes); - return bytes; - } - - /** Checks that the "test" array is in the data array starting at the given offset. */ - private void checkSubArray(byte[] test, byte[] data, int offset) { - Preconditions.checkArgument(data.length >= test.length + offset); - for (int i = 0; i < test.length; i++) { - Preconditions.checkArgument(test[i] == data[i + offset]); - } - } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java similarity index 53% rename from common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java rename to common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java index caf3a0f3b38cc..76690cbc4c29f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthMessage.java @@ -21,65 +21,55 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; - import org.apache.spark.network.protocol.Encodable; import org.apache.spark.network.protocol.Encoders; /** - * Server's response to client's challenge. + * A message sent in the forward secure authentication protocol, containing an app ID, a salt for + * key derivation, and an encrypted payload. * - * Please see crypto/README.md for more details. + * Please see crypto/README.md for more details of implementation. */ -public class ServerResponse implements Encodable { +class AuthMessage implements Encodable { /** Serialization tag used to catch incorrect payloads. */ private static final byte TAG_BYTE = (byte) 0xFB; - public final byte[] response; - public final byte[] nonce; - public final byte[] inputIv; - public final byte[] outputIv; + public final String appId; + public final byte[] salt; + public final byte[] ciphertext; - public ServerResponse( - byte[] response, - byte[] nonce, - byte[] inputIv, - byte[] outputIv) { - this.response = response; - this.nonce = nonce; - this.inputIv = inputIv; - this.outputIv = outputIv; + AuthMessage(String appId, byte[] salt, byte[] ciphertext) { + this.appId = appId; + this.salt = salt; + this.ciphertext = ciphertext; } @Override public int encodedLength() { return 1 + - Encoders.ByteArrays.encodedLength(response) + - Encoders.ByteArrays.encodedLength(nonce) + - Encoders.ByteArrays.encodedLength(inputIv) + - Encoders.ByteArrays.encodedLength(outputIv); + Encoders.Strings.encodedLength(appId) + + Encoders.ByteArrays.encodedLength(salt) + + Encoders.ByteArrays.encodedLength(ciphertext); } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); - Encoders.ByteArrays.encode(buf, response); - Encoders.ByteArrays.encode(buf, nonce); - Encoders.ByteArrays.encode(buf, inputIv); - Encoders.ByteArrays.encode(buf, outputIv); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, salt); + Encoders.ByteArrays.encode(buf, ciphertext); } - public static ServerResponse decodeMessage(ByteBuffer buffer) { + public static AuthMessage decodeMessage(ByteBuffer buffer) { ByteBuf buf = Unpooled.wrappedBuffer(buffer); if (buf.readByte() != TAG_BYTE) { - throw new IllegalArgumentException("Expected ServerResponse, received something else."); + throw new IllegalArgumentException("Expected ClientChallenge, received something else."); } - return new ServerResponse( - Encoders.ByteArrays.decode(buf), - Encoders.ByteArrays.decode(buf), - Encoders.ByteArrays.decode(buf), - Encoders.ByteArrays.decode(buf)); + return new AuthMessage( + Encoders.Strings.decode(buf), // AppID + Encoders.ByteArrays.decode(buf), // Salt + Encoders.ByteArrays.decode(buf)); // Ciphertext } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index dd31c955350f1..549ee4df467c0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -84,9 +84,9 @@ protected boolean doAuthChallenge( int position = message.position(); int limit = message.limit(); - ClientChallenge challenge; + AuthMessage challenge; try { - challenge = ClientChallenge.decodeMessage(message); + challenge = AuthMessage.decodeMessage(message); LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress()); } catch (RuntimeException e) { if (conf.saslFallback()) { @@ -113,7 +113,7 @@ protected boolean doAuthChallenge( "Trying to authenticate non-registered app %s.", challenge.appId); LOG.debug("Authenticating challenge for app {}.", challenge.appId); engine = new AuthEngine(challenge.appId, secret, conf); - ServerResponse response = engine.respond(challenge); + AuthMessage response = engine.response(challenge); ByteBuf responseData = Unpooled.buffer(response.encodedLength()); response.encode(responseData); callback.onSuccess(responseData.nioBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java deleted file mode 100644 index 819b8a7efbdba..0000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.crypto; - -import java.nio.ByteBuffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.Encoders; - -/** - * The client challenge message, used to initiate authentication. - * - * Please see crypto/README.md for more details of implementation. - */ -public class ClientChallenge implements Encodable { - /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xFA; - - public final String appId; - public final String kdf; - public final int iterations; - public final String cipher; - public final int keyLength; - public final byte[] nonce; - public final byte[] challenge; - - public ClientChallenge( - String appId, - String kdf, - int iterations, - String cipher, - int keyLength, - byte[] nonce, - byte[] challenge) { - this.appId = appId; - this.kdf = kdf; - this.iterations = iterations; - this.cipher = cipher; - this.keyLength = keyLength; - this.nonce = nonce; - this.challenge = challenge; - } - - @Override - public int encodedLength() { - return 1 + 4 + 4 + - Encoders.Strings.encodedLength(appId) + - Encoders.Strings.encodedLength(kdf) + - Encoders.Strings.encodedLength(cipher) + - Encoders.ByteArrays.encodedLength(nonce) + - Encoders.ByteArrays.encodedLength(challenge); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, kdf); - buf.writeInt(iterations); - Encoders.Strings.encode(buf, cipher); - buf.writeInt(keyLength); - Encoders.ByteArrays.encode(buf, nonce); - Encoders.ByteArrays.encode(buf, challenge); - } - - public static ClientChallenge decodeMessage(ByteBuffer buffer) { - ByteBuf buf = Unpooled.wrappedBuffer(buffer); - - if (buf.readByte() != TAG_BYTE) { - throw new IllegalArgumentException("Expected ClientChallenge, received something else."); - } - - return new ClientChallenge( - Encoders.Strings.decode(buf), - Encoders.Strings.decode(buf), - buf.readInt(), - Encoders.Strings.decode(buf), - buf.readInt(), - Encoders.ByteArrays.decode(buf), - Encoders.ByteArrays.decode(buf)); - } - -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md index 14df703270498..78e7459b9995d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -1,158 +1,101 @@ -Spark Auth Protocol and AES Encryption Support +Forward Secure Auth Protocol ============================================== -This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This -protocol is built on symmetric key encryption, based on the assumption that the two endpoints being -authenticated share a common secret, which is how Spark authentication currently works. The protocol -provides mutual authentication, meaning that after the negotiation both parties know that the remote -side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's -not an implementation of it. +This file describes a forward secure authentication protocol which may be used by Spark. This +protocol is essentially ephemeral Diffie-Hellman key exchange using Curve25519, referred to as +X25519. -This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently -released JREs. +Both client and server share a (possibly low-entropy) pre-shared secret that is used to derive a +key-encrypting key using HKDF. This will mix in any preceding protocol transcript. -The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5: - -- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired. -- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only - viable, supported cipher these days is 3DES, and a more modern alternative is desired. -- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link - in the negotiation would still be MD5 and 3DES. - -The protocol assumes that the shared secret is generated and distributed in a secure manner. - -The protocol always negotiates encryption keys. If encryption is not desired, the existing -SASL-based authentication, or no authentication at all, can be chosen instead. - -When messages are described below, it's expected that the implementation should support -arbitrary sizes for fields that don't have a fixed size. +The key-encrypting key is used to encrypt an X25519 public key with AES-GCM. This is intended to +authenticate the message exchange between the parties and there is no expectation of secrecy for +the public key. This protocol utilizes GCM's associated authenticated data (AAD) field to include +metadata and the prior protocol transcript, to bind each round with all preceding rounds. Client Challenge ---------------- -The auth negotiation is started by the client. The client starts by generating an encryption -key based on the application's shared secret, and a nonce. - - KEY = KDF(SECRET, SALT, KEY_LENGTH) - -Where: -- KDF(): a key derivation function that takes a secret, a salt, a configurable number of - iterations, and a configurable key length. -- SALT: a byte sequence used to salt the key derivation function. -- KEY_LENGTH: length of the encryption key to generate. - +The auth negotiation is started by the client. Given an application ID, the client starts by +generating a random 16-byte salt value and deriving a key encryption key: -The client generates a message with the following content: + preSharedKey = lookupKey(appId) + nonSecretSalt = Random(16 bytes) + aadState = Concat(appId, nonSecretSalt) + keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState) - CLIENT_CHALLENGE = ( - APP_ID, - KDF, - ITERATIONS, - CIPHER, - KEY_LENGTH, - ANONCE, - ENC(APP_ID || ANONCE || CHALLENGE)) +This key encryption key is then used to encrypt an ephemeral X25519 public key. -Where: + clientKeyPair = X25519.generate() + randomIV = Random(16 bytes) + ciphertext = AES-GCM-Encrypt( + key = keyEncryptingKey, + iv = randomIV, + plaintext = clientKeyPair.publicKey(), + aad = aadState) + clientChallenge = (appId, nonSecretSalt, randomIV, ciphertext) -- APP_ID: the application ID which the server uses to identify the shared secret. -- KDF: the key derivation function described above. -- ITERATIONS: number of iterations to run the KDF when generating keys. -- CIPHER: the cipher used to encrypt data. -- KEY_LENGTH: length of the encryption keys to generate, in bits. -- ANONCE: the nonce used as the salt when generating the auth key. -- ENC(): an encryption function that uses the cipher and the generated key. This function - will also be used in the definition of other messages below. -- CHALLENGE: a byte sequence used as a challenge to the server. -- ||: concatenation operator. - -When strings are used where byte arrays are expected, the UTF-8 representation of the string -is assumed. - -To respond to the challenge, the server should consider the byte array as representing an -arbitrary-length integer, and respond with the value of the integer plus one. +Note that the App ID and non-secret salt are bound to the ciphertext both through HKDF key +derivation and AES-GCM AAD. We are not relying on keeping the client public key secret and could +alternatively compute a MAC rather than encrypting with AES-GCM. +The client sends this challenge to a server. Server Response And Challenge ----------------------------- -Once the client challenge is received, the server will generate the same auth key by -using the same algorithm the client has used. It will then verify the client challenge: -if the APP_ID and ANONCE fields match, the server knows that the client has the shared -secret. The server then creates a response to the client challenge, to prove that it also -has the secret key, and provides parameters to be used when creating the session key. - -The following describes the response from the server: - - SERVER_CHALLENGE = ( - ENC(APP_ID || ANONCE || RESPONSE), - ENC(SNONCE), - ENC(INIV), - ENC(OUTIV)) - -Where: - -- RESPONSE: the server's response to the client challenge. -- SNONCE: a nonce to be used as salt when generating the session key. -- INIV: initialization vector used to initialize the input channel of the client. -- OUTIV: initialization vector used to initialize the output channel of the client. - -At this point the server considers the client to be authenticated, and will try to -decrypt any data further sent by the client using the session key. - - -Default Algorithms ------------------- - -Configuration options are available for the KDF and cipher algorithms to use. - -The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm -from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support -PBEKeySpec when generating keys. The default number of iterations was chosen to take a -reasonable amount of time on modern CPUs. See the documentation in TransportConf for more -details. - -The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any -algorithm supported by the commons-crypto library. It should allow the cipher to operate -in stream mode. - -The default key length is 128 (bits). - - -Implementation Details ----------------------- - -The commons-crypto library currently only supports AES ciphers, and requires an initialization -vector (IV). This first version of the protocol does not explicitly include the IV in the client -challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and -padding the IV with zeroes in case the nonce is not long enough. - -Future versions of the protocol might add support for new ciphers and explicitly include needed -configuration parameters in the messages. - - -Threat Assessment +Once the client challenge is received, the server will derive the same key encryption key and +recover the client's public key: + + assert(appId = clientChallenge.appId) + preSharedKey = lookupKey(appId) + aadState = Concat(appId, clientChallenge.nonSecretSalt) + keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState) + clientPublicKey = AES-GCM-Decrypt( + key = keyEncryptingKey, + iv = clientChallenge.randomIV, + ciphertext = clientChallenge.ciphertext, + aad = aadState) + +The server can then send its own ephemeral public key to the client, encrypted under a key derived +from the pre-shared key and the protocol transcript so far: + + preSharedKey = lookupKey(appId) + nonSecretSalt = Random(16 bytes) + aadState = Concat(appId, nonSecretSalt, clientChallenge) + keyEncryptingKey = HKDF(preSharedKey, nonSecretSalt, aadState) + randomIV = Random(16 bytes) + serverKeyPair = X25519.generate() + ciphertext = AES-GCM-Encrypt( + key = keyEncryptingKey, + iv = randomIV, + plaintext = serverKeyPair.publicKey(), + aad = aadState) + serverResponse = (appId, nonSecretSalt, randomIV, ciphertext) + +Now that the server has the client's ephemeral public key, it can generate its own ephemeral +keypair and compute a shared secret. + + sharedSecret = X25519.computeSharedSecret(clientPublicKey, serverKeyPair.privateKey()) + +With the shared secret, the server will also generate two initialization vectors to be used for +inbound and outbound streams. These IVs are not secret and will be bound to the preceding protocol +transcript in order to be deterministic by both parties. + + clientIv = HKDF(sharedSecret, salt=transcript, info="clientIv") + serverIv = HKDF(sharedSecret, salt=transcript, info="serverIv") + +The server can then send its response to the client, who can decrypt the server's ephemeral public +key, and reconstruct the same shared secret and IVs. + +Security Comments ----------------- -The protocol is secure against different forms of attack: - -* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible - to calculate the original secret from the encrypted messages. Neither the secret nor any - encryption keys are transmitted on the wire, encrypted or not. - -* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to - know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a - malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged - between client and server, nor insert arbitrary commands for the server to execute. +This protocol is essentially a [NNpsk0](http://www.noiseprotocol.org/noise.html#pattern-modifiers) +pattern in the [Noise framework](http://www.noiseprotocol.org/) built around ECDHE using X25519 as +the underlying curve. If the pre-shared key is compromised, it does not allow for recovery of past +sessions. It would, however, allow impersonation of future sessions. -* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to - just replay messages sniffed from the communication channel. +In the event of a pre-shared key compromise, messages would still be confidential from a passive +observer. Only active adversaries spoofing a session would be able to recover plaintext. -An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the -shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be -able to generate a session key, which will make it hard to craft a valid, encrypted message that the -server will be able to understand. This will cause the server to close the connection as soon as the -attacker tries to send any command to the server. The attacker can just hold the channel open for -some time, which will be closed when the server times out the channel. These issues could be -separately mitigated by adding a shorter timeout for the first message after authentication, and -potentially by adding host blacklists if a possible attack is detected from a particular host. diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index 382b7337d715f..fbd8a55bc1b20 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -19,30 +19,42 @@ import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; import java.util.Arrays; import java.util.Map; -import java.security.InvalidKeyException; import java.util.Random; -import static java.nio.charset.StandardCharsets.UTF_8; - import com.google.common.collect.ImmutableMap; +import com.google.crypto.tink.subtle.Hex; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.FileRegion; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.junit.Assert.*; import org.junit.BeforeClass; import org.junit.Test; +import static org.mockito.Mockito.*; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.junit.Assert.*; -import static org.mockito.Mockito.*; - -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; public class AuthEngineSuite { + private static final String clientPrivate = + "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; + private static final String clientChallengeHex = + "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + + "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + + "65f8c426e18ff380f6"; + private static final String serverResponseHex = + "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + + "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + + "08ecad08b46b5ee3ff"; + private static final String sharedKey = + "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; + private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; private static TransportConf conf; @BeforeClass @@ -56,9 +68,9 @@ public void testAuthEngine() throws Exception { AuthEngine server = new AuthEngine("appId", "secret", conf); try { - ClientChallenge clientChallenge = client.challenge(); - ServerResponse serverResponse = server.respond(clientChallenge); - client.validate(serverResponse); + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); TransportCipher serverCipher = server.sessionCipher(); TransportCipher clientCipher = client.sessionCipher(); @@ -72,50 +84,113 @@ public void testAuthEngine() throws Exception { } } - @Test - public void testMismatchedSecret() throws Exception { - AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "different_secret", conf); + @Test(expected = IllegalArgumentException.class) + public void testCorruptChallengeAppId() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage corruptChallenge = + new AuthMessage("junk", clientChallenge.salt, clientChallenge.ciphertext); + AuthMessage serverResponse = server.response(corruptChallenge); + } + } - ClientChallenge clientChallenge = client.challenge(); - try { - server.respond(clientChallenge); - fail("Should have failed to validate response."); - } catch (IllegalArgumentException e) { - // Expected. + @Test(expected = GeneralSecurityException.class) + public void testCorruptChallengeSalt() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + clientChallenge.salt[0] ^= 1; + AuthMessage serverResponse = server.response(clientChallenge); } } - @Test(expected = IllegalArgumentException.class) - public void testWrongAppId() throws Exception { - AuthEngine engine = new AuthEngine("appId", "secret", conf); - ClientChallenge challenge = engine.challenge(); - - byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce, - engine.rawResponse(engine.challenge)); - engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, - challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + @Test(expected = GeneralSecurityException.class) + public void testCorruptChallengeCiphertext() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + clientChallenge.ciphertext[0] ^= 1; + AuthMessage serverResponse = server.response(clientChallenge); + } } @Test(expected = IllegalArgumentException.class) - public void testWrongNonce() throws Exception { - AuthEngine engine = new AuthEngine("appId", "secret", conf); - ClientChallenge challenge = engine.challenge(); - - byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 }, - engine.rawResponse(engine.challenge)); - engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, - challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + public void testCorruptResponseAppId() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + AuthMessage corruptResponse = + new AuthMessage("junk", serverResponse.salt, serverResponse.ciphertext); + client.deriveSessionCipher(clientChallenge, corruptResponse); + } } - @Test(expected = IllegalArgumentException.class) - public void testBadChallenge() throws Exception { - AuthEngine engine = new AuthEngine("appId", "secret", conf); - ClientChallenge challenge = engine.challenge(); + @Test(expected = GeneralSecurityException.class) + public void testCorruptResponseSalt() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + serverResponse.salt[0] ^= 1; + client.deriveSessionCipher(clientChallenge, serverResponse); + } + } + + @Test(expected = GeneralSecurityException.class) + public void testCorruptServerCiphertext() throws Exception { + + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + serverResponse.ciphertext[0] ^= 1; + client.deriveSessionCipher(clientChallenge, serverResponse); + } + } + + @Test + public void testFixedChallenge() throws Exception { + try (AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + // This tests that the server will accept an old challenge as expected. However, + // it will generate a fresh ephemeral keypair, so we can't replay an old session. + AuthMessage freshServerResponse = server.response(clientChallenge); + } + } - byte[] badChallenge = new byte[challenge.challenge.length]; - engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, - challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + @Test + public void testFixedChallengeResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), sharedKey); + assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); + } + } + + @Test(expected = GeneralSecurityException.class) + public void testMismatchedSecret() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "different_secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + server.response(clientChallenge); + } } @Test @@ -123,9 +198,9 @@ public void testEncryptedMessage() throws Exception { AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf); try { - ClientChallenge clientChallenge = client.challenge(); - ServerResponse serverResponse = server.respond(clientChallenge); - client.validate(serverResponse); + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); TransportCipher cipher = server.sessionCipher(); TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); @@ -151,9 +226,9 @@ public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception { AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf); try { - ClientChallenge clientChallenge = client.challenge(); - ServerResponse serverResponse = server.respond(clientChallenge); - client.validate(serverResponse); + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); TransportCipher cipher = server.sessionCipher(); TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); @@ -193,7 +268,7 @@ public Long answer(InvocationOnMock invocationOnMock) throws Throwable { } } - @Test(expected = InvalidKeyException.class) + @Test(expected = AssertionError.class) public void testBadKeySize() throws Exception { Map mconf = ImmutableMap.of("spark.network.crypto.keyLength", "42"); TransportConf conf = new TransportConf("rpc", new MapConfigProvider(mconf)); @@ -201,7 +276,6 @@ public void testBadKeySize() throws Exception { try (AuthEngine engine = new AuthEngine("appId", "secret", conf)) { engine.challenge(); fail("Should have failed to create challenge message."); - // Call close explicitly to make sure it's idempotent. engine.close(); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java index a90ff247da4fc..baed940369151 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java @@ -17,15 +17,11 @@ package org.apache.spark.network.crypto; -import java.nio.ByteBuffer; -import java.util.Arrays; - import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; import org.junit.Test; -import static org.junit.Assert.*; - -import org.apache.spark.network.protocol.Encodable; public class AuthMessagesSuite { @@ -42,39 +38,15 @@ private static byte[] byteArray() { } return bytes; } - private static int integer() { - return COUNTER++; - } - - @Test - public void testClientChallenge() { - ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(), - byteArray(), byteArray()); - ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg)); - - assertEquals(msg.appId, decoded.appId); - assertEquals(msg.kdf, decoded.kdf); - assertEquals(msg.iterations, decoded.iterations); - assertEquals(msg.cipher, decoded.cipher); - assertEquals(msg.keyLength, decoded.keyLength); - assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); - assertTrue(Arrays.equals(msg.challenge, decoded.challenge)); - } - @Test - public void testServerResponse() { - ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray()); - ServerResponse decoded = ServerResponse.decodeMessage(encode(msg)); - assertTrue(Arrays.equals(msg.response, decoded.response)); - assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); - assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv)); - assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv)); - } - - private ByteBuffer encode(Encodable msg) { + public void testPublicKeyEncodeDecode() { + AuthMessage msg = new AuthMessage(string(), byteArray(), byteArray()); ByteBuf buf = Unpooled.buffer(); msg.encode(buf); - return buf.nioBuffer(); - } + AuthMessage decoded = AuthMessage.decodeMessage(buf.nioBuffer()); + assertEquals(msg.appId, decoded.appId); + assertArrayEquals(msg.salt, decoded.salt); + assertArrayEquals(msg.ciphertext, decoded.ciphertext); + } } diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index a83dc31e366a5..57fb136bc0cb9 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -188,6 +188,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.7.0//stream-2.7.0.jar stringtemplate/3.2.1//stringtemplate-3.2.1.jar super-csv/2.2.0//super-csv-2.2.0.jar +tink/1.6.0//tink-1.6.0.jar univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar xbean-asm6-shaded/4.8//xbean-asm6-shaded-4.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index eb6305de37e74..29077c701dbc7 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -189,6 +189,7 @@ stax-api/1.0.1//stax-api-1.0.1.jar stream/2.7.0//stream-2.7.0.jar stringtemplate/3.2.1//stringtemplate-3.2.1.jar super-csv/2.2.0//super-csv-2.2.0.jar +tink/1.6.0//tink-1.6.0.jar univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar xbean-asm6-shaded/4.8//xbean-asm6-shaded-4.8.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index b9db185f6d61d..0552b6ce71204 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -209,6 +209,7 @@ stax2-api/3.1.4//stax2-api-3.1.4.jar stream/2.7.0//stream-2.7.0.jar stringtemplate/3.2.1//stringtemplate-3.2.1.jar super-csv/2.2.0//super-csv-2.2.0.jar +tink/1.6.0//tink-1.6.0.jar token-provider/1.0.1//token-provider-1.0.1.jar univocity-parsers/2.7.3//univocity-parsers-2.7.3.jar validation-api/1.1.0.Final//validation-api-1.1.0.Final.jar diff --git a/pom.xml b/pom.xml index 6fe0a16900aea..889776fe18593 100644 --- a/pom.xml +++ b/pom.xml @@ -187,6 +187,7 @@ 2.8 1.8 1.0.0 + 1.6.0