Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-100485: Add math.sumprod() #100677

Merged
merged 43 commits into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
5776b0f
Stub function.
rhettinger Jan 1, 2023
8ffd49e
Core fuctionality without error handling
rhettinger Jan 1, 2023
f5a3ecb
Test core functionality.
rhettinger Jan 2, 2023
37fcb0a
Add docs
rhettinger Jan 2, 2023
6cfc5cc
Check for non-iterable inputs
rhettinger Jan 2, 2023
fbeca6b
Handle uneven lengths and error returns from PyIter_Next().
rhettinger Jan 2, 2023
40cfc8f
Handle errors arising during the multiplications or additions.
rhettinger Jan 2, 2023
2ac24f5
Test special values
rhettinger Jan 2, 2023
1a64fd8
Add blurb
rhettinger Jan 2, 2023
5f4e493
Regenerate clinic files
rhettinger Jan 2, 2023
31fa258
Beautify with booleans.
rhettinger Jan 3, 2023
d19a38b
Add int_path. Needs work on refcount and addition overflow.
rhettinger Jan 3, 2023
a28fd71
Add more assertions and refcount fixes
rhettinger Jan 3, 2023
f998ee4
Only test for ints when iterators are not stopped.
rhettinger Jan 3, 2023
94a9cbd
Test long addition for overflow
rhettinger Jan 3, 2023
c778865
Improve wording regarding extended precision.
rhettinger Jan 3, 2023
6d619dd
Add the "finished" summary variable for readability.
rhettinger Jan 3, 2023
83dbf96
Add flt path
rhettinger Jan 4, 2023
c2e5f3d
add flt/int path
rhettinger Jan 4, 2023
b8c9127
.
rhettinger Jan 4, 2023
fdaf31b
Merge branch 'sumprod4' into sumprod
rhettinger Jan 4, 2023
4f22a7f
Update docs and docstrings.
rhettinger Jan 4, 2023
7d7fd87
Fix typo
rhettinger Jan 4, 2023
c16f90a
Use extended precision for float/float and int/float cases.
rhettinger Jan 4, 2023
c52bb84
Brevity is the soul of wit.
rhettinger Jan 4, 2023
bf8059c
Beautification.
rhettinger Jan 4, 2023
1326e83
Various minor improvments
rhettinger Jan 5, 2023
9b03468
Fully encapsulate the dl (double length) logic.
rhettinger Jan 5, 2023
8b88f4d
Add performance comment.
rhettinger Jan 5, 2023
83ca929
Write tight.
rhettinger Jan 5, 2023
708f314
Abbreviate comment.
rhettinger Jan 5, 2023
90f1ee4
Remove case for int/int while float in use.
rhettinger Jan 5, 2023
02035dc
Beautify dl_add()
rhettinger Jan 5, 2023
9b1edcc
Add stress tests.
rhettinger Jan 5, 2023
22249d6
Update itertools recipe to reflect the promotion.
rhettinger Jan 5, 2023
ec50404
Remove doctest for dotproduct().
rhettinger Jan 5, 2023
19dfd08
Speed-up iterator access. Let float be selected by bools.
rhettinger Jan 6, 2023
13f1d4f
Add assertion for a key checkpoint
rhettinger Jan 6, 2023
a07e03d
Move dl_mul() out to a separate function. Dekker (5.8).
rhettinger Jan 6, 2023
d274325
Add the smaller magnitude component first.
rhettinger Jan 6, 2023
67f82ca
Use Dekker (5.12) directly. Avoids calling dl_add() and saves a compa…
rhettinger Jan 7, 2023
5d165be
Update timing in comment.
rhettinger Jan 7, 2023
2e772cb
Add in the large component first.
rhettinger Jan 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions Doc/library/itertools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ by combining :func:`map` and :func:`count` to form ``map(f, count())``.
These tools and their built-in counterparts also work well with the high-speed
functions in the :mod:`operator` module. For example, the multiplication
operator can be mapped across two vectors to form an efficient dot-product:
``sum(map(operator.mul, vector1, vector2))``.
``sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))``.


**Infinite iterators:**
Expand Down Expand Up @@ -838,10 +838,6 @@ which incur interpreter overhead.
"Returns the sequence elements n times"
return chain.from_iterable(repeat(tuple(iterable), n))

def dotproduct(vec1, vec2):
"Compute a sum of products."
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

def convolve(signal, kernel):
# See: https://betterexplained.com/articles/intuitive-convolution/
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
Expand All @@ -852,7 +848,7 @@ which incur interpreter overhead.
window = collections.deque([0], maxlen=n) * n
for x in chain(signal, repeat(0, n-1)):
window.append(x)
yield dotproduct(kernel, window)
yield math.sumprod(kernel, window)

def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
Expand Down Expand Up @@ -1211,9 +1207,6 @@ which incur interpreter overhead.
>>> list(ncycles('abc', 3))
['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c']

>>> dotproduct([1,2,3], [4,5,6])
32

>>> data = [20, 40, 24, 32, 20, 28, 16]
>>> list(convolve(data, [0.25, 0.25, 0.25, 0.25]))
[5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]
Expand Down
16 changes: 16 additions & 0 deletions Doc/library/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,22 @@ Number-theoretic and representation functions
.. versionadded:: 3.7


.. function:: sumprod(p, q)

Return the sum of products of values from two iterables *p* and *q*.

Raises :exc:`ValueError` if the inputs do not have the same length.

Roughly equivalent to::

sum(itertools.starmap(operator.mul, zip(p, q, strict=true)))

For float and mixed int/float inputs, the intermediate products
and sums are computed with extended precision.

.. versionadded:: 3.12


.. function:: trunc(x)

Return *x* with the fractional part
Expand Down
166 changes: 166 additions & 0 deletions Lib/test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from test.support import verbose, requires_IEEE_754
from test import support
import unittest
import fractions
import itertools
import decimal
import math
Expand Down Expand Up @@ -1202,6 +1203,171 @@ def testLog10(self):
self.assertEqual(math.log(INF), INF)
self.assertTrue(math.isnan(math.log10(NAN)))

def testSumProd(self):
sumprod = math.sumprod
Decimal = decimal.Decimal
Fraction = fractions.Fraction

# Core functionality
self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140)
self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5)
self.assertEqual(sumprod([], []), 0)

# Type preservation and coercion
for v in [
(10, 20, 30),
(1.5, -2.5),
(Fraction(3, 5), Fraction(4, 5)),
(Decimal(3.5), Decimal(4.5)),
(2.5, 10), # float/int
(2.5, Fraction(3, 5)), # float/fraction
(25, Fraction(3, 5)), # int/fraction
(25, Decimal(4.5)), # int/decimal
]:
for p, q in [(v, v), (v, v[::-1])]:
with self.subTest(p=p, q=q):
expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True))
actual = sumprod(p, q)
self.assertEqual(expected, actual)
self.assertEqual(type(expected), type(actual))

# Bad arguments
self.assertRaises(TypeError, sumprod) # No args
self.assertRaises(TypeError, sumprod, []) # One arg
self.assertRaises(TypeError, sumprod, [], [], []) # Three args
self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable
self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable

# Uneven lengths
self.assertRaises(ValueError, sumprod, [10, 20], [30])
self.assertRaises(ValueError, sumprod, [10], [20, 30])

# Error in iterator
def raise_after(n):
for i in range(n):
yield i
raise RuntimeError
with self.assertRaises(RuntimeError):
sumprod(range(10), raise_after(5))
with self.assertRaises(RuntimeError):
sumprod(raise_after(5), range(10))

# Error in multiplication
class BadMultiply:
def __mul__(self, other):
raise RuntimeError
def __rmul__(self, other):
raise RuntimeError
with self.assertRaises(RuntimeError):
sumprod([10, BadMultiply(), 30], [1, 2, 3])
with self.assertRaises(RuntimeError):
sumprod([1, 2, 3], [10, BadMultiply(), 30])

# Error in addition
with self.assertRaises(TypeError):
sumprod(['abc', 3], [5, 10])
with self.assertRaises(TypeError):
sumprod([5, 10], ['abc', 3])

# Special values should give the same as the pure python recipe
self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf)
self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf)
self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf)
self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf)
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf])))
self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3])))
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3])))
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan])))

# Error cases that arose during development
args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
self.assertEqual(sumprod(*args), 0.0)


@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sumprod() accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Other implementations may choose a different algorithm
def test_sumprod_accuracy(self):
sumprod = math.sumprod
self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0)
self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)

def test_sumprod_stress(self):
sumprod = math.sumprod
product = itertools.product
Decimal = decimal.Decimal
Fraction = fractions.Fraction

class Int(int):
def __add__(self, other):
return Int(int(self) + int(other))
def __mul__(self, other):
return Int(int(self) * int(other))
__radd__ = __add__
__rmul__ = __mul__
def __repr__(self):
return f'Int({int(self)})'

class Flt(float):
def __add__(self, other):
return Int(int(self) + int(other))
def __mul__(self, other):
return Int(int(self) * int(other))
__radd__ = __add__
__rmul__ = __mul__
def __repr__(self):
return f'Flt({int(self)})'

def baseline_sumprod(p, q):
"""This defines the target behavior including expections and special values.
However, it is subject to rounding errors, so float inputs should be exactly
representable with only a few bits.
"""
total = 0
for p_i, q_i in zip(p, q, strict=True):
total += p_i * q_i
return total

def run(func, *args):
"Make comparing functions easier. Returns error status, type, and result."
try:
result = func(*args)
except (AssertionError, NameError):
raise
except Exception as e:
return type(e), None, 'None'
return None, type(result), repr(result)

pools = [
(-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)),
(5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125),
(-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333,
5.25, -3.25, -3.0*2**(-333), 3, 2**513),
(3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14,
9, 3+4j, Flt(13), 0.0),
(13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8),
Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)),
(Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0),
Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5),
(-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538,
2*2**-513),
(-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25),
(11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)),
]

for pool in pools:
for size in range(4):
for args1 in product(pool, repeat=size):
for args2 in product(pool, repeat=size):
args = (args1, args2)
self.assertEqual(
run(baseline_sumprod, *args),
run(sumprod, *args),
args,
)

def testModf(self):
self.assertRaises(TypeError, math.modf)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add math.sumprod() to compute the sum of products.
39 changes: 38 additions & 1 deletion Modules/clinic/mathmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading