Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing default transform for LKJCorr #7065

Merged
merged 11 commits into from
Dec 13, 2023
12 changes: 11 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,13 @@
lkjcorr = LKJCorrRV()


class MultivariateIntervalTransform(Interval):
name = "interval"

def log_jac_det(self, *args):
return super().log_jac_det(*args).sum(-1)

Check warning on line 1531 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1531

Added line #L1531 was not covered by tests


class LKJCorr(BoundedContinuous):
r"""
The LKJ (Lewandowski, Kurowicka and Joe) log-likelihood.
Expand Down Expand Up @@ -1592,6 +1599,9 @@
TensorVariable
"""

if value.ndim > 1:
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")

Check warning on line 1603 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1602-L1603

Added lines #L1602 - L1603 were not covered by tests

# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
# n (or else find a different expression)
if not isinstance(n, Constant):
Expand Down Expand Up @@ -1623,7 +1633,7 @@

@_default_transform.register(LKJCorr)
def lkjcorr_default_transform(op, rv):
return Interval(floatX(-1.0), floatX(1.0))
return MultivariateIntervalTransform(floatX(-1.0), floatX(1.0))

Check warning on line 1636 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L1636

Added line #L1636 was not covered by tests


class MatrixNormalRV(RandomVariable):
Expand Down
13 changes: 6 additions & 7 deletions pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ def transformed_value_logprob(op, values, *rv_outs, use_jacobian=True, **kwargs)
raise NotImplementedError(
f"Univariate transform {transform} cannot be applied to multivariate {rv_op}"
)
else:
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)
# Check there is no broadcasting between logp and jacobian
if logp.type.broadcastable != log_jac_det.type.broadcastable:
raise ValueError(
f"The logp of {rv_op} and log_jac_det of {transform} are not allowed to broadcast together. "
"There is a bug in the implementation of either one."
)

if use_jacobian:
if value.name:
Expand Down
22 changes: 21 additions & 1 deletion tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import functools as ft
import re
import warnings

import numpy as np
Expand All @@ -33,6 +32,7 @@
import pymc as pm

from pymc.distributions.multivariate import (
MultivariateIntervalTransform,
_LKJCholeskyCov,
_OrderedMultinomial,
posdef,
Expand Down Expand Up @@ -2122,6 +2122,26 @@ def ref_rand(size, n, eta):
)


@pytest.mark.parametrize(
argnames="shape",
argvalues=[
(2,),
pytest.param(
(3, 2),
marks=pytest.mark.xfail(
raises=NotImplementedError,
reason="LKJCorr logp is only implemented for vector values (ndim=1)",
),
),
],
)
def test_default_transform(shape):
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
with pm.Model() as m:
x = pm.LKJCorr("x", n=2, eta=1, shape=shape)
assert isinstance(m.rvs_to_transforms[x], MultivariateIntervalTransform)
assert m.logp(sum=False)[0].type.shape == shape[:-1]


class TestLKJCholeskyCov(BaseTestDistributionRandom):
pymc_dist = _LKJCholeskyCov
pymc_dist_params = {"n": 3, "eta": 1.0, "sd_dist": pm.DiracDelta.dist([0.5, 1.0, 2.0])}
Expand Down
Loading