Skip to content

Commit

Permalink
Remove deprecated function rvs_to_value_vars
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Nov 14, 2023
1 parent c5115ee commit 0044bf1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 105 deletions.
60 changes: 0 additions & 60 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,66 +271,6 @@ def expand_replace(var):
return graphs, replacements


def rvs_to_value_vars(
graphs: Iterable[Variable],
apply_transforms: bool = True,
**kwargs,
) -> List[Variable]:
"""Clone and replace random variables in graphs with their value variables.
This will *not* recompute test values in the resulting graphs.
Parameters
----------
graphs
The graphs in which to perform the replacements.
apply_transforms
If ``True``, apply each value variable's transform.
"""
warnings.warn(
"rvs_to_value_vars is deprecated. Use model.replace_rvs_by_values instead",
FutureWarning,
)

def populate_replacements(
random_var: TensorVariable, replacements: Dict[TensorVariable, TensorVariable]
) -> List[TensorVariable]:
# Populate replacements dict with {rv: value} pairs indicating which graph
# RVs should be replaced by what value variables.

value_var = getattr(
random_var.tag, "observations", getattr(random_var.tag, "value_var", None)
)

# No value variable to replace RV with
if value_var is None:
return []

transform = getattr(value_var.tag, "transform", None)
if transform is not None and apply_transforms:
# We want to replace uses of the RV by the back-transformation of its value
value_var = transform.backward(value_var, *random_var.owner.inputs)

replacements[random_var] = value_var

# Also walk the graph of the value variable to make any additional replacements
# if that is not a simple input variable
return [value_var]

# Clone original graphs
inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
equiv = clone_get_equiv(inputs, graphs, False, False, {})
graphs = [equiv[n] for n in graphs]

graphs, _ = _replace_vars_in_graphs(
graphs,
replacement_fn=populate_replacements,
**kwargs,
)

return graphs


def replace_rvs_by_values(
graphs: Sequence[TensorVariable],
*,
Expand Down
69 changes: 24 additions & 45 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
replace_rng_nodes,
replace_rvs_by_values,
reseed_rngs,
rvs_to_value_vars,
walk_model,
)
from pymc.testing import assert_no_rvs
Expand Down Expand Up @@ -671,8 +670,7 @@ def test_constant_fold_raises():
class TestReplaceRVsByValues:
@pytest.mark.parametrize("symbolic_rv", (False, True))
@pytest.mark.parametrize("apply_transforms", (True, False))
@pytest.mark.parametrize("test_deprecated_fn", (True, False))
def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
def test_basic(self, symbolic_rv, apply_transforms):
# Interval transform between last two arguments
interval = (
Interval(bounds_fn=lambda *args: (args[-2], args[-1])) if apply_transforms else None
Expand All @@ -696,15 +694,11 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
b_value_var = m.rvs_to_values[b]
c_value_var = m.rvs_to_values[c]

if test_deprecated_fn:
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
(res,) = rvs_to_value_vars((d,), apply_transforms=apply_transforms)
else:
(res,) = replace_rvs_by_values(
(d,),
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)
(res,) = replace_rvs_by_values(
(d,),
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)

assert res.owner.op == pt.add
log_output = res.owner.inputs[0]
Expand Down Expand Up @@ -740,8 +734,7 @@ def test_basic(self, symbolic_rv, apply_transforms, test_deprecated_fn):
else:
assert a_value_var not in res_ancestors

@pytest.mark.parametrize("test_deprecated_fn", (True, False))
def test_unvalued_rv(self, test_deprecated_fn):
def test_unvalued_rv(self):
with pm.Model() as m:
x = pm.Normal("x")
y = pm.Normal.dist(x)
Expand All @@ -751,15 +744,11 @@ def test_unvalued_rv(self, test_deprecated_fn):
x_value = m.rvs_to_values[x]
z_value = m.rvs_to_values[z]

if test_deprecated_fn:
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
(res,) = rvs_to_value_vars((out,))
else:
(res,) = replace_rvs_by_values(
(out,),
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)
(res,) = replace_rvs_by_values(
(out,),
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)

assert res.owner.op == pt.add
assert res.owner.inputs[0] is z_value
Expand All @@ -769,8 +758,7 @@ def test_unvalued_rv(self, test_deprecated_fn):
assert res_y.owner.op == pt.random.normal
assert res_y.owner.inputs[3] is x_value

@pytest.mark.parametrize("test_deprecated_fn", (True, False))
def test_no_change_inplace(self, test_deprecated_fn):
def test_no_change_inplace(self):
# Test that calling rvs_to_value_vars in models with nested transformations
# does not change the original rvs in place. See issue #5172
with pm.Model() as m:
Expand All @@ -784,22 +772,17 @@ def test_no_change_inplace(self, test_deprecated_fn):
before = pytensor.clone_replace(m.free_RVs)

# This call would change the model free_RVs in place in #5172
if test_deprecated_fn:
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
rvs_to_value_vars(m.potentials)
else:
replace_rvs_by_values(
m.potentials,
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)
replace_rvs_by_values(
m.potentials,
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)

after = pytensor.clone_replace(m.free_RVs)
assert equal_computations(before, after)

@pytest.mark.parametrize("test_deprecated_fn", (True, False))
@pytest.mark.parametrize("reversed", (False, True))
def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
def test_interdependent_transformed_rvs(self, reversed):
# Test that nested transformed variables, whose transformed values depend on other
# RVs are properly replaced
with pm.Model() as m:
Expand All @@ -815,15 +798,11 @@ def test_interdependent_transformed_rvs(self, reversed, test_deprecated_fn):
if reversed:
rvs = rvs[::-1]

if test_deprecated_fn:
with pytest.warns(FutureWarning, match="Use model.replace_rvs_by_values instead"):
transform_values = rvs_to_value_vars(rvs)
else:
transform_values = replace_rvs_by_values(
rvs,
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)
transform_values = replace_rvs_by_values(
rvs,
rvs_to_values=m.rvs_to_values,
rvs_to_transforms=m.rvs_to_transforms,
)

for transform_value in transform_values:
assert_no_rvs(transform_value)
Expand Down

0 comments on commit 0044bf1

Please sign in to comment.