Skip to content

Commit

Permalink
Enforce symmetry in the combine method for the moments
Browse files Browse the repository at this point in the history
Do not use and maintain the FirstMoment accept variables in the combine
method. These are not required for higher order combine methods.
  • Loading branch information
aherbert committed Oct 3, 2023
1 parent d82aebb commit 4187c2d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ class FirstMoment implements DoubleConsumer {
/**
* Deviation of most recently added value from the previous first moment.
* Retained to prevent repeated computation in higher order moments.
* Note: This value is not used in the {@link #combine(FirstMoment)} method.
*/
protected double dev;

/**
* Deviation of most recently added value from the previous first moment,
* normalized by current sample size. Retained to prevent repeated
* computation in higher order moments.
* Note: This value is not used in the {@link #combine(FirstMoment)} method.
*/
protected double nDev;

Expand Down Expand Up @@ -183,26 +185,49 @@ public void accept(double value) {
* @return {@code this} instance after combining {@code other}.
*/
FirstMoment combine(FirstMoment other) {
if (n == 0) {
n = other.n;
nonFiniteValue = other.nonFiniteValue;
dev = other.dev;
nDev = other.nDev;
m1 = other.m1;
} else if (other.n != 0) {
n += other.n;
nonFiniteValue += other.nonFiniteValue;
dev = other.m1 * 0.5 - m1 * 0.5;
// In contrast to the accept method, here nDev can be close to MAX_VALUE
// if the weight (other.n / n) approaches 1. So we cannot yet rescale nDev and
// instead have to combine it with the scaled-down value of m1.
nDev = dev * ((double) other.n / n);
m1 = m1 * 0.5 + nDev;
// Scale up the terms.
m1 *= 2;
dev *= 2;
nDev *= 2;
nonFiniteValue += other.nonFiniteValue;
final double mu1 = this.m1;
final double mu2 = other.m1;
final long n1 = n;
final long n2 = other.n;
n = n1 + n2;
// Adjust the mean with the weighted difference:
// m1 = m1 + (m2 - m1) * n2 / (n1 + n2)
// The difference between means can be 2 * MAX_VALUE so the computation optionally
// scales by a factor of 2. Avoiding scaling if possible preserves sub-normals.
if (n1 == n2) {
// Optimisation for equal sizes: m1 = (m1 + m2) / 2
// Use scaling for a large sum
final double sum = mu1 + mu2;
m1 = Double.isFinite(sum) ?
sum * 0.5 :
mu1 * 0.5 + mu2 * 0.5;
} else {
// Use scaling for a large difference
if (Double.isFinite(mu2 - mu1)) {
m1 = combine(mu1, mu2, n1, n2);
} else {
m1 = 2 * combine(mu1 * 0.5, mu2 * 0.5, n1, n2);
}
}
return this;
}

/**
* Combine the moments. This method is used to enforce symmetry. It assumes that
* the two sizes are not identical, and at least one size is non-zero.
*
* @param m1 Moment 1.
* @param m2 Moment 2.
* @param n1 Size of sample 1.
* @param n2 Size of sample 2.
* @return the combined first moment
*/
private static double combine(double m1, double m2, long n1, long n2) {
// Note: If either size is zero the weighted difference is zero and
// the other moment is unchanged.
return n2 < n1 ?
m1 + (m2 - m1) * ((double) n2 / (n1 + n2)) :
m2 + (m1 - m2) * ((double) n1 / (n1 + n2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ SumOfSquaredDeviations combine(SumOfSquaredDeviations other) {
} else if (m != 0) {
// "Updating one-pass algorithm"
// See: Chan et al (1983) Equation 1.5b (modified for the mean)
final double diffOfMean = other.getFirstMoment() - m1;
final double diffOfMean = other.m1 - m1;
final double sqDiffOfMean = diffOfMean * diffOfMean;
sumSquaredDev += other.sumSquaredDev + sqDiffOfMean * (((double) n * m) / ((double) n + m));
// Enforce symmetry
sumSquaredDev = (sumSquaredDev + other.sumSquaredDev) +
sqDiffOfMean * (((double) n * m) / ((double) n + m));
}
super.combine(other);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void testCombine(double[] array1, double[] array2) {
Mean mean1b = Mean.create();
Arrays.stream(array1).forEach(mean1b);
mean2.combine(mean1b);
TestHelper.assertEquals(expected, mean2.getAsDouble(), ULP_COMBINE, () -> "combine");
Assertions.assertEquals(mean1.getAsDouble(), mean2.getAsDouble(), () -> "combine reversed");
Assertions.assertEquals(mean1BeforeCombine, mean1b.getAsDouble());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ void testCombine(double[] array1, double[] array2) {
Variance var1b = Variance.create();
Arrays.stream(array1).forEach(var1b);
var2.combine(var1b);
TestHelper.assertEquals(expected, var2.getAsDouble(), ULP_COMBINE_ACCEPT, () -> "combine");
Assertions.assertEquals(var1.getAsDouble(), var2.getAsDouble(), () -> "combine reversed");
Assertions.assertEquals(var1BeforeCombine, var1b.getAsDouble());
}

Expand Down

0 comments on commit 4187c2d

Please sign in to comment.