-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing feature: Multivariate normal distribution with a different covariance matrix for each particle #55
Comments
Thanks for opening this issue.
Maybe also one additional idea. |
This should do: the broadcasting should be smart enough to handle everything. import particles.distributions as dists
import numpy as np
import scipy.stats as stats
def log_det_chol(chol):
diags = np.diagonal(chol, axis1=-2, axis2=-1)
return np.sum(np.log(np.abs(diags)), -1)
def mvn_logpdf(x, loc, chol):
d = loc.shape[-1]
b = np.broadcast(x, loc)
diff = np.empty(b.shape)
diff.flat = [u - v for (u,v) in b]
z = np.linalg.solve(chol, diff) # solve_triangular doesn't accept batched matrices
const = -0.5 * d * np.log(2 * np.pi) - log_det_chol(chol)
return const -0.5 * np.sum(z * z, -1)
def mvn_sample(loc, chol, size=1):
# Do some checks on size here
d = loc.shape[-1]
broadcast_to = np.broadcast_shapes(loc.shape, (size, d))
loc = np.broadcast_to(loc, broadcast_to)
eps = np.random.randn(*broadcast_to)
return loc + np.einsum("...ij,...j", chol, eps)
class Generalized_MV_Normal(dists.ProbDist):
"""Multivariate normal, with dim >=2, allowing for a different cov matrix for each
particle.
"""
def __init__(self, loc=None, cov=None):
self.loc = loc
self.chol = np.linalg.cholesky(cov)
self.dim = cov.shape[-1]
def rvs(self, size=1):
return mvn_sample(self.loc, self.chol, size)
def logpdf(self, x):
return mvn_logpdf(x, self.loc, self.chol) |
Ah, sorry, my bad. One way to fix my code :
where basically I replaced$B_1$ by $S$ minus its expectation.
The current version of MvNormal does not allow for a covariance matrix that varies across the particles. I implemented this for a colleague, but I don't like it so much because there is a loop over$n$ , so iit may be slow:
Since this is the second time someone is asking for something like this, I am going to open an issue and try to think of ways to make the above code more efficient (numba?).
Originally posted by @nchopin in #54 (comment)
The text was updated successfully, but these errors were encountered: