Skip to content

Commit

Permalink
pythonGH-100485: Add extended accuracy test. Switch to faster fma() b…
Browse files Browse the repository at this point in the history
…ased variant. pythonGH-101383)
  • Loading branch information
rhettinger authored and mdboom committed Jan 31, 2023
1 parent ab13fe0 commit dc5314f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 36 deletions.
83 changes: 83 additions & 0 deletions Lib/test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,89 @@ def run(func, *args):
args,
)

@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sumprod() accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Other implementations may choose a different algorithm
@support.requires_resource('cpu')
def test_sumprod_extended_precision_accuracy(self):
import operator
from fractions import Fraction
from itertools import starmap
from collections import namedtuple
from math import log2, exp2, fabs
from random import choices, uniform, shuffle
from statistics import median

DotExample = namedtuple('DotExample', ('x', 'y', 'target_sumprod', 'condition'))

def DotExact(x, y):
vec1 = map(Fraction, x)
vec2 = map(Fraction, y)
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

def Condition(x, y):
return 2.0 * DotExact(map(abs, x), map(abs, y)) / abs(DotExact(x, y))

def linspace(lo, hi, n):
width = (hi - lo) / (n - 1)
return [lo + width * i for i in range(n)]

def GenDot(n, c):
""" Algorithm 6.1 (GenDot) works as follows. The condition number (5.7) of
the dot product xT y is proportional to the degree of cancellation. In
order to achieve a prescribed cancellation, we generate the first half of
the vectors x and y randomly within a large exponent range. This range is
chosen according to the anticipated condition number. The second half of x
and y is then constructed choosing xi randomly with decreasing exponent,
and calculating yi such that some cancellation occurs. Finally, we permute
the vectors x, y randomly and calculate the achieved condition number.
"""

assert n >= 6
n2 = n // 2
x = [0.0] * n
y = [0.0] * n
b = log2(c)

# First half with exponents from 0 to |_b/2_| and random ints in between
e = choices(range(int(b/2)), k=n2)
e[0] = int(b / 2) + 1
e[-1] = 0.0

x[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]
y[:n2] = [uniform(-1.0, 1.0) * exp2(p) for p in e]

# Second half
e = list(map(round, linspace(b/2, 0.0 , n-n2)))
for i in range(n2, n):
x[i] = uniform(-1.0, 1.0) * exp2(e[i - n2])
y[i] = (uniform(-1.0, 1.0) * exp2(e[i - n2]) - DotExact(x, y)) / x[i]

# Shuffle
pairs = list(zip(x, y))
shuffle(pairs)
x, y = zip(*pairs)

return DotExample(x, y, DotExact(x, y), Condition(x, y))

def RelativeError(res, ex):
x, y, target_sumprod, condition = ex
n = DotExact(list(x) + [-res], list(y) + [1])
return fabs(n / target_sumprod)

def Trial(dotfunc, c, n):
ex = GenDot(10, c)
res = dotfunc(ex.x, ex.y)
return RelativeError(res, ex)

times = 1000 # Number of trials
n = 20 # Length of vectors
c = 1e30 # Target condition number

relative_err = median(Trial(math.sumprod, c, n) for i in range(times))
self.assertLess(relative_err, 1e-16)

def testModf(self):
self.assertRaises(TypeError, math.modf)

Expand Down
53 changes: 17 additions & 36 deletions Modules/mathmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2832,12 +2832,7 @@ long_add_would_overflow(long a, long b)
}

/*
Double and triple length extended precision floating point arithmetic
based on:
A Floating-Point Technique for Extending the Available Precision
by T. J. Dekker
https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf
Double and triple length extended precision algorithms from:
Accurate Sum and Dot Product
by Takeshi Ogita, Siegfried M. Rump, and Shin’Ichi Oishi
Expand All @@ -2848,58 +2843,44 @@ based on:

typedef struct{ double hi; double lo; } DoubleLength;

static inline DoubleLength
twosum(double a, double b)
static DoubleLength
dl_sum(double a, double b)
{
// Rump Algorithm 3.1 Error-free transformation of the sum
/* Algorithm 3.1 Error-free transformation of the sum */
double x = a + b;
double z = x - a;
double y = (a - (x - z)) + (b - z);
return (DoubleLength) {x, y};
}

static inline DoubleLength
dl_split(double x) {
// Rump Algorithm 3.2 Error-free splitting of a floating point number
// Dekker (5.5) and (5.6).
double t = x * 134217729.0; // Veltkamp constant = 2.0 ** 27 + 1
double hi = t - (t - x);
double lo = x - hi;
return (DoubleLength) {hi, lo};
}

static inline DoubleLength
static DoubleLength
dl_mul(double x, double y)
{
// Dekker (5.12) and mul12()
DoubleLength xx = dl_split(x);
DoubleLength yy = dl_split(y);
double p = xx.hi * yy.hi;
double q = xx.hi * yy.lo + xx.lo * yy.hi;
double z = p + q;
double zz = p - z + q + xx.lo * yy.lo;
/* Algorithm 3.5. Error-free transformation of a product */
double z = x * y;
double zz = fma(x, y, -z);
return (DoubleLength) {z, zz};
}

typedef struct { double hi; double lo; double tiny; } TripleLength;

static const TripleLength tl_zero = {0.0, 0.0, 0.0};

static inline TripleLength
tl_fma(TripleLength total, double x, double y)
static TripleLength
tl_fma(double x, double y, TripleLength total)
{
// Rump Algorithm 5.10 with K=3 and using SumKVert
/* Algorithm 5.10 with SumKVert for K=3 */
DoubleLength pr = dl_mul(x, y);
DoubleLength sm = twosum(total.hi, pr.hi);
DoubleLength r1 = twosum(total.lo, pr.lo);
DoubleLength r2 = twosum(r1.hi, sm.lo);
DoubleLength sm = dl_sum(total.hi, pr.hi);
DoubleLength r1 = dl_sum(total.lo, pr.lo);
DoubleLength r2 = dl_sum(r1.hi, sm.lo);
return (TripleLength) {sm.hi, r2.hi, total.tiny + r1.lo + r2.lo};
}

static inline double
static double
tl_to_d(TripleLength total)
{
DoubleLength last = twosum(total.lo, total.hi);
DoubleLength last = dl_sum(total.lo, total.hi);
return total.tiny + last.lo + last.hi;
}

Expand Down Expand Up @@ -3066,7 +3047,7 @@ math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q)
} else {
goto finalize_flt_path;
}
TripleLength new_flt_total = tl_fma(flt_total, flt_p, flt_q);
TripleLength new_flt_total = tl_fma(flt_p, flt_q, flt_total);
if (isfinite(new_flt_total.hi)) {
flt_total = new_flt_total;
flt_total_in_use = true;
Expand Down

0 comments on commit dc5314f

Please sign in to comment.