Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup logprob module #7443

Merged
merged 4 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableVariable, _icdf, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -228,7 +228,7 @@ def __get__(self, instance, type_):
return descr_get(instance, type_)


class SymbolicRandomVariable(OpFromGraph):
class SymbolicRandomVariable(MeasurableOp, OpFromGraph):
"""Symbolic Random Variable

This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic
Expand Down Expand Up @@ -624,10 +624,6 @@ def dist(
return rv_out


# Let PyMC know that the SymbolicRandomVariable has a logprob.
MeasurableVariable.register(SymbolicRandomVariable)


@node_rewriter([SymbolicRandomVariable])
def inline_symbolic_random_variable(fgraph, node):
"""
Expand Down
30 changes: 23 additions & 7 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# SOFTWARE.

import abc
import warnings

from collections.abc import Sequence
from functools import singledispatch
Expand All @@ -46,6 +47,17 @@
from pytensor.tensor.random.op import RandomVariable


def __getattr__(name):
if name == "MeasurableVariable":
warnings.warn(

Check warning on line 52 in pymc/logprob/abstract.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/abstract.py#L52

Added line #L52 was not covered by tests
f"{name} has been deprecated in favor of MeasurableOp. Importing will fail in a future release.",
FutureWarning,
)
return MeasurableOpMixin

Check warning on line 56 in pymc/logprob/abstract.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/abstract.py#L56

Added line #L56 was not covered by tests

raise AttributeError(f"module {__name__} has no attribute {name}")


@singledispatch
def _logprob(
op: Op,
Expand Down Expand Up @@ -131,14 +143,21 @@
return rv_icdf


class MeasurableVariable(abc.ABC):
"""A variable that can be assigned a measure/log-probability"""
class MeasurableOp(abc.ABC):
"""An operation whose outputs can be assigned a measure/log-probability"""


MeasurableOp.register(RandomVariable)

MeasurableVariable.register(RandomVariable)

class MeasurableOpMixin(MeasurableOp):
"""MeasurableOp Mixin with a distinctive string representation"""

class MeasurableElemwise(Elemwise):
def __str__(self):
return f"Measurable{super().__str__()}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows seeing when a Measurable replacement was actually performed in the IR graph. Most times we use a subclass of the pre-existing Op, and it would not be distinguishable in the string representation



class MeasurableElemwise(MeasurableOpMixin, Elemwise):
"""Base class for Measurable Elemwise variables"""

valid_scalar_types: tuple[MetaType, ...] = ()
Expand All @@ -150,6 +169,3 @@
f"Acceptable types are {self.valid_scalar_types}"
)
super().__init__(scalar_op, *args, **kwargs)


MeasurableVariable.register(MeasurableElemwise)
4 changes: 2 additions & 2 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
MeasurableOp,
_icdf_helper,
_logcdf_helper,
_logprob,
Expand Down Expand Up @@ -522,7 +522,7 @@ def conditional_logp(
while q:
node = q.popleft()

if not isinstance(node.op, MeasurableVariable):
if not isinstance(node.op, MeasurableOp):
continue

q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
Expand Down
14 changes: 4 additions & 10 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,15 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.shape import SpecifyShape

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOp, MeasurableOpMixin, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableSpecifyShape(SpecifyShape):
class MeasurableSpecifyShape(MeasurableOpMixin, SpecifyShape):
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""


MeasurableVariable.register(MeasurableSpecifyShape)


@_logprob.register(MeasurableSpecifyShape)
def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
(value,) = values
Expand All @@ -80,7 +77,7 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:

if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and isinstance(base_rv.owner.op, MeasurableOp)
and base_rv not in rv_map_feature.rv_values
):
return None # pragma: no cover
Expand All @@ -99,13 +96,10 @@ def find_measurable_specify_shapes(fgraph, node) -> list[TensorVariable] | None:
)


class MeasurableCheckAndRaise(CheckAndRaise):
class MeasurableCheckAndRaise(MeasurableOpMixin, CheckAndRaise):
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""


MeasurableVariable.register(MeasurableCheckAndRaise)


@_logprob.register(MeasurableCheckAndRaise)
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
(value,) = values
Expand Down
7 changes: 2 additions & 5 deletions pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,14 @@
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.abstract import MeasurableOpMixin, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db


class MeasurableCumsum(CumOp):
class MeasurableCumsum(MeasurableOpMixin, CumOp):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""


MeasurableVariable.register(MeasurableCumsum)


@_logprob.register(MeasurableCumsum)
def logprob_cumsum(op, values, base_rv, **kwargs):
"""Compute the log-likelihood graph for a `Cumsum`."""
Expand Down
15 changes: 5 additions & 10 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
MeasurableOp,
MeasurableOpMixin,
_logprob,
_logprob_helper,
)
Expand Down Expand Up @@ -217,7 +218,7 @@ def rv_pull_down(x: TensorVariable) -> TensorVariable:
return fgraph.outputs[0]


class MixtureRV(Op):
class MixtureRV(MeasurableOpMixin, Op):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

__props__ = ("indices_end_idx", "out_dtype", "out_broadcastable")
Expand All @@ -235,9 +236,6 @@ def perform(self, node, inputs, outputs):
raise NotImplementedError("This is a stand-in Op.") # pragma: no cover


MeasurableVariable.register(MixtureRV)


def get_stack_mixture_vars(
node: Apply,
) -> tuple[list[TensorVariable] | None, int | None]:
Expand Down Expand Up @@ -457,13 +455,10 @@ def logprob_switch_mixture(op, values, switch_cond, component_true, component_fa
)


class MeasurableIfElse(IfElse):
class MeasurableIfElse(MeasurableOpMixin, IfElse):
"""Measurable subclass of IfElse operator."""


MeasurableVariable.register(MeasurableIfElse)


@node_rewriter([IfElse])
def useless_ifelse_outputs(fgraph, node):
"""Remove outputs that are shared across the IfElse branches."""
Expand Down Expand Up @@ -512,7 +507,7 @@ def find_measurable_ifelse_mixture(fgraph, node):
base_rvs = assume_measured_ir_outputs(valued_rvs, base_rvs)
if len(base_rvs) != op.n_outs * 2:
return None
if not all(var.owner and isinstance(var.owner.op, MeasurableVariable) for var in base_rvs):
if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_rvs):
return None

return MeasurableIfElse(n_outs=op.n_outs).make_node(if_var, *base_rvs).outputs
Expand Down
22 changes: 5 additions & 17 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableVariable,
MeasurableOpMixin,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand All @@ -59,20 +59,14 @@
from pymc.pytensorf import constant_fold


class MeasurableMax(Max):
class MeasurableMax(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""


MeasurableVariable.register(MeasurableMax)


class MeasurableMaxDiscrete(Max):
class MeasurableMaxDiscrete(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxDiscrete)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -162,21 +156,15 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):
return logprob


class MeasurableMaxNeg(Max):
class MeasurableMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


MeasurableVariable.register(MeasurableMaxNeg)


class MeasurableDiscreteMaxNeg(Max):
class MeasurableDiscreteMaxNeg(MeasurableOpMixin, Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(MeasurableDiscreteMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down
57 changes: 5 additions & 52 deletions pymc/logprob/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
from collections import deque
from collections.abc import Collection, Sequence

import pytensor.tensor as pt

from pytensor import config
from pytensor.compile.mode import optdb
from pytensor.graph.basic import (
Expand Down Expand Up @@ -82,8 +80,8 @@
)
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.utils import DiracDelta, indices_from_subtensor
from pymc.logprob.abstract import MeasurableOp
from pymc.logprob.utils import DiracDelta

inc_subtensor_ops = (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)
subtensor_ops = (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)
Expand Down Expand Up @@ -140,7 +138,7 @@ def apply(self, fgraph):
continue
# This is where we filter only those nodes we care about:
# Nodes that have variables that we want to measure and are not yet measurable
if isinstance(node.op, MeasurableVariable):
if isinstance(node.op, MeasurableOp):
continue
if not any(out in rv_map_feature.needs_measuring for out in node.outputs):
continue
Expand All @@ -156,7 +154,7 @@ def apply(self, fgraph):
node_rewriter, "__name__", ""
)
# If we converted to a MeasurableVariable we're done here!
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableVariable):
if node not in fgraph.apply_nodes or isinstance(node.op, MeasurableOp):
# go to next node
break

Expand Down Expand Up @@ -275,7 +273,7 @@ def request_measurable(self, vars: Sequence[Variable]) -> list[Variable]:
# Input vars or valued vars can't be measured for derived expressions
if not var.owner or var in self.rv_values:
continue
if isinstance(var.owner.op, MeasurableVariable):
if isinstance(var.owner.op, MeasurableOp):
measurable.append(var)
else:
self.needs_measuring.add(var)
Expand Down Expand Up @@ -313,50 +311,6 @@ def remove_DiracDelta(fgraph, node):
return [dd_val]


@node_rewriter(inc_subtensor_ops)
def incsubtensor_rv_replace(fgraph, node):
r"""Replace `*IncSubtensor*` `Op`\s and their value variables for log-probability calculations.

This is used to derive the log-probability graph for ``Y[idx] = data``, where
``Y`` is a `RandomVariable`, ``idx`` indices, and ``data`` some arbitrary data.

To compute the log-probability of a statement like ``Y[idx] = data``, we must
first realize that our objective is equivalent to computing ``logprob(Y, z)``,
where ``z = pt.set_subtensor(y[idx], data)`` and ``y`` is the value variable
for ``Y``.

In other words, the log-probability for an `*IncSubtensor*` is the log-probability
of the underlying `RandomVariable` evaluated at ``data`` for the indices
given by ``idx`` and at the value variable for ``~idx``.

This provides a means of specifying "missing data", for instance.
"""
rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

rv_var = node.outputs[0]
if rv_var not in rv_map_feature.rv_values:
return None # pragma: no cover

base_rv_var = node.inputs[0]

if not rv_map_feature.request_measurable([base_rv_var]):
return None

data = node.inputs[1]
idx = indices_from_subtensor(getattr(node.op, "idx_list", None), node.inputs[2:])

# Create a new value variable with the indices `idx` set to `data`
value_var = rv_map_feature.rv_values[rv_var]
new_value_var = pt.set_subtensor(value_var[idx], data)
rv_map_feature.update_rv_maps(rv_var, new_value_var, base_rv_var)

# Return the `RandomVariable` being indexed
return [base_rv_var]


logprob_rewrites_db = SequenceDB()
logprob_rewrites_db.name = "logprob_rewrites_db"
# Introduce sigmoid. We do it before canonicalization so that useless mul are removed next
Expand All @@ -377,7 +331,6 @@ def incsubtensor_rv_replace(fgraph, node):
# (or eventually) the graph outputs. Often this is done by lifting other `Op`s
# "up" through the random/measurable variables and into their inputs.
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic")

logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")

Expand Down
Loading
Loading