From 7699f582cec9fcc3f6ecdce40fac99f382065289 Mon Sep 17 00:00:00 2001 From: Tushar Mittal Date: Sat, 28 Mar 2020 00:02:40 +0530 Subject: [PATCH] Add QuadPotentialFullAdapt in pm.sample init --- pymc3/sampling.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index ca1773fb27e..eb4e7d00022 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -275,6 +275,7 @@ def sample( * advi_map: Initialize ADVI with MAP and use MAP as starting point. * map: Use the MAP as starting point. This is discouraged. * nuts: Run NUTS and estimate posterior mean and mass matrix from the trace. + * adapt_full: Adapt a dense mass matrix using the sample covariances step: function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step @@ -1866,6 +1867,7 @@ def init_nuts( * map: Use the MAP as starting point. This is discouraged. * nuts: Run NUTS and estimate posterior mean and mass matrix from the trace. + * adapt_full: Adapt a dense mass matrix using the sample covariances chains: int Number of jobs to start. n_init: int @@ -2006,6 +2008,14 @@ def init_nuts( cov = np.atleast_1d(pm.trace_cov(init_trace)) start = list(np.random.choice(init_trace, chains)) potential = quadpotential.QuadPotentialFull(cov) + elif init == "adapt_full": + start = [model.test_point] * chains + mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) + init_trace = pm.sample( + draws=n_init, step=pm.NUTS(), tune=n_init // 2, random_seed=random_seed + ) + cov = np.atleast_1d(pm.trace_cov(init_trace)) + potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10) else: raise ValueError("Unknown initializer: {}.".format(init))