Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite loop bug in doctest from #36581 #38209

Merged
merged 5 commits into from
Jun 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions src/sage/stats/distributions/discrete_gaussian_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
\exp(-|x|_2^2 / (2\sigma^2))`, i.e. the normalization factor such that the sum
over all probabilities is 1 for `B`, via Poisson summation.


INPUT:

- ``tau`` -- (default: ``None``) all vectors `v` with `|v|_2^2 \leq
Expand Down Expand Up @@ -235,7 +234,7 @@ def _normalisation_factor_zz(self, tau=None, prec=None):
sage: while v not in counter:
....: add_samples(1000)

sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.2: # long time
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.2: # long time
....: add_samples(1000)

sage: DGL = distributions.DiscreteGaussianDistributionLatticeSampler
Expand All @@ -250,9 +249,14 @@ def _normalisation_factor_zz(self, tau=None, prec=None):

sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1], [3, -4, 2]])
sage: D = DGL(M, 1.7)
sage: D._normalisation_factor_zz() # long time
sage: D._normalisation_factor_zz() # long time
7247.1975...

sage: Sigma = Matrix(ZZ, [[5, -2, 4], [-2, 10, -5], [4, -5, 5]])
sage: D = DGL(ZZ^3, Sigma, [7, 2, 5])
sage: D._normalisation_factor_zz()
78.6804...

sage: M = Matrix(ZZ, [[1, 3, 0], [-2, 5, 1]])
sage: D = DGL(M, 3)
sage: D._normalisation_factor_zz()
Expand Down Expand Up @@ -300,7 +304,10 @@ def f_or_hat(x):
from sage.functions.log import log
basis = self.B.LLL()
base = vector(ZZ, [v.round() for v in basis.solve_left(self._c)])
BOUND = max(1, (self._RR(log(10**4, self.n)).ceil() - 1) // 2)
# BOUND is the largest integer such that |coords| <= 10^4
# However, this might still drift from true value for larger lattices
# So optimally one should fix the TODO above
BOUND = max(1, (self._RR(10**(4 / self.n)).ceil() - 1) // 2)
if BOUND > 10:
BOUND = 10
coords = itertools.product(range(-BOUND, BOUND + 1), repeat=self.n)
Expand Down Expand Up @@ -431,6 +438,11 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
- ``sigma_basis`` -- (default: ``False``) When set, ``sigma`` is treated as
a (row) basis, i.e. the covariance matrix is computed by `\Sigma = SS^T`

.. TODO::

Rename class methods like ``.f`` and hide most of them
(at least behind something like ``.data``).

EXAMPLES::

sage: n = 2; sigma = 3.0
Expand All @@ -440,8 +452,7 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
56.5486677646...

sage: from collections import defaultdict
sage: counter = defaultdict(Integer)
sage: m = 0
sage: counter = defaultdict(Integer); m = 0
sage: def add_samples(i):
....: global counter, m
....: for _ in range(i):
Expand All @@ -454,7 +465,7 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.1: # needs sage.symbolic
....: add_samples(1000)

sage: counter = defaultdict(Integer)
sage: counter = defaultdict(Integer); m = 0
sage: v = vector(ZZ, n, (0, 0))
sage: v.set_immutable()
sage: while v not in counter:
Expand All @@ -478,10 +489,24 @@ def __init__(self, B, sigma=1, c=0, r=None, precision=None, sigma_basis=False):
sage: Sigma = Matrix(ZZ, [[5, -2, 4], [-2, 10, -5], [4, -5, 5]])
sage: c = vector(ZZ, [7, 2, 5])
sage: D = distributions.DiscreteGaussianDistributionLatticeSampler(ZZ^n, Sigma, c)
sage: f = D.f
sage: nf = D._normalisation_factor_zz(); nf # This has not been properly implemented
63.76927...
sage: while v not in counter: add_samples(1000)
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.1: add_samples(1000)
78.6804...

We can compute the expected number of samples before sampling a vector::

sage: v = vector(ZZ, n, (11, 4, 8))
sage: v.set_immutable()
sage: 1 / (f(v) / nf)
2553.9461...

sage: counter = defaultdict(Integer); m = 0
sage: while v not in counter:
....: add_samples(1000)
sage: sum(counter.values()) # random
3000
sage: while abs(m*f(v)*1.0/nf/counter[v] - 1.0) >= 0.1: # needs sage.symbolic
....: add_samples(1000)

If the covariance provided is not positive definite, an error is thrown::

Expand Down
Loading