Skip to content

Commit

Permalink
Allow Truncation of CustomDists
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 11, 2023
1 parent 76a5de0 commit a47bb8a
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 75 deletions.
170 changes: 104 additions & 66 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from pytensor.tensor import TensorConstant, TensorVariable
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType

from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
CustomSymbolicDistRV,
Distribution,
SymbolicRandomVariable,
_moment,
Expand All @@ -38,8 +40,9 @@
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered


Expand All @@ -49,7 +52,7 @@ class TruncatedRV(SymbolicRandomVariable):
that represents a truncated univariate random variable.
"""

default_output = 1
default_output = 0

def __init__(
self,
Expand All @@ -63,8 +66,13 @@ def __init__(
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the internal RNG."""
return {node.inputs[-1]: node.outputs[0]}
"""Return the update mapping for the internal RNGs.
TruncatedRVs are created in a way that the rng updats follow the same order as the input RNGs.
"""
rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)]
next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)]
return dict(zip(rngs, next_rngs))


@singledispatch
Expand Down Expand Up @@ -141,10 +149,14 @@ class Truncated(Distribution):

@classmethod
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)):
if not (
isinstance(dist, TensorVariable)
and isinstance(dist.owner.op, (RandomVariable, CustomSymbolicDistRV))
):
if isinstance(dist.owner.op, SymbolicRandomVariable):
raise NotImplementedError(
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}"
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
f"You can try wrapping the distribution inside a CustomDist instead."
)
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
Expand All @@ -160,6 +172,15 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)

return super().dist([dist, lower, upper, max_n_steps], **kwargs)

@staticmethod
def _recreate_untruncated_rv(
op, truncated_rv
) -> tuple[TensorVariable, TensorVariable, TensorVariable]:
"""Recreate (unbox) the untruncated base RV and return it alongside lower and upper"""
*rv_inputs, lower, upper = truncated_rv.owner.inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
return untruncated_rv, lower, upper

Check warning on line 182 in pymc/distributions/truncated.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/truncated.py#L180-L182

Added lines #L180 - L182 were not covered by tests

@classmethod
def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
# Try to use specialized Op
Expand All @@ -174,46 +195,59 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
if size is None:
size = pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, new_size=size)
rv_inputs = [
inp
if not isinstance(inp.type, RandomType)
else pytensor.shared(np.random.default_rng())
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs = [*dist.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_

# We will use a Shared RNG variable because Scan demands it, even though it
# would not be necessary for the OpFromGraph inverse cdf.
rng = pytensor.shared(np.random.default_rng())
rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output()
rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(lower, upper)))
try:
# For left truncated discrete RVs, we need to include the whole lower bound.
# This may result in draws below the truncation range, if any uniform == 0
lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_
cdf_lower_ = pt.exp(logcdf(rv_, lower_value))
lower_value_ = lower_ - 1 if dist.dtype.startswith("int") else lower_
cdf_lower_ = pt.exp(logcdf(rv_, lower_value_))
cdf_upper_ = pt.exp(logcdf(rv_, upper_))
# It's okay to reuse the same rng here, because the rng in rv_ will not be
# used by either the logcdf of icdf functions
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rng,
size=rv_inputs_[0],
rng=uniform_rng_,
size=rv_.shape,
).owner.outputs
truncated_rv_ = icdf(rv_, uniform_)
# There should be a RV in the icdf graph (the uniform draw)
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=[*graph_inputs_, rng],
outputs=[uniform_next_rng_, truncated_rv_],
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs, rng)
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
Expand All @@ -226,7 +260,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):

return (
(truncated_rv, reject_draws),
[(rng, next_rng)],
collect_default_updates(new_truncated_rv),
until(~pt.any(reject_draws)),
)

Expand All @@ -236,7 +270,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, rng, *rv_inputs_],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)
Expand All @@ -246,23 +280,30 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)
# Sort updates of each RNG so that they show in the same order as the input RNGs

def sort_updates(update):
rng, next_rng = update
return graph_inputs.index(rng)

next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]

return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=[*graph_inputs_, rng],
outputs=[tuple(updates.values())[0], truncated_rv_],
inputs=graph_inputs_,
outputs=[truncated_rv_, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs, rng)
)(*graph_inputs)


@_change_dist_size.register(TruncatedRV)
def change_truncated_size(op, dist, new_size, expand):
*rv_inputs, lower, upper, rng = dist.owner.inputs
# Recreate the original untruncated RV
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
*rv_inputs, lower, upper = truncated_rv.owner.inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()

if expand:
new_size = to_tuple(new_size) + tuple(dist.shape)
new_size = to_tuple(new_size) + tuple(truncated_rv.shape)

return Truncated.rv_op(
untruncated_rv,
Expand All @@ -274,11 +315,9 @@ def change_truncated_size(op, dist, new_size, expand):


@_moment.register(TruncatedRV)
def truncated_moment(op, rv, *inputs):
*rv_inputs, lower, upper, rng = inputs

# recreate untruncated rv and respective moment
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
def truncated_moment(op: TruncatedRV, truncated_rv, *inputs):
*rv_inputs, lower, upper = inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
untruncated_moment = moment(untruncated_rv)

fallback_moment = pt.switch(
Expand All @@ -299,31 +338,31 @@ def truncated_moment(op, rv, *inputs):


@_default_transform.register(TruncatedRV)
def truncated_default_transform(op, rv):
def truncated_default_transform(op, truncated_rv):
# Don't transform discrete truncated distributions
if op.base_rv_op.dtype.startswith("int"):
if truncated_rv.type.dtype.startswith("int"):
return None
# Lower and Upper are the arguments -3 and -2
return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2))
# Lower and Upper are the arguments -2 and -1
return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1))


@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values

*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]
*rv_inputs, lower, upper = inputs

base_rv_op = op.base_rv_op
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
base_rv = base_rv_op.make_node(*rv_inputs).default_output()

base_logp = logp(base_rv, value)
# For left truncated RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_logcdf = logcdf(base_rv, lower_value)
upper_logcdf = logcdf(base_rv, upper)

if base_rv_op.name:
logp.name = f"{base_rv_op}_logprob"
base_logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"

Expand All @@ -338,37 +377,36 @@ def truncated_logprob(op, values, *inputs, **kwargs):
elif is_upper_bounded:
lognorm = upper_logcdf

logp = logp - lognorm
truncated_logp = base_logp - lognorm

if is_lower_bounded:
logp = pt.switch(value < lower, -np.inf, logp)
truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp)

if is_upper_bounded:
logp = pt.switch(value <= upper, logp, -np.inf)
truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf)

if is_lower_bounded and is_upper_bounded:
logp = check_parameters(
logp,
truncated_logp = check_parameters(
truncated_logp,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logp
return truncated_logp


@_logcdf.register(TruncatedRV)
def truncated_logcdf(op, value, *inputs, **kwargs):
*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]
def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
*rv_inputs, lower, upper = inputs

base_rv_op = op.base_rv_op
logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)
base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
base_logcdf = logcdf(base_rv, value)

# For left truncated discrete RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
lower_value = lower - 1 if base_rv.dtype.startswith("int") else lower
lower_logcdf = logcdf(base_rv, lower_value)
upper_logcdf = logcdf(base_rv, upper)

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
Expand All @@ -381,7 +419,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs):
elif is_upper_bounded:
lognorm = upper_logcdf

logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf
logcdf_trunc = logcdf_numerator - lognorm

if is_lower_bounded:
Expand Down
Loading

0 comments on commit a47bb8a

Please sign in to comment.