Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge old optimisations from spdz2k-develop #385

Merged
merged 19 commits into from
Feb 18, 2022
Merged
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dk.alexandra.fresco.framework.util;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class MathUtils {
Expand Down Expand Up @@ -113,4 +115,20 @@ private static Pair<BigInteger, Integer> expressAsProductOfPowerOfTwo(BigInteger
return new Pair<>(q, s);
}

/**
* Turns input value into bits in big-endian order. <p> If the actual bit length of the value is
* smaller than numBits, the result is padded with 0s. If the bit length is larger only the first
* numBits bits are used. </p>
*/
public static List<BigInteger> toBits(BigInteger value, int numBits) {
List<BigInteger> bits = new ArrayList<>(numBits);
for (int b = 0; b < numBits; b++) {
boolean boolBit = value.testBit(b);
BigInteger bit = boolBit ? BigInteger.ONE : BigInteger.ZERO;
bits.add(bit);
}
Collections.reverse(bits);
return bits;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
* Holds the most crucial properties about the finite field we are working within.
*/
public class BasicNumericContext {
// TODO temporary hardcoded statistical security parameter
private static final int DEFAULT_STATISTICAL_SECURITY = 40;

private final int statisticalSecurityParam;
private final int maxBitLength;
private final int myId;
private final int noOfParties;
Expand All @@ -23,14 +26,32 @@ public class BasicNumericContext {
* @param noOfParties number of parties in computation
* @param fieldDefinition the field definition used in the application
* @param defaultFixedPointPrecision the fixed point precision when using the fixed point library
* @param statisticalSecurityParam the statistical security parameter
*/
public BasicNumericContext(int maxBitLength, int myId, int noOfParties,
FieldDefinition fieldDefinition, int defaultFixedPointPrecision) {
FieldDefinition fieldDefinition, int defaultFixedPointPrecision, int statisticalSecurityParam) {
this.maxBitLength = maxBitLength;
this.myId = myId;
this.noOfParties = noOfParties;
this.fieldDefinition = fieldDefinition;
this.defaultFixedPointPrecision = defaultFixedPointPrecision;
this.statisticalSecurityParam = statisticalSecurityParam;
}

/**
* Construct a new BasicNumericContext.
*
* @param maxBitLength The maximum length in bits that the numbers in the application will
* have.
* @param myId my party id
* @param noOfParties number of parties in computation
* @param fieldDefinition the field definition used in the application
* @param defaultFixedPointPrecision the fixed point precision when using the fixed point library
*/
public BasicNumericContext(int maxBitLength, int myId, int noOfParties,
FieldDefinition fieldDefinition, int defaultFixedPointPrecision) {
this(maxBitLength, myId, noOfParties, fieldDefinition, defaultFixedPointPrecision,
DEFAULT_STATISTICAL_SECURITY);
}

/**
Expand Down Expand Up @@ -80,5 +101,11 @@ public int getDefaultFixedPointPrecision() {
return defaultFixedPointPrecision;
}

/**
* Returns the statistical security parameter.
*/
public int getStatisticalSecurityParam() {
return this.statisticalSecurityParam;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,11 @@ public void testModularSqrtNoSqrt() {
MathUtils.modularSqrt(new BigInteger("23"), modulus);
}

@Test
public void testToBits() {
BigInteger value = BigInteger.valueOf(11);
List<BigInteger> expected = Arrays.asList(BigInteger.ONE, BigInteger.ZERO, BigInteger.ONE, BigInteger.ONE);
assertEquals(expected, MathUtils.toBits(value, 4));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,20 @@
import org.junit.Test;

public class BasicNumericContextTest {

private BasicNumericContext context;
private int maxBitLength = 16;
private int myId = 2;
private int noOfParties = 2;
private FieldDefinition fieldDefinition = mock(FieldDefinition.class);
private int precision = 32;
private int statisticalSecurityParameter = 8;

@Before
public void setup() {
when(fieldDefinition.getModulus()).thenReturn(BigInteger.ONE);
context = new BasicNumericContext(maxBitLength, myId, noOfParties, fieldDefinition, precision);
context = new BasicNumericContext(maxBitLength, myId, noOfParties, fieldDefinition, precision,
statisticalSecurityParameter);
}

@Test
Expand Down Expand Up @@ -52,4 +55,10 @@ public void getMyId() {
public void getNoOfParties() {
Assert.assertEquals(context.getNoOfParties(), noOfParties);
}

@Test
public void getStatisticalSecurityParameter() {
Assert.assertEquals(context.getStatisticalSecurityParam(), statisticalSecurityParameter);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dk.alexandra.fresco.framework.util.Pair;
import dk.alexandra.fresco.framework.value.SInt;
import dk.alexandra.fresco.lib.common.compare.Comparison;
import dk.alexandra.fresco.lib.common.compare.Comparison.Algorithm;
import dk.alexandra.fresco.lib.common.compare.DefaultComparison;
import java.io.IOException;
import java.math.BigInteger;
Expand Down Expand Up @@ -50,7 +51,7 @@ public DRes<BigInteger> buildComputation(ProtocolBuilderNumeric producer) {
Pair<DRes<SInt>, DRes<SInt>> input = new Pair<>(x1, x2);
return () -> input;
}).seq((seq, input) -> {
DRes<SInt> equals = Comparison.using(seq).equals(32, input.getFirst(), input.getSecond());
DRes<SInt> equals = Comparison.using(seq).equals(input.getFirst(), input.getSecond(), 32, Algorithm.CONST_ROUNDS);
DRes<BigInteger> open = seq.numeric().open(equals);
return open;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testInputCmdLine() throws Exception {
try {
InputSumExample.main(
new String[]{"-i", "1", "-p", "1:localhost:8081", "-p", "2:localhost:8082", "-s",
"dummyArithmetic"});
"spdz"});
} catch (IOException e) {
System.exit(-1);
}
Expand All @@ -136,7 +136,7 @@ public void testInputCmdLine() throws Exception {
try {
InputSumExample.main(
new String[]{"-i", "2", "-p", "1:localhost:8081", "-p", "2:localhost:8082", "-s",
"dummyArithmetic"});
"spdz"});
} catch (IOException e) {
System.exit(-1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ DRes<SInt>, DRes<SInt>, ProtocolBuilderNumeric> numeric(

return new KeyedCompareAndSwap<>(leftKeyAndValue, rightKeyAndValue,
(a, b, builder) -> Comparison
.using(builder).compareLEQ(b, a),
.using(builder).compareLT(b, a),
(a, b, builder) -> builder.numeric().add(a, b),
(a, b, builder) -> builder.numeric().sub(a, b),
(a, b, builder) -> builder.numeric().add(a, b),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import dk.alexandra.fresco.framework.DRes;
import dk.alexandra.fresco.framework.builder.ComputationDirectory;
import dk.alexandra.fresco.framework.builder.numeric.ProtocolBuilderNumeric;
import dk.alexandra.fresco.framework.util.Pair;
import dk.alexandra.fresco.framework.value.SInt;
import dk.alexandra.fresco.lib.common.util.SIntPair;
import java.math.BigInteger;
import java.util.List;

/**
* Interface for comparing numeric values.
Expand All @@ -21,57 +25,121 @@ static Comparison using(ProtocolBuilderNumeric builder) {
}

/**
* Compares two values and return x == y
* @param bitLength The maximum bit-length of the numbers to compare.
* @param x The first number
* @param y The second number
* @return A deferred result computing x == y
* The different algorithms supported by Fresco. The enum is used to decide of whether an
* algorithm running in constant rounds or logarithmic rounds should be used. In general the
* logarithmic round choice is the fastest.
*/
DRes<SInt> equals(int bitLength, DRes<SInt> x, DRes<SInt> y);

enum Algorithm {
LOG_ROUNDS, CONST_ROUNDS
}

/**
* Computes x == y.
*
* @param x input
* @param y input
* @return A deferred result computing x == y. Result will be either [1] (true) or [0] (false).
* @param x the first input
* @param y the second input
* @param bitlength the amount of bits to do the equality test on. Must be less than or equal to
* the max bitlength allowed
* @param algorithm the algorithm to use
* @return A deferred result computing x' == y'. Where x' and y' represent the {@code bitlength}
* least significant bits of x, respectively y. Result will be either [1] (true) or [0] (false).
*/
DRes<SInt> equals(DRes<SInt> x, DRes<SInt> y, int bitlength, Algorithm algorithm);

/**
* Call to {@link #equals(DRes, DRes, int, Algorithm)} with default comparison algorithm.
*/
default DRes<SInt> equals(DRes<SInt> x, DRes<SInt> y, int bitlength) {
return equals(x, y, bitlength, Algorithm.LOG_ROUNDS);
}

/**
* Call to {@link #equals(DRes, DRes, int, Algorithm)} with default comparison algorithm, checking
* equality of all bits.
*/
DRes<SInt> equals(DRes<SInt> x, DRes<SInt> y);


/**
* Computes if x <= y.
*
* @param x the first input
* @param y the second input
* @return A deferred result computing x <= y. Result will be either [1] (true) or [0] (false).
*/
@Deprecated
DRes<SInt> compareLEQ(DRes<SInt> x, DRes<SInt> y);

/**
* Computes if x1 &le; x2.
* @param x1 input
* @param x2 input
* @return A deferred result computing x1 &le; x2. Result will be either [1] (true) or [0] (false).
* Computes if x < y.
*
* @param x the first input
* @param y the second input
* @param algorithm the algorithm to use
* @return A deferred result computing x < y. Result will be either [1] (true) or [0] (false).
*/
DRes<SInt> compareLEQ(DRes<SInt> x1, DRes<SInt> x2);
DRes<SInt> compareLT(DRes<SInt> x, DRes<SInt> y, Algorithm algorithm);

/**
* Compares if x1 &le; x2, but with twice the possible bit-length.
* Requires that the maximum bit length is set to something that can handle
* this scenario. It has to be at least less than half the modulus bit size.
*
* @param x1 input
* @param x2 input
* @return A deferred result computing x1 &le; x2. Result will be either [1] (true) or [0] (false).
* Call to {@link #compareLT(DRes, DRes, Algorithm)} with default comparison algorithm.
*/
DRes<SInt> compareLEQLong(DRes<SInt> x1, DRes<SInt> x2);
default DRes<SInt> compareLT(DRes<SInt> x, DRes<SInt> y) {
return compareLT(x, y, Algorithm.LOG_ROUNDS);
}

/**
* Computes if the bit decomposition of an open value is less than the bit decomposition of a
* secret value.
*
* @param openValue open value which will be decomposed into bits and compared to secretBits
* @param secretBits secret value decomposed into bits
*/
DRes<SInt> compareLTBits(BigInteger openValue, DRes<List<DRes<SInt>>> secretBits);

/**
* Compares if x <= y, but with twice the possible bit-length. Requires that the maximum bit
* length is set to something that can handle this scenario. It has to be at least less than half
* the modulus bit size.
*
* @param x the first input
* @param y the second input
* @return A deferred result computing x <= y. Result will be either [1] (true) or [0] (false).
*/
@Deprecated
DRes<SInt> compareLEQLong(DRes<SInt> x, DRes<SInt> y);

/**
* Computes the sign of the value (positive or negative)
*
*
* @param x The value to compute the sign off
* @return A deferred result computing the sign. Result will be 1 if the value is positive
* (including 0) and -1 if negative.
* (including 0) and -1 if negative.
*/
DRes<SInt> sign(DRes<SInt> x);

/**
* Test for equality with zero for a bitLength-bit number (positive or negative)
*
* @param x the value to test against zero
* @param bitLength bitlength
* @return A deferred result computing x == 0. Result will be either [1] (true) or [0] (false)
* @param bitlength the amount of bits to do the zero-test on. Must be less than or equal to the
* modulus bitlength
* @param algorithm the algorithm to use for zero-equality test
* @return A deferred result computing x' == 0 where x' is the {@code bitlength} least significant
* bits of x. Result will be either [1] (true) or [0] (false)
*/
DRes<SInt> compareZero(DRes<SInt> x, int bitLength);
DRes<SInt> compareZero(DRes<SInt> x, int bitlength, Algorithm algorithm);

/**
* Computes the index of the minimum element in a list and the element itself. <p>The index is
* expressed as a list of bits where all bits are 0 except for the bit at the index of the minimum
* element, which is set to 1.</p>
*/
DRes<Pair<List<DRes<SInt>>, SInt>> argMin(List<DRes<SInt>> xs);

/**
* Call to {@link #compareZero(DRes, int, Algorithm)} with default comparison algorithm.
*/
default DRes<SInt> compareZero(DRes<SInt> x, int bitlength) {
return compareZero(x, bitlength, Algorithm.LOG_ROUNDS);
}

}
Loading