From 5092ea238e28c7d099c662d416b2a96fdbea4790 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Sun, 5 May 2024 12:29:23 -0500 Subject: [PATCH] Fix negative bandwidth test and add online code path test. (gh-118600) --- Lib/statistics.py | 5 ++--- Lib/test/test_statistics.py | 28 ++++++++++++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/Lib/statistics.py b/Lib/statistics.py index f3ce2d8b6b442a..c2f4fe8e054d3d 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -1791,9 +1791,8 @@ def kde_random(data, h, kernel='normal', *, seed=None): if h <= 0.0: raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}') - try: - kernel_invcdf = _kernel_invcdfs[kernel] - except KeyError: + kernel_invcdf = _kernel_invcdfs.get(kernel) + if kernel_invcdf is None: raise StatisticsError(f'Unknown kernel name: {kernel!r}') prng = _random.Random(seed) diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index a60791e9b6e1f5..40680759d456ac 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -2402,7 +2402,7 @@ def integrate(func, low, high, steps=10_000): with self.assertRaises(StatisticsError): kde(sample, h=0.0) # Zero bandwidth with self.assertRaises(StatisticsError): - kde(sample, h=0.0) # Negative bandwidth + kde(sample, h=-1.0) # Negative bandwidth with self.assertRaises(TypeError): kde(sample, h='str') # Wrong bandwidth type with self.assertRaises(StatisticsError): @@ -2426,6 +2426,14 @@ def integrate(func, low, high, steps=10_000): self.assertEqual(f_hat(-1.0), 1/2) self.assertEqual(f_hat(1.0), 1/2) + # Test online updates to data + + data = [1, 2] + f_hat = kde(data, 5.0, 'triangular') + self.assertEqual(f_hat(100), 0.0) + data.append(100) + self.assertGreater(f_hat(100), 0.0) + def test_kde_kernel_invcdfs(self): kernel_invcdfs = statistics._kernel_invcdfs kde = statistics.kde @@ -2462,7 +2470,7 @@ def test_kde_random(self): with self.assertRaises(TypeError): kde_random(iter(sample), 1.5) # Data is not a sequence with self.assertRaises(StatisticsError): - kde_random(sample, h=0.0) # Zero bandwidth + kde_random(sample, h=-1.0) # Zero bandwidth with self.assertRaises(StatisticsError): kde_random(sample, h=0.0) # Negative bandwidth with self.assertRaises(TypeError): @@ -2474,10 +2482,10 @@ def test_kde_random(self): h = 1.5 kernel = 'cosine' - prng = kde_random(sample, h, kernel) - self.assertEqual(prng.__name__, 'rand') - self.assertIn(kernel, prng.__doc__) - self.assertIn(repr(h), prng.__doc__) + rand = kde_random(sample, h, kernel) + self.assertEqual(rand.__name__, 'rand') + self.assertIn(kernel, rand.__doc__) + self.assertIn(repr(h), rand.__doc__) # Approximate distribution test: Compare a random sample to the expected distribution @@ -2507,6 +2515,14 @@ def p_expected(x): for x in xarr: self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.0005)) + # Test online updates to data + + data = [1, 2] + rand = kde_random(data, 5, 'triangular') + self.assertLess(max([rand() for i in range(5000)]), 10) + data.append(100) + self.assertGreater(max(rand() for i in range(5000)), 10) + class TestQuantiles(unittest.TestCase):