-
-
Notifications
You must be signed in to change notification settings - Fork 30.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Consider adding sumproduct() or dotproduct() to the math module #100485
Comments
@rhettinger working on implementing this feature. I'm curious, how would one be expected to use quad precision in this context? Should I just write a multiplication and addition function for a custom quad precision data structure? |
Raymond's use of "quad precision" was an active link to an article spelling out what he means. It's a software technique for getting twice native precision using single native precision float operations, and for that reason is usually called "doubled precision". Since Python's "float" is IEEE-754 "double precision" on most (all?) platforms, twice that leads to "quad". If you look at the edit history of Raymond's post, at first it had a prototype implementation in pure Python. |
@PurityLake Please leave that for me. I've already done substantial work on it and opened this issue to get buy in before going further. |
Not needed - it would self-evidently be nice to have 😄 |
Here's the logic I'm planning to use (and the tests to make sure it correct). If you all want something different or can suggest improvements, please let me know. 'Pure Python implementation of sumprod() with accuracy enhancements.'
from math import fabs, fsum, frexp, isfinite
from sys import stderr
verbose = False
def accumulate(total, frac, x):
if fabs(total) < fabs(x):
x, total = total, x
t = total + x
frac += (total - t) + x
return t, frac
def split(x, VELTKAMP_CONSTANT=float(0x8000001)):
t = x * VELTKAMP_CONSTANT
hi = t - (t - x)
lo = x - hi
assert not isfinite(hi) or (frexp(hi)[0] * 2**26).is_integer()
assert not isfinite(hi) or (frexp(lo)[0] * 2**26).is_integer()
assert not isfinite(hi) or hi + lo == x
return hi, lo
def multiply(p, q):
hp, lp = split(p)
hq, lq = split(q)
parts = hp*hq, hp*lq, hq*lp, lp*lq
assert not all(map(isfinite, parts)) or fsum(parts) == p * q
return parts
def both_int(p, q):
return {type(p), type(q)} == {int}
def one_float_one_int(p, q):
return {type(p), type(q)} == {int, float}
def both_float(p, q):
return {type(p), type(q)} == {float}
def PyLong_AsLongAndOverflow(p):
p = int(p)
if p.bit_length() > 63:
raise OverflowError
return p
def check_long_mult_overflow(p, q):
"Multiply longs and raise if too large for a longlong."
result = p * q
if result.bit_length() > 63:
raise OverflowError
return result
def PyFloat_AsDouble(p):
assert type(p) in {int, float}
return float(p) # This can raise OverflowError
NULL = object()
def sumprod(p, q, verbose=False):
p_stopped = q_stopped = False
int_path_enabled = flt_path_enabled = True
int_total_in_use = flt_total_in_use = False
int_total = 0
flt_total = flt_frac = 0.0
obj_total = 0
p_it = iter(p)
q_it = iter(q)
while True:
p_i = next(p_it, NULL)
if p_i is NULL:
p_stopped = True
q_i = next(q_it, NULL)
if q_i is NULL:
q_stopped = True
if p_stopped != q_stopped:
raise ValueError('Inputs are not the same length')
finished = p_stopped & q_stopped
if int_path_enabled:
if not finished and both_int(p_i, q_i):
try:
if verbose: print('I', end='', file=stderr)
int_p = PyLong_AsLongAndOverflow(p_i)
int_q = PyLong_AsLongAndOverflow(q_i)
int_total += check_long_mult_overflow(int_p, int_q)
int_total_in_use = True
continue
except OverflowError:
pass
# We're finished, overflowed, or have a non-int
int_path_enabled = False
if int_total_in_use:
obj_total += int(int_total)
int_total_in_use = False
if flt_path_enabled:
if not finished and (both_float(p_i, q_i) or one_float_one_int(p_i, q_i)):
try:
if verbose: print('F', end='', file=stderr)
flt_p = PyFloat_AsDouble(p_i)
flt_q = PyFloat_AsDouble(q_i)
for part in multiply(flt_p, flt_q):
flt_total, flt_frac = accumulate(flt_total, flt_frac, part)
flt_total_in_use = True
if isfinite(flt_total):
continue
# If a non-finite value arises, fallback to the slow path
flt_total = flt_frac = 0.0
except OverflowError:
pass
# We're finished, overflowed, have a non-float, or have a non-finite value
flt_path_enabled = False
if flt_total_in_use:
flt_total += flt_frac
obj_total += float(flt_total)
flt_total_in_use = False
assert not int_total_in_use
assert not flt_total_in_use
if finished:
return obj_total
if verbose: print('O', end='', file=stderr)
obj_total += p_i * q_i
if __name__ == '__main__':
from decimal import Decimal
from fractions import Fraction
from itertools import product
from math import isnan
from random import randrange
def baseline_sumprod(p, q):
"This defines the target behavior including expections and special values."
total = 0
for p_i, q_i in zip(p, q, strict=True):
total += p_i * q_i
return total
def run(func, *args):
"Make comparing functions easier. Returns error status, type, and result."
try:
result = func(*args)
except Exception as e:
return 'Error', type(e), None
if isinstance(result, (float, Decimal)) and isnan(result):
result = 'NaN'
return None, type(result), result,
pool = [3.75, 2.5, -1.5, 0.0, 1.0,
float('inf'), -float('inf'), float('NaN'),
3, -5, 7, 0, 1, 10*2**64, 10**1000,
Fraction(7, 8), Fraction(21, 4), Fraction(-41, 2),
Fraction(0), Fraction(1),
Decimal('6.125'), Decimal('12.125'), Decimal('-2.125'),
Decimal(0), Decimal(1),
Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'),
]
args = ((), ())
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done 0')
for a, b in product(pool, repeat=2):
args = (a,), (b,)
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done 2')
for m, n in product(range(5), repeat=2):
args = [randrange(100)/16 for i in range(m)], [randrange(100)/16 for i in range(n)]
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done uneven')
for a, b, c, d in product(pool, repeat=4):
args = (a, b), (c, d)
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done 4')
alt_pool = (1e200, 2e200, 5e200, 0.0, -1e200, -2e200, -5e200)
for a, b, c, d, e, f in product(pool, repeat=6):
args = (a, b, c), (d, e, f)
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done overflow 6')
for a, b, c, d, e, f in product(pool, repeat=6):
args = (a, b, c), (d, e, f)
assert run(baseline_sumprod, *args) == run(sumprod, *args)
print('Done regular 6') |
I checked in the core logic so that we could have a baseline for further improvements. The beta release is still far away. Suggestions are welcome. Here are some topics that I'm mulling over:
|
2SUM was faster so I put it in.
Interestingly, this only makes the code two lines longer. The running time only increases by 12%. Worth it? |
Please add |
It would be good to include this in What's New. |
Fun fact,
|
If we're going to have an even more efficient method, the dot product example from |
…lementation is not available (#101567)
@rhettinger
Example use case (as variant to the motivating example at the very top message of this thread): I'd welcome this generalisation, but perhaps my comment is too late because |
It should not be too late for such an improvement yet. Beta 1 is the point when no new features are supposed to be added. |
I was reviewing the itertools recipes to see whether some were worth promoting to be builtin tools. The
dotproduct()
recipe was the best candidate. To non-matrix people this is known as sumproduct() and it comes up in many non-vector applications, possibly the most common beingsum([price * quantity for price, quantity in zip(prices, quantities)])
and the second most common being weighted averages.The current version of the recipe is:
If we offered this as part of the math module, we could make a higher quality implementation.
For float inputs or mixed int/float inputs, we could square and sum in quad precision, making a single rounding at the end. This would make a robust building block to serve as a foundation for users to construct higher level tools. It is also something that is difficult for them to do on their own.
Linked PRs
The text was updated successfully, but these errors were encountered: