diff --git a/src/main/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolver.java b/src/main/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolver.java deleted file mode 100644 index dd624271744..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolver.java +++ /dev/null @@ -1,164 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -import com.google.common.annotations.VisibleForTesting; -import org.apache.commons.math3.analysis.UnivariateFunction; -import org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver; -import org.apache.commons.math3.analysis.solvers.BrentSolver; -import org.apache.commons.math3.exception.NoBracketingException; -import org.apache.commons.math3.exception.TooManyEvaluationsException; -import org.apache.commons.math3.util.FastMath; -import org.broadinstitute.hellbender.utils.Utils; - -import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * A robust version of the Brent solver that tries to avoid spurious non-bracketing conditions. - * - * The root is bracketed by searching the solution search interval. If multiple roots are found and - * a non-null merit function is provided, the root with maximum merit is chosen. If no merit function - * is provided or if merits are equal, the leftmost root is chosen by convention. - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ -public final class RobustBrentSolver extends AbstractUnivariateSolver { - private final UnivariateFunction meritFunc; - private final int numBisections; - private final int depth; - - public RobustBrentSolver(final double relativeAccuracy, - final double absoluteAccuracy, - final double functionValueAccuracy, - @Nullable final UnivariateFunction meritFunc, - final int numBisections, - final int depth) { - super(relativeAccuracy, absoluteAccuracy, functionValueAccuracy); - this.meritFunc = meritFunc; - this.numBisections = numBisections; - this.depth = depth; - } - - @Override - public double solve(final int maxEval, - final UnivariateFunction objFunc, - final double min, - final double max) - throws TooManyEvaluationsException, NoBracketingException { - setup(maxEval, objFunc, min, max, min); /* the last parameter is not actually used */ - return doSolve(); - } - - @Override - protected double doSolve() throws TooManyEvaluationsException, NoBracketingException { - final double min = getMin(); - final double max = getMax(); - final double[] xSearchGrid = createHybridSearchGrid(min, max, numBisections, depth); - final double[] fSearchGrid = Arrays.stream(xSearchGrid).map(this::computeObjectiveValue).toArray(); - - /* find bracketing intervals on the search grid */ - final List bracketsList = detectBrackets(xSearchGrid, fSearchGrid); - if (bracketsList.isEmpty()) { - throw new NoBracketingException(min, max, fSearchGrid[0], fSearchGrid[fSearchGrid.length-1]); - } - final BrentSolver solver = new BrentSolver(getRelativeAccuracy(), getAbsoluteAccuracy(), getFunctionValueAccuracy()); - final List roots = bracketsList.stream() - .map(b -> solver.solve(getMaxEvaluations(), this::computeObjectiveValue, b.min, b.max, 0.5 * (b.min + b.max))) - .collect(Collectors.toList()); - if (roots.size() == 1 || meritFunc == null) { - return roots.get(0); - } - final double[] merits = roots.stream().mapToDouble(meritFunc::value).toArray(); - final int bestRootIndex = IntStream.range(0, roots.size()) - .boxed() - .max((i, j) -> (int) (merits[i] - merits[j])) - .get(); - return roots.get(bestRootIndex); - } - - /** - * Generates a hybrid search grid concentrated around {@code min}. The hybrid grid starts with a base-2 - * logarithmic grid. Each grid element is further divided uniformly into {@code refinementOrder} intervals. - * - * @param min left endpoint - * @param max right endpoint - * @param logSubdivisions number of logarithmic refinements - * @param uniformSubdivisions number of uniform refinements - * @return a double array of grid points - */ - @VisibleForTesting - double[] createHybridSearchGrid(final double min, final double max, final int logSubdivisions, final int uniformSubdivisions) { - final double[] baseGrid = createLogarithmicGrid(min, max, logSubdivisions, 2); - final int nBaseGrid = logSubdivisions + 2; - if (uniformSubdivisions > 1) { - final double[] refinedGrid = new double[(nBaseGrid - 1) * uniformSubdivisions + 1]; - for (int j = 0; j < nBaseGrid - 1; j++) { - final double len = (baseGrid[j + 1] - baseGrid[j]) / uniformSubdivisions; - for (int k = 0; k < uniformSubdivisions; k++) { - refinedGrid[j * uniformSubdivisions + k] = baseGrid[j] + k * len; - } - } - refinedGrid[(nBaseGrid - 1) * uniformSubdivisions] = max; - return refinedGrid; - } else { - return baseGrid; - } - } - - /** - * Creates a logarithmic grid concentrated around {@code min}. - * - * @param min left endpoint - * @param max right endpoint - * @param subdivisions number of logarithmic subdivisions - * @param base logarithm base - * @return a double array of grid points of length {@code subdivisions} + 2 - */ - private double[] createLogarithmicGrid(final double min, final double max, final int subdivisions, final double base) { - Utils.validateArg(base > 1, "The logarithm base must be greater than 1"); - final double[] grid = new double[subdivisions + 2]; - grid[0] = 0; - grid[subdivisions + 1] = max - min; - for (int j = subdivisions; j > 0; j--) { - grid[j] = grid[j+1] / base; - } - for (int j = 0; j < subdivisions + 2; j++) { - grid[j] += min; - } - return grid; - } - - @VisibleForTesting - static List detectBrackets(final double[] x, final double[] f) { - final List brackets = new ArrayList<>(); - final double[] signs = new double[f.length]; - for (int i = 0; i < f.length; i++) { - signs[i] = FastMath.signum(f[i]); - } - double prevSignum = signs[0]; - int prevIdx = 0; - int idx = 1; - while (idx < f.length) { - if (signs[idx]*prevSignum <= 0) { - brackets.add(new Bracket(x[prevIdx], x[idx])); - prevIdx = idx; - prevSignum = signs[idx]; - } - idx++; - } - return brackets; - } - - @VisibleForTesting - static final class Bracket { - final double min, max; - - Bracket(final double min, final double max) { - this.min = min; - this.max = max; - } - } -} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolver.java b/src/main/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolver.java deleted file mode 100644 index fdc7eaead08..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolver.java +++ /dev/null @@ -1,311 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -import org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver; -import org.apache.commons.math3.exception.NoBracketingException; -import org.apache.commons.math3.exception.TooManyEvaluationsException; -import org.apache.commons.math3.util.FastMath; -import org.broadinstitute.hellbender.utils.Utils; -import org.broadinstitute.hellbender.utils.param.ParamUtils; - -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * This class implements a synchronized univariate solver for solving multiple independent equations. - * It is to be used in situations where function queries have a costly overhead, though, simultaneous - * queries of multiple functions have the same overhead as single queries. - * - * Consider the tasking of solving N independent equations: - * - * f_1(x_1) = 0, - * f_2(x_2) = 0, - * ... - * f_N(x_N) = 0 - * - * One approach is to solve these equations sequentially. In certain situations, each function - * evaluation may be cheap but could entail a costly overhead (e.g. if the functions are evaluated - * in a distributed architecture). It is desirable to minimize this overhead by bundling as many - * function calls as possible, and querying the function in "chunks". - * - * Consider the ideal situation where function evaluations are infinitely cheap, however, each query has - * a considerable overhead time of \tau. Also, let us assume that the overhead of simultaneously - * querying {f_1(x_1), ..., f_N(x_N)} is the same as that of a single query, i.e. f_i(x_i). If the - * univariate solver requires k queries on average, the overhead cost of the sequential approach is - * O(k N \tau). By making simultaneous queries, this class reduces the overhead to O(k \tau). - * - * This is achieved by instantiating N threads for the N univariate solvers, accumulating their queries - * and suspending them until all threads announce their required query. - * - * TODO github/gatk-protected issue #853 - * @implNote In the current implementation, we make a thread for each solver. This is fine if the number - * of equations is reasonably small (< 200). In the future, the class must take a max number - * of threads and limit concurrency. - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ -public final class SynchronizedUnivariateSolver { - - /** - * Default value for the absolute accuracy of function evaluations - */ - private static final double DEFAULT_FUNCTION_ACCURACY = 1e-15; - - /** - * Stores queries from instantiated solvers - */ - private final ConcurrentHashMap queries; - - /** - * Stores function evaluations on {@link #queries} - */ - private final ConcurrentHashMap results; - - /** - * Number of queries before making a function call - */ - private final int numberOfQueriesBeforeCalling; - - /** - * The objective functions - */ - private final Function, Map> func; - - /** - * A list of solver jobs - */ - private final List jobDescriptions; - private final List solverDescriptions; - private final Function solverFactory; - private final Set jobIndices; - - private final Lock resultsLock = new ReentrantLock(); - private final Condition resultsAvailable = resultsLock.newCondition(); - private CountDownLatch solversCountDownLatch; - - /** - * Public constructor for invoking one of {@link AbstractUnivariateSolver} - * - * @param func the objective function (must be able to evaluate multiple univariate functions in one query) - * @param numberOfQueriesBeforeCalling Number of queries before making a function call (the default value is - * the number of equations) - */ - public SynchronizedUnivariateSolver(final Function, Map> func, - final Function solverFactory, - final int numberOfQueriesBeforeCalling) { - this.func = Utils.nonNull(func); - this.solverFactory = Utils.nonNull(solverFactory); - this.numberOfQueriesBeforeCalling = ParamUtils.isPositive(numberOfQueriesBeforeCalling, "Number of queries" + - " before calling function evaluations must be positive"); - - queries = new ConcurrentHashMap<>(numberOfQueriesBeforeCalling); - results = new ConcurrentHashMap<>(numberOfQueriesBeforeCalling); - jobDescriptions = new ArrayList<>(); - solverDescriptions = new ArrayList<>(); - jobIndices = new HashSet<>(); - } - - /** - * Add a solver jobDescription - * - * @param index a unique index for the equation - * @param min lower bound of the root - * @param max upper bound of the root - * @param x0 initial guess - * @param absoluteAccuracy absolute accuracy - * @param relativeAccuracy relative accuracy - * @param functionValueAccuracy function value accuracy - * @param maxEval maximum number of allowed evaluations - */ - public void add(final int index, final double min, final double max, final double x0, - final double absoluteAccuracy, final double relativeAccuracy, - final double functionValueAccuracy, final int maxEval) { - if (jobIndices.contains(index)) { - throw new IllegalArgumentException("A jobDescription with index " + index + " already exists; jobDescription indices must" + - " be unique"); - } - if (x0 <= min || x0 >= max) { - throw new IllegalArgumentException(String.format("The initial guess \"%f\" for equation number \"%d\" is" + - " must lie inside the provided search bracket [%f, %f]", x0, index, min, max)); - } - jobDescriptions.add(new UnivariateSolverJobDescription(index, min, max, x0, maxEval)); - solverDescriptions.add(new UnivariateSolverSpecifications(absoluteAccuracy, relativeAccuracy, functionValueAccuracy)); - } - - /** - * Add a solver jobDescription using the default function accuracy {@link #DEFAULT_FUNCTION_ACCURACY} - * - * @param index a unique index for the equation - * @param min lower bound of the root - * @param max upper bound of the root - * @param x0 initial guess - * @param absoluteAccuracy absolute accuracy - * @param relativeAccuracy relative accuracy - * @param maxEval maximum number of allowed evaluations - */ - public void add(final int index, final double min, final double max, final double x0, - final double absoluteAccuracy, final double relativeAccuracy, - final int maxEval) { - add(index, min, max, x0, absoluteAccuracy, relativeAccuracy, DEFAULT_FUNCTION_ACCURACY, maxEval); - } - - /** - * Solve the equations - * - * @return a map from equation indices to the summary of results - * @throws InterruptedException if any of the solver threads are interrupted - */ - public Map solve() throws InterruptedException { - if (jobDescriptions.isEmpty()) { - return Collections.emptyMap(); - } - final Map solvers = new HashMap<>(jobDescriptions.size()); - solversCountDownLatch = new CountDownLatch(jobDescriptions.size()); - IntStream.range(0, jobDescriptions.size()) - .forEach(jobIdx -> solvers.put(jobDescriptions.get(jobIdx).getIndex(), - new SolverWorker(solverDescriptions.get(jobIdx), jobDescriptions.get(jobIdx)))); - - /* start solver threads */ - solvers.values().forEach(worker -> new Thread(worker).start()); - - /* wait for all workers to finish */ - solversCountDownLatch.await(); - - return solvers.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().getSummary())); - } - - /** - * Require an evaluation of equation {@param index} at {$param x} - * - * @param index equation index - * @param x equation argument - * @return evaluated function value - * @throws InterruptedException if the waiting thread is interrupted - */ - private double evaluate(final int index, final double x) throws InterruptedException { - queries.put(index, x); - resultsLock.lock(); - final double value; - try { - fetchResults(); - while (!results.containsKey(index)) { - resultsAvailable.await(); - } - value = results.get(index); - results.remove(index); - } finally { - resultsLock.unlock(); - } - return value; - } - - /** - * Check if enough queries are in. If so, make a call and signal the waiting threads - */ - private void fetchResults() { - resultsLock.lock(); - try { - if (queries.size() >= FastMath.min(numberOfQueriesBeforeCalling, solversCountDownLatch.getCount())) { - results.putAll(func.apply(queries)); - queries.clear(); - resultsAvailable.signalAll(); - } - } finally { - resultsLock.unlock(); - } - } - - public enum UnivariateSolverStatus { - /** - * Solution could not be bracketed - */ - NO_BRACKETING, - - /** - * Too many function evaluations - */ - TOO_MANY_EVALUATIONS, - - /** - * The solver found the root successfully - */ - SUCCESS, - - /** - * The status is not determined yet - */ - TBD - } - - /** - * Stores the summary of a univariate solver jobDescription - */ - public final class UnivariateSolverSummary { - public final double x; - public final int evaluations; - public final UnivariateSolverStatus status; - - UnivariateSolverSummary(final double x, final int evaluations, final UnivariateSolverStatus status) { - this.x = x; - this.evaluations = evaluations; - this.status = status; - } - } - - /** - * A runnable version a solver - */ - private final class SolverWorker implements Runnable { - final UnivariateSolverJobDescription jobDescription; - final UnivariateSolverSpecifications solverDescription; - UnivariateSolverStatus status; - private UnivariateSolverSummary summary; - - SolverWorker(final UnivariateSolverSpecifications solverDescription, - final UnivariateSolverJobDescription jobDescription) { - this.solverDescription = solverDescription; - this.jobDescription = jobDescription; - status = UnivariateSolverStatus.TBD; - } - - @Override - public void run() { - double sol; - final AbstractUnivariateSolver abstractSolver = solverFactory.apply(solverDescription); - try { - sol = abstractSolver.solve(jobDescription.getMaxEvaluations(), x -> { - final double value; - try { - value = evaluate(jobDescription.getIndex(), x); - } catch (final InterruptedException ex) { - throw new RuntimeException(String.format("Evaluation of equation (n=%d) was interrupted --" + - " can not continue", jobDescription.getIndex())); - } - return value; - }, jobDescription.getMin(), jobDescription.getMax(), jobDescription.getInitialGuess()); - status = UnivariateSolverStatus.SUCCESS; - } catch (final NoBracketingException ex) { - status = UnivariateSolverStatus.NO_BRACKETING; - sol = Double.NaN; - } catch (final TooManyEvaluationsException ex) { - status = UnivariateSolverStatus.TOO_MANY_EVALUATIONS; - sol = Double.NaN; - } - summary = new UnivariateSolverSummary(sol, abstractSolver.getEvaluations(), status); - solversCountDownLatch.countDown(); - fetchResults(); - } - - UnivariateSolverSummary getSummary() { - return Utils.nonNull(summary, "Solver summary is not available"); - } - } -} diff --git a/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverJobDescription.java b/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverJobDescription.java deleted file mode 100644 index a65c32653dd..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverJobDescription.java +++ /dev/null @@ -1,67 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -import org.apache.commons.math3.exception.TooManyEvaluationsException; - -/** - * This class stores the description of a solver job. - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ - -public final class UnivariateSolverJobDescription { - - /** - * Job index (an arbitrary task-specific integer index) - */ - private final int index; - - /** - * Maximum function evaluations; solvers will throw a {@link TooManyEvaluationsException} exception if function - * evaluations exceed this number - */ - private final int maxEvaluations; - - /** - * Left endpoint for an interval that brackets the root - */ - private final double min; - - /** - * Right endpoint for an interval that brackets the root - */ - private final double max; - - /** - * Initial guess (must lie inside the interval) - */ - private final double x0; - - public UnivariateSolverJobDescription(final int index, final double min, final double max, final double x0, - final int maxEvaluations) { - this.index = index; - this.min = min; - this.max = max; - this.x0 = x0; - this.maxEvaluations = maxEvaluations; - } - - public int getIndex() { - return index; - } - - public int getMaxEvaluations() { - return maxEvaluations; - } - - public double getMin() { - return min; - } - - public double getMax() { - return max; - } - - public double getInitialGuess() { - return x0; - } -} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverSpecifications.java b/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverSpecifications.java deleted file mode 100644 index 4fb50c780c2..00000000000 --- a/src/main/java/org/broadinstitute/hellbender/utils/solver/UnivariateSolverSpecifications.java +++ /dev/null @@ -1,30 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -/** - * Accuracy specifications of a univariate solver - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ -public final class UnivariateSolverSpecifications { - - private final double absoluteAccuracy, relativeAccuracy, functionValueAccuracy; - - public UnivariateSolverSpecifications(final double absoluteAccuracy, final double relativeAccuracy, - final double functionValueAccuracy) { - this.absoluteAccuracy = absoluteAccuracy; - this.relativeAccuracy = relativeAccuracy; - this.functionValueAccuracy = functionValueAccuracy; - } - - public double getAbsoluteAccuracy() { - return absoluteAccuracy; - } - - public double getRelativeAccuracy() { - return relativeAccuracy; - } - - public double getFunctionValueAccuracy() { - return functionValueAccuracy; - } -} diff --git a/src/test/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolverUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolverUnitTest.java deleted file mode 100644 index d37d3da951d..00000000000 --- a/src/test/java/org/broadinstitute/hellbender/utils/solver/RobustBrentSolverUnitTest.java +++ /dev/null @@ -1,69 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -import org.apache.commons.math3.analysis.UnivariateFunction; -import org.apache.commons.math3.analysis.solvers.BrentSolver; -import org.apache.commons.math3.exception.NoBracketingException; -import org.apache.commons.math3.util.FastMath; -import org.broadinstitute.hellbender.GATKBaseTest; -import org.testng.Assert; -import org.testng.annotations.Test; - -import java.util.List; - -/** - * Unit tests for {@link RobustBrentSolver} - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ -public class RobustBrentSolverUnitTest extends GATKBaseTest { - - private static final double DEF_ABS_ACC = 1e-6; - private static final double DEF_REL_ACC = 1e-6; - private static final double DEF_F_ACC = 1e-15; - - @Test - public void gridTest() { - final RobustBrentSolver solver = new RobustBrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC, null, 1, 1); - final double[] x = solver.createHybridSearchGrid(0, 1, 10, 2); - Assert.assertEquals(x.length, 23); - Assert.assertEquals(x[0], 0, 1e-12); - Assert.assertEquals(x[22], 1, 1e-12); - } - - @Test - public void detectBracketsTest() { - final List brackets = RobustBrentSolver.detectBrackets( - new double[] {0, 1, 2, 3, 4, 5, 6, 7}, - new double[] {0, 1, -1, -5, 6, 7, -1, 0}); - Assert.assertEquals(brackets.size(), 5); - } - - /** - * Test on a 4th degree polynomial with 4 real roots at x = 0, 1, 2, 3. This objective function is positive for - * large enough positive and negative values of its arguments. Therefore, the simple Brent solver complains that - * the search interval does not bracket a root. The robust Brent solver, however, subdivides the given search - * interval and finds a bracketing sub-interval. - * - * The "best" root according to the given merit function (set to the anti-derivative of the objective function) - * is in fact the one at x = 0. We require the robust solver to output x = 0, and the simple solver to fail. - */ - @Test - public void simpleTest() { - final UnivariateFunction objFunc = x -> 30 * x * (x - 1) * (x - 2) * (x - 3); - final UnivariateFunction meritFunc = x -> 6 * FastMath.pow(x, 5) - 45 * FastMath.pow(x, 4) + 110 * FastMath.pow(x, 3) - - 90 * FastMath.pow(x, 2); - final RobustBrentSolver solverRobust = new RobustBrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC, - meritFunc, 4, 1); - final BrentSolver solverSimple = new BrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC); - final double xRobust = solverRobust.solve(100, objFunc, -1, 4); - Assert.assertEquals(xRobust, 0, DEF_ABS_ACC); - boolean simpleSolverFails = false; - try { - /* this will fail */ - solverSimple.solve(100, objFunc, -1, 4); - } catch (final NoBracketingException ex) { - simpleSolverFails = true; - } - Assert.assertTrue(simpleSolverFails); - } -} diff --git a/src/test/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolverUnitTest.java b/src/test/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolverUnitTest.java deleted file mode 100644 index 3a70ab9283e..00000000000 --- a/src/test/java/org/broadinstitute/hellbender/utils/solver/SynchronizedUnivariateSolverUnitTest.java +++ /dev/null @@ -1,81 +0,0 @@ -package org.broadinstitute.hellbender.utils.solver; - -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver; -import org.apache.commons.math3.analysis.solvers.BrentSolver; -import org.apache.commons.math3.util.FastMath; -import org.broadinstitute.hellbender.GATKBaseTest; -import org.testng.Assert; -import org.testng.annotations.Test; - -import java.util.Map; -import java.util.function.Function; -import java.util.stream.Collectors; - -/** - * Unit tests for {@link SynchronizedUnivariateSolver} - * - * @author Mehrtash Babadi <mehrtash@broadinstitute.org> - */ -public class SynchronizedUnivariateSolverUnitTest extends GATKBaseTest { - - private static final Function SOLVER_FACTORY = desc -> new BrentSolver(desc.getAbsoluteAccuracy(), - desc.getRelativeAccuracy(), desc.getFunctionValueAccuracy()); - - @Test - public void testFewEquations() throws InterruptedException { - final Function, Map> func = arg -> - arg.entrySet().stream() - .map(entry -> { - final int index = entry.getKey(); - final double x = entry.getValue(); - switch (index) { - case 1: - return ImmutablePair.of(index, x*x - 3); - case 2: - return ImmutablePair.of(index, x*x*x - 4); - case 3: - return ImmutablePair.of(index, x - 5); - default: - return null; - } - }).collect(Collectors.toMap(p -> p.left, p -> p.right)); - - - final SynchronizedUnivariateSolver solver = new SynchronizedUnivariateSolver(func, SOLVER_FACTORY, 3); - solver.add(1, 0, 4, 3.5, 1e-7, 1e-7, 20); - solver.add(2, 0, 3, 0.5, 1e-7, 1e-7, 20); - solver.add(3, 0, 10, 0.6, 1e-7, 1e-7, 20); - - final Map sol = solver.solve(); - Assert.assertEquals(sol.get(1).x, 1.732050, 1e-6); - Assert.assertEquals(sol.get(2).x, 1.587401, 1e-6); - Assert.assertEquals(sol.get(3).x, 5.000000, 1e-6); - } - - @Test - public void testManyEquations() throws InterruptedException { - testManyEquationsInstance(10); - testManyEquationsInstance(100); - testManyEquationsInstance(1000); - } - - private void testManyEquationsInstance(final int numEquations) throws InterruptedException { - final Function, Map> func = arg -> - arg.entrySet().stream() - .map(entry -> { - final int index = entry.getKey(); - final double x = entry.getValue(); - return ImmutablePair.of(index, FastMath.pow(x, index) - index); - }).collect(Collectors.toMap(p -> p.left, p -> p.right)); - final SynchronizedUnivariateSolver solver = new SynchronizedUnivariateSolver(func, SOLVER_FACTORY, numEquations); - for (int n = 1; n <= numEquations; n++) { - solver.add(n, 0, 2, 0.5, 1e-7, 1e-7, 100); - } - final Map sol = solver.solve(); - for (int n = 1; n <= numEquations; n++) { - Assert.assertEquals(sol.get(n).x, FastMath.pow(n, 1.0/n), 1e-6); - } - } -}