Sampling a 1D gaussian mixture #30
-
Hi, class MixtureModel(dist.Distribution):
support = dist.constraints.real_vector
def __init__(self, locs, scales, weights):
self.locs, self.scales, self.weights= locs, scales, weights
super().__init__(batch_shape=locs.shape[:-1], event_shape=locs.shape[-1:])
self.loc = locs.T
self.scale = scales.T
self.weights = weights.T
norm = jnp.sum(self.weights)
self.weights = self.weights/norm
self.num_distr = len(locs)
def sample(self, key, sample_shape=()):
return jnp.zeros(
sample_shape + self.shape()
) # a dummy sample to initialize the samplers
def log_prob(self, x):
log_probs = jax.scipy.stats.norm.logpdf(x,loc=self.loc, scale=self.scale)
return jax.scipy.special.logsumexp(jnp.log(self.weights) + log_probs, axis=0)[0]
def model():
x=numpyro.sample("x", dist.Uniform(-3.0, 3.0))
numpyro.sample("val", MixtureModel(jnp.array([0,1.5]),
jnp.array([0.5,0.1]),
jnp.array([8,2])), obs=x)
ns = NestedSampler(model)
ns.run(random.PRNGKey(2))
samples = ns.get_samples(random.PRNGKey(3), num_samples=100_000) Well, first the distribution of 'x' looks like I was expecting (although I do not anderstand why it is the 'x' variable and not the 'val' wihich is the output, but I guess it more related to Numpyro wrapping) But the summary is
Now I know the Integral value it is 0.63167, and so I do not anderstand the Do you see any reason that the sampler is not so good here (nb. HMC(NUTS) is doing well) and why the logZ seems out also? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Beta Was this translation helpful? Give feedback.
Add: I have increased significantly the
num_live_points
from 1_000 (default) to 10_000 and the sampling looks betterBut,
num_live_points=50*(# posterior modes)*(D+1)
which here gives 50 x 2 x (1+1) = 200.(nb. there are still summary crashes)