Skip to content

Commit

Permalink
Pass seed to find_map and set global seeds in _iter_sample and `_…
Browse files Browse the repository at this point in the history
…prepare_iter_sample`

This reverts some changes in 47b61de which wrongly disabled global seeding in some sampling contexts that still depended on it.
  • Loading branch information
ricardoV94 authored and twiecki committed May 19, 2022
1 parent 56ad6a9 commit b91283e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,9 @@ def _iter_sample(
if draws < 1:
raise ValueError("Argument `draws` must be greater than 0.")

if random_seed is not None:
np.random.seed(random_seed)

try:
step = CompoundStep(step)
except TypeError:
Expand Down Expand Up @@ -1229,6 +1232,9 @@ def _prepare_iter_population(
if draws < 1:
raise ValueError("Argument `draws` should be above 0.")

if random_seed is not None:
np.random.seed(random_seed)

# The initialization of traces, samplers and points must happen in the right order:
# 1. population of points is created
# 2. steppers are initialized and linked to the points object
Expand Down Expand Up @@ -2511,7 +2517,7 @@ def init_nuts(
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "advi_map":
start = pm.find_MAP(include_transformed=True)
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
approx = pm.MeanField(model=model, start=start)
pm.fit(
random_seed=seeds[0],
Expand All @@ -2526,7 +2532,7 @@ def init_nuts(
cov = approx.std.eval() ** 2
potential = quadpotential.QuadPotentialDiag(cov)
elif init == "map":
start = pm.find_MAP(include_transformed=True)
start = pm.find_MAP(include_transformed=True, seed=seeds[0])
cov = pm.find_hessian(point=start)
initial_points = [start] * chains
potential = quadpotential.QuadPotentialFull(cov)
Expand Down
48 changes: 48 additions & 0 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,54 @@ def setup_method(self):
super().setup_method()
self.model, self.start, self.step, _ = simple_init()

@pytest.mark.parametrize("init", ("jitter+adapt_diag", "advi", "map"))
@pytest.mark.parametrize("cores", (1, 2))
@pytest.mark.parametrize(
"chains, seeds",
[
(1, None),
(1, 1),
(1, [1]),
(2, None),
(2, 1),
(2, [1, 2]),
],
)
def test_random_seed(self, chains, seeds, cores, init):
with pm.Model(rng_seeder=3):
x = pm.Normal("x", 0, 10, initval="prior")
tr1 = pm.sample(
chains=chains,
random_seed=seeds,
cores=cores,
init=init,
tune=0,
draws=10,
return_inferencedata=False,
compute_convergence_checks=False,
)
tr2 = pm.sample(
chains=chains,
random_seed=seeds,
cores=cores,
init=init,
tune=0,
draws=10,
return_inferencedata=False,
compute_convergence_checks=False,
)

allequal = np.all(tr1["x"] == tr2["x"])
if seeds is None:
assert not allequal
# TODO: ADVI init methods are not correctly seeded, as they rely on the state of
# the model RandomState/Generators which is updated in place when the function
# is compiled and evaluated. This elif branch must be removed once this is fixed
elif init == "advi":
assert not allequal
else:
assert allequal

def test_sample_does_not_set_seed(self):
# This tests that when random_seed is None, the global seed is not affected
random_numbers = []
Expand Down

0 comments on commit b91283e

Please sign in to comment.