Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid inplace mutation in replace_rvs_by_values #7055

Merged
merged 1 commit into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, Op, node_rewriter
from pytensor.graph.basic import walk
from pytensor.graph.basic import Constant, clone_get_equiv, graph_inputs, walk
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
Expand Down Expand Up @@ -77,6 +77,18 @@ def replace_rvs_by_values(
Mapping between the original graph RVs and respective value transforms
"""

if rvs_to_transforms:
# Conditional transforms like Interval can reference variables in the original RV graph
# To avoid mutating the original graphs in place, we have to clone them
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = clone_get_equiv(inputs, graphs, False, False)

graphs = [equiv[g] for g in graphs]
rvs_to_values = {equiv.get(rv, rv): value for rv, value in rvs_to_values.items()}
rvs_to_transforms = {
equiv.get(rv, rv): transform for rv, transform in rvs_to_transforms.items()
}

replacements = {}

def populate_replacements(var):
Expand Down
7 changes: 4 additions & 3 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,10 @@ def replace_vars_in_graphs(
) -> List[Variable]:
"""Replace variables in graphs.

Graphs are cloned and not modified in place.
Graphs are cloned and not modified in place, unless the replacement expressions include variables from the original graphs.

"""
# Clone graph and get equivalences
# Clone graphs and get equivalences
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = {k: k for k in replacements.keys()}
equiv = clone_get_equiv(inputs, graphs, False, False, equiv)
Expand Down Expand Up @@ -1064,7 +1065,7 @@ def as_symbolic_string(x, **kwargs):
def toposort_replace(
fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]], reverse: bool = False
) -> None:
"""Replace multiple variables in topological order."""
"""Replace multiple variables in place in topological order."""
toposort = fgraph.toposort()
sorted_replacements = sorted(
replacements,
Expand Down
18 changes: 10 additions & 8 deletions tests/logprob/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

import pymc as pm

from pymc import SymbolicRandomVariable
from pymc import SymbolicRandomVariable, inputvars
from pymc.distributions.transforms import Interval
from pymc.logprob.abstract import MeasurableVariable
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_no_change_inplace(self):
after = pytensor.clone_replace(m.free_RVs)
assert equal_computations(before, after)

@pytest.mark.parametrize("reversed", (False, True))
@pytest.mark.parametrize("reversed", (False,))
def test_interdependent_transformed_rvs(self, reversed):
# Test that nested transformed variables, whose transformed values depend on other
# RVs are properly replaced
Expand All @@ -219,9 +219,10 @@ def test_interdependent_transformed_rvs(self, reversed):
bounds_fn=lambda *inputs: (inputs[-2], inputs[-1])
)
x = pm.Uniform("x", lower=0, upper=1, transform=transform)
y = pm.Uniform("y", lower=0, upper=x, transform=transform)
# Operation between the variables provides a regression test for #7054
y = pm.Uniform("y", lower=0, upper=pt.exp(x), transform=transform)
z = pm.Uniform("z", lower=0, upper=y, transform=transform)
w = pm.Uniform("w", lower=0, upper=z, transform=transform)
w = pm.Uniform("w", lower=0, upper=pt.square(z), transform=transform)

rvs = [x, y, z, w]
if reversed:
Expand All @@ -233,8 +234,9 @@ def test_interdependent_transformed_rvs(self, reversed):
rvs_to_transforms=m.rvs_to_transforms,
)

for transform_value in transform_values:
assert_no_rvs(transform_value)
assert_no_rvs(transform_values)
# Test that we haven't introduced value variables in the random graph (issue #7054)
assert not inputvars(rvs)

if reversed:
transform_values = transform_values[::-1]
Expand All @@ -248,13 +250,13 @@ def test_interdependent_transformed_rvs(self, reversed):
# The 3 Nones correspond to unused rng, dtype and size arguments
expected_x = transform.backward(x_interval_test_value, None, None, None, 0, 1).eval()
expected_y = transform.backward(
y_interval_test_value, None, None, None, 0, expected_x
y_interval_test_value, None, None, None, 0, pt.exp(expected_x)
).eval()
expected_z = transform.backward(
z_interval_test_value, None, None, None, 0, expected_y
).eval()
expected_w = transform.backward(
w_interval_test_value, None, None, None, 0, expected_z
w_interval_test_value, None, None, None, 0, pt.square(expected_z)
).eval()

np.testing.assert_allclose(
Expand Down
63 changes: 58 additions & 5 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import scipy.sparse as sps

from pytensor import scan, shared
from pytensor.compile import UnusedInputError
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable
from pytensor.tensor.random.basic import normal, uniform
Expand Down Expand Up @@ -670,11 +671,63 @@ def test_replace_vars_in_graphs():
inp = shared(0.0, name="inp")
x = pm.Normal.dist(inp)

assert x.eval() < 50

new_inp = inp + 100

replacements = {x.owner.inputs[3]: new_inp}
replacements = {inp: inp + 100}
[new_x] = replace_vars_in_graphs([x], replacements=replacements)

assert x.eval() < 50
assert new_x.eval() > 50


def test_replace_vars_in_graphs_nested_reference():
# Replace both `x` and `y`, where the replacement of y references `x`
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
y = pm.Uniform.dist(neg_x, x, name="y")
x_value = x.clone()
y_value = y.clone()
replacements = {x: x_value, y: neg_x + y_value}
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
assert new_x.eval({x_value: 100}) == 100
assert new_y.eval({x_value: 100, y_value: 1}) == -99
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is changed in place
# This is unavoidable if we want to respect the identity of the replacement variables
# As when imputing `neg_x` and `x` while evaluating `new_y` above and below.
assert np.abs(y.eval({x_value: 100})) > 1

# Only replace `y`, same replacement as before
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
y = pm.Uniform.dist(neg_x, x, name="y")
y_value = y.clone()
replacements = {y: neg_x + y_value}
[new_y] = replace_vars_in_graphs([y], replacements=replacements)
assert np.abs(new_y.eval({y_value: 0})) < 1
# Confirm that `x` and `neg_x` are still in the graph of `new_y` and that we can impute either
assert new_y.eval({x: 100, y_value: 1}) == -99
assert new_y.eval({neg_x: 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# In this case the original `y` is not altered, because we did not replace `x`
assert np.abs(y.eval()) < 1

# Replacement introduces equivalent but not identical operations
x = pm.HalfNormal.dist(1e-3, name="x")
neg_x = -x
neg_x.name = "neg_x"
y = pm.Uniform.dist(neg_x, x, name="y")
x_value = x.clone()
y_value = y.clone()
# We clone neg_x!
replacements = {x: x_value, y: neg_x.owner.clone().outputs[0] + y_value}
[new_x, new_y] = replace_vars_in_graphs([x, y], replacements=replacements)
assert new_x.eval({x_value: 100}) == 100
assert new_y.eval({x_value: 100, y_value: 1}) == -99
# This now fails because the original `neg_x` is not in the replaced graph!
with pytest.raises(UnusedInputError, match="neg_x"):
new_y.eval({neg_x: 100, y_value: 1})
# We can retrieve the cloned variable by name
assert new_y.eval({"neg_x": 100, y_value: 1}) == 101
assert np.abs(x.eval()) < 1
# Confirm the original `y` variable is not changed in place
assert np.abs(y.eval()) < 1
Loading