Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decrease the timing differences when trimming zeros from DH TLS PMS #129

Merged
merged 3 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ method.

### Improvements
* Stricter guarantees about which curves are used for EC key generation. [PR #127](https://github.com/corretto/amazon-corretto-crypto-provider/pull/127)
* Reduce timing signal from trimming zeros of TLSPremasterSecrets from DH KeyAgreement. [PR #129](https://github.com/corretto/amazon-corretto-crypto-provider/pull/129)

### Patches
* Add version gating to some tests introduced in 1.5.0 [PR #128](https://github.com/corretto/amazon-corretto-crypto-provider/pull/128)
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ set(ACCP_SRC
src/com/amazon/corretto/crypto/provider/AccessibleByteArrayOutputStream.java
src/com/amazon/corretto/crypto/provider/AesCtrDrbg.java
src/com/amazon/corretto/crypto/provider/AesGcmSpi.java
src/com/amazon/corretto/crypto/provider/ConstantTime.java
src/com/amazon/corretto/crypto/provider/EcGen.java
src/com/amazon/corretto/crypto/provider/EvpKeyAgreement.java
src/com/amazon/corretto/crypto/provider/EvpKeyType.java
Expand Down Expand Up @@ -528,6 +529,7 @@ add_jar(
tst/com/amazon/corretto/crypto/provider/test/AESGenerativeTest.java
tst/com/amazon/corretto/crypto/provider/test/AesCtrDrbgTest.java
tst/com/amazon/corretto/crypto/provider/test/AesTest.java
tst/com/amazon/corretto/crypto/provider/test/ConstantTimeTests.java
tst/com/amazon/corretto/crypto/provider/test/EcGenTest.java
tst/com/amazon/corretto/crypto/provider/test/EvpKeyAgreementTest.java
tst/com/amazon/corretto/crypto/provider/test/EvpKeyAgreementSpecificTest.java
Expand Down
69 changes: 69 additions & 0 deletions src/com/amazon/corretto/crypto/provider/ConstantTime.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package com.amazon.corretto.crypto.provider;

/**
* Contains several constant time utilities
*/
final class ConstantTime {
private ConstantTime() {
// Prevent instantiation
}

/**
* Equivalent to {@code val != 0 ? 1 : 0}
*/
static final int isNonZero(int val) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor bikeshed: Is it possible to use byte as the return type instead of int to save on memory usage? (If the answer is "No", that's fine.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, yes.

However, Java automatically casts all bit-wise operations to int or larger. So, I'm not sure if it would practically decrease memory usage and it would make everything a bit more annoying to read (due to large numbers of casts).

return ((val | -val) >>> 31) & 0x01; // Unsigned bitshift
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Equivalent to {@code val == 0 ? 1 : 0}
*/
static final int isZero(int val) {
return 1 - isNonZero(val);
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Equivalent to {@code val < 0 ? 1 : 0}
*/
static final int isNegative(int val) {
return (val >>> 31) & 0x01;
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Equivalent to {@code x == y ? 1 : 0}
*/
static final int equal(int x, int y) {
final int difference = x - y;
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
// Difference is 0 iff x == y
return isZero(difference);
}

/**
* Equivalent to {@code x > y ? 1 : 0}
*/
static final int gt(int x, int y) {
// Convert to long to avoid underflow
final long xl = x;
final long yl = y;
final long difference = yl - xl;
// If xl > yl, then difference is negative.
// Thus, we can just return the sign-bit
return (int) ((difference >>> 63) & 0x01); // Unsigned bitshift
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Equivalent to {@code selector != 0 ? a : b}
*/
static final int select(int selector, int a, int b) {
final int mask = isZero(selector) - 1;
// Mask == -1 (all bits 1) iff selector != 0
// Mask == 0 (all bits 0) iff selector == 0

final int combined = a ^ b;

return b ^ (combined & mask);
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}
}
32 changes: 25 additions & 7 deletions src/com/amazon/corretto/crypto/provider/EvpKeyAgreement.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,35 @@ protected void reset() {
}

private static byte[] trimZeros(final byte[] secret) {
// According to other implementations, we don't appear
// to need to worry about timing leaks of this data.
int bytesToTrim = 0;
while (bytesToTrim < secret.length && secret[bytesToTrim] == 0) {
bytesToTrim++;
int foundNonZero = 0;
for (int x = 0; x < secret.length; x++) {
final int currByte = secret[x];
// Have we found something that isn't a zero?
foundNonZero |= currByte;

// foundNonZero == 0 iff we have not see any non-zero bytes
// Thus, we should update bytesToTrim iff foundNonZero == 0
final int shouldUpdateTrim = ConstantTime.isZero(foundNonZero);
bytesToTrim = ConstantTime.select(shouldUpdateTrim, x + 1, bytesToTrim);
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}

if (bytesToTrim == 0) {
return secret;
// Allocating arrays of different lengths always risks non-constant time operation.
// There is no way to avoid this.
final byte[] result = new byte[secret.length - bytesToTrim];

// We'll always do the same number of byte copies, but the leading zeros will be overwritten by valid ones.
// While the memory access pattern won't be identical, there is no way to completely avoid this.
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
for (int x = 0; x < secret.length; x++) {
final int realIndex = x - bytesToTrim;

final int notYetValid = ConstantTime.isNegative(realIndex);

final int indexToUpdate = ConstantTime.select(notYetValid, 0, realIndex);
result[indexToUpdate] = secret[x];
SalusaSecondus marked this conversation as resolved.
Show resolved Hide resolved
}
return Arrays.copyOfRange(secret, bytesToTrim, secret.length);

return result;
}

static class ECDH extends EvpKeyAgreement {
Expand Down
117 changes: 117 additions & 0 deletions tst/com/amazon/corretto/crypto/provider/test/ConstantTimeTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package com.amazon.corretto.crypto.provider.test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.params.provider.Arguments.arguments;

import java.util.ArrayList;
import java.util.List;

import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;


// Note: We don't actually test that these methods are constant time, just that they give the correct answers.
@ExtendWith(TestResultLogger.class)
@Execution(ExecutionMode.CONCURRENT)
public class ConstantTimeTests {
// A few common values which when combined can trigger edge cases
private static final int[] TEST_VALUES = {Integer.MIN_VALUE, Integer.MIN_VALUE + 1, -2, -1, 0, 1, 2, Integer.MAX_VALUE - 1, Integer.MAX_VALUE};
private static final Class<?> CONSTANT_TIME_CLASS;
static {
try {
CONSTANT_TIME_CLASS = Class.forName("com.amazon.corretto.crypto.provider.ConstantTime");
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}


public static List<Arguments> testPairs() {
final List<Arguments> result = new ArrayList<>();
for (int a : TEST_VALUES) {
for (int b : TEST_VALUES) {
result.add(arguments(a, b));
}
}
return result;
}

public static int[] testSingles() {
return TEST_VALUES;
}

@ParameterizedTest
@MethodSource("testSingles")
public void testIsNonZero(int val) {
final int expected = val != 0 ? 1 : 0;
assertEquals(expected, sneaky("isNonZero", val));
}

@ParameterizedTest
@MethodSource("testSingles")
public void testIsZero(int val) {
final int expected = val == 0 ? 1 : 0;
assertEquals(expected, sneaky("isZero", val));
}

@ParameterizedTest
@MethodSource("testSingles")
public void testIsNegative(int val) {
final int expected = val < 0 ? 1 : 0;
assertEquals(expected, sneaky("isNegative", val));
}

@ParameterizedTest
@MethodSource("testPairs")
public void testEqual(int x, int y) {
final int expected = x == y ? 1 : 0;
assertEquals(expected, sneaky("equal", x, y));
}

@ParameterizedTest
@MethodSource("testPairs")
public void testGt(int x, int y) {
final int expected = x > y ? 1 : 0;
assertEquals(expected, sneaky("gt", x, y));
}

@ParameterizedTest
@MethodSource("testSingles")
public void testSelect(int selector) {
final int a = 10;
final int b = 11;
final int expected = selector != 0 ? a : b;
assertEquals(expected, sneaky("select", selector, a, b));
}

private static int sneaky(String name, int a) {
try {
return TestUtil.sneakyInvoke_int(CONSTANT_TIME_CLASS, name, a);
} catch (final Throwable t) {
throw new AssertionError(t);
}
}

private static int sneaky(String name, int a, int b) {
try {
return TestUtil.sneakyInvoke_int(CONSTANT_TIME_CLASS, name, a, b);
} catch (final Throwable t) {
throw new AssertionError(t);
}
}

private static int sneaky(String name, int a, int b, int c) {
try {
return TestUtil.sneakyInvoke_int(CONSTANT_TIME_CLASS, name, a, b, c);
} catch (final Throwable t) {
throw new AssertionError(t);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ private void printNotice(final String notice, ExtensionContext context, String d
StringWriter causeText = new StringWriter();
if (cause != null) {
try (PrintWriter printWriter = new PrintWriter(causeText)) {
printWriter.append(" @ ").append(getFailureLocation(cause).toString());
printWriter.append(" @ ").append(String.valueOf(getFailureLocation(cause)));
// Don't print out traces for Assert.* failures which throw subclasses of AssertionError.
// Just for thrown exceptions
if (AssertionError.class.equals(cause.getClass()) ||
Expand Down