Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing default transform for LKJCorr #7065

Merged
merged 11 commits into from
Dec 13, 2023
9 changes: 8 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)

Check warning on line 1531 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1531

Added line #L1531 was not covered by tests


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1623,7 +1630,7 @@

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))

Check warning on line 1633 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1633

Added line #L1633 was not covered by tests


class MatrixNormalRV(RandomVariable):
Expand Down
6 changes: 6 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,12 @@ def ref_rand(size, n, eta):
size=1000,
)

def test_default_transform(self):
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=(3, 2))
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
assert m.logp(sum=False)[0].shape == (3,)


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
Expand Down
Loading