Skip to content

Commit

Permalink
Merge functionality of pytensorf and logprob/utils
Browse files Browse the repository at this point in the history
Also fixes circular imports
  • Loading branch information
ricardoV94 committed Nov 14, 2023
1 parent 0044bf1 commit b4160a9
Show file tree
Hide file tree
Showing 15 changed files with 385 additions and 686 deletions.
5 changes: 2 additions & 3 deletions pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"Kron",
]

from pymc.pytensorf import constant_fold

TensorLike = Union[np.ndarray, TensorVariable]
IntSequence = Union[np.ndarray, Sequence[int]]

Expand Down Expand Up @@ -183,9 +185,6 @@ def n_dims(self) -> int:
def _slice(self, X, Xs=None):
xdims = X.shape[-1]
if isinstance(xdims, Variable):
# Circular dependency
from pymc.pytensorf import constant_fold

[xdims] = constant_fold([xdims])
if self.input_dim != xdims:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import pytensor.tensor as pt

from pytensor.compile import SharedVariable
from pytensor.graph import ancestors
from pytensor.tensor.variable import TensorConstant
from scipy.cluster.vq import kmeans

# Avoid circular dependency when importing modelcontext
from pymc.distributions.distribution import Distribution
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, walk_model
from pymc.pytensorf import compile_pymc

_ = Distribution # keep both pylint and black happy

Expand All @@ -48,7 +49,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
model = modelcontext(model)

inputs, input_names = [], []
for rv in walk_model(vars_needed):
for rv in ancestors(vars_needed):
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
inputs.append(rv)
input_names.append(rv.name)
Expand Down
16 changes: 7 additions & 9 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

TensorLike: TypeAlias = Union[Variable, float, np.ndarray]

Expand All @@ -76,7 +77,7 @@ def _find_unallowed_rvs_in_graph(graph):

return {
rv
for rv in find_rvs_in_graph(graph)
for rv in rvs_in_graph(graph)
if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV))
}

Expand Down Expand Up @@ -530,11 +531,9 @@ def conditional_logp(
continue

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes
# that follow.
remapped_vars, _ = rvs_to_value_vars(
q_values + list(node.inputs),
initial_replacements=replacements,
remapped_vars = replace_vars_in_graphs(
graphs=q_values + list(node.inputs),
replacements=replacements,
)
q_values = remapped_vars[: len(q_values)]
q_rv_inputs = remapped_vars[len(q_values) :]
Expand Down Expand Up @@ -562,8 +561,7 @@ def conditional_logp(

logprob_vars[q_value_var] = q_logprob_var

# Recompute test values for the changes introduced by the
# replacements above.
# Recompute test values for the changes introduced by the replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
Expand Down
3 changes: 1 addition & 2 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableSpecifyShape(SpecifyShape):
Expand Down Expand Up @@ -107,8 +108,6 @@ class MeasurableCheckAndRaise(CheckAndRaise):

@_logprob.register(MeasurableCheckAndRaise)
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
from pymc.pytensorf import replace_rvs_by_values

(value,) = values
# transfer assertion from rv to value
assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value})
Expand Down
10 changes: 2 additions & 8 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values
from pymc.pytensorf import constant_fold


def is_newaxis(x):
Expand Down Expand Up @@ -255,9 +256,6 @@ def get_stack_mixture_vars(
mixture_rvs = joined_rvs.owner.inputs

elif isinstance(joined_rvs.owner.op, Join):
# TODO: Find better solution to avoid this circular dependency
from pymc.pytensorf import constant_fold

join_axis = joined_rvs.owner.inputs[0]
# TODO: Support symbolic join axes. This will raise ValueError if it's not a constant
(join_axis,) = constant_fold((join_axis,), raise_not_constant=False)
Expand Down Expand Up @@ -351,9 +349,6 @@ def logprob_MixtureRV(
comp_rvs = [comp[None] for comp in comp_rvs]
original_shape = (len(comp_rvs),)
else:
# TODO: Find better solution to avoid this circular dependency
from pymc.pytensorf import constant_fold

join_axis_val = constant_fold((join_axis,))[0].item()
original_shape = shape_tuple(comp_rvs[0])

Expand Down Expand Up @@ -544,7 +539,6 @@ def find_measurable_ifelse_mixture(fgraph, node):
@_logprob.register(MeasurableIfElse)
def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for an `IfElse`."""
from pymc.pytensorf import replace_rvs_by_values

assert len(values) * 2 == len(base_rvs)

Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
logprob_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.pytensorf import replace_rvs_by_values
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableScan(Scan):
Expand Down
7 changes: 2 additions & 5 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
assume_measured_ir_outputs,
measurable_ir_rewrites_db,
)
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values
from pymc.pytensorf import constant_fold


@node_rewriter([Alloc])
Expand Down Expand Up @@ -131,7 +132,6 @@ class MeasurableMakeVector(MakeVector):
def logprob_make_vector(op, values, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `MeasurableMakeVector`."""
# TODO: Sort out this circular dependency issue
from pymc.pytensorf import replace_rvs_by_values

(value,) = values

Expand All @@ -158,9 +158,6 @@ class MeasurableJoin(Join):
@_logprob.register(MeasurableJoin)
def logprob_join(op, values, axis, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for a `Join`."""
# TODO: Find better way to avoid circular dependency
from pymc.pytensorf import constant_fold, replace_rvs_by_values

(value,) = values

base_rv_shapes = [base_var.shape[axis] for base_var in base_rvs]
Expand Down
Loading

0 comments on commit b4160a9

Please sign in to comment.