Skip to content

Commit

Permalink
Fix negative bandwidth test and add online code path test. (gh-118600)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhettinger authored May 5, 2024
1 parent 9c13d9e commit 5092ea2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
5 changes: 2 additions & 3 deletions Lib/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,9 +1791,8 @@ def kde_random(data, h, kernel='normal', *, seed=None):
if h <= 0.0:
raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}')

try:
kernel_invcdf = _kernel_invcdfs[kernel]
except KeyError:
kernel_invcdf = _kernel_invcdfs.get(kernel)
if kernel_invcdf is None:
raise StatisticsError(f'Unknown kernel name: {kernel!r}')

prng = _random.Random(seed)
Expand Down
28 changes: 22 additions & 6 deletions Lib/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2402,7 +2402,7 @@ def integrate(func, low, high, steps=10_000):
with self.assertRaises(StatisticsError):
kde(sample, h=0.0) # Zero bandwidth
with self.assertRaises(StatisticsError):
kde(sample, h=0.0) # Negative bandwidth
kde(sample, h=-1.0) # Negative bandwidth
with self.assertRaises(TypeError):
kde(sample, h='str') # Wrong bandwidth type
with self.assertRaises(StatisticsError):
Expand All @@ -2426,6 +2426,14 @@ def integrate(func, low, high, steps=10_000):
self.assertEqual(f_hat(-1.0), 1/2)
self.assertEqual(f_hat(1.0), 1/2)

# Test online updates to data

data = [1, 2]
f_hat = kde(data, 5.0, 'triangular')
self.assertEqual(f_hat(100), 0.0)
data.append(100)
self.assertGreater(f_hat(100), 0.0)

def test_kde_kernel_invcdfs(self):
kernel_invcdfs = statistics._kernel_invcdfs
kde = statistics.kde
Expand Down Expand Up @@ -2462,7 +2470,7 @@ def test_kde_random(self):
with self.assertRaises(TypeError):
kde_random(iter(sample), 1.5) # Data is not a sequence
with self.assertRaises(StatisticsError):
kde_random(sample, h=0.0) # Zero bandwidth
kde_random(sample, h=-1.0) # Zero bandwidth
with self.assertRaises(StatisticsError):
kde_random(sample, h=0.0) # Negative bandwidth
with self.assertRaises(TypeError):
Expand All @@ -2474,10 +2482,10 @@ def test_kde_random(self):

h = 1.5
kernel = 'cosine'
prng = kde_random(sample, h, kernel)
self.assertEqual(prng.__name__, 'rand')
self.assertIn(kernel, prng.__doc__)
self.assertIn(repr(h), prng.__doc__)
rand = kde_random(sample, h, kernel)
self.assertEqual(rand.__name__, 'rand')
self.assertIn(kernel, rand.__doc__)
self.assertIn(repr(h), rand.__doc__)

# Approximate distribution test: Compare a random sample to the expected distribution

Expand Down Expand Up @@ -2507,6 +2515,14 @@ def p_expected(x):
for x in xarr:
self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.0005))

# Test online updates to data

data = [1, 2]
rand = kde_random(data, 5, 'triangular')
self.assertLess(max([rand() for i in range(5000)]), 10)
data.append(100)
self.assertGreater(max(rand() for i in range(5000)), 10)


class TestQuantiles(unittest.TestCase):

Expand Down

0 comments on commit 5092ea2

Please sign in to comment.