Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify signatures of graph_replace and clone_replace #398

Merged
merged 3 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from copy import copy
from typing import Optional
from typing import Optional, Sequence, Union, overload

from pytensor.compile.function.types import Function, UnusedInputError, orig_function
from pytensor.compile.io import In, Out
Expand All @@ -15,16 +15,117 @@
from pytensor.graph.fg import FunctionGraph


@overload
def rebuild_collect_shared(
outputs,
outputs: Variable,
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
):
) -> tuple[
list[Variable],
Variable,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...


@overload
def rebuild_collect_shared(
outputs: Sequence[Variable],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Variable],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...


@overload
def rebuild_collect_shared(
outputs: Out,
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Out,
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...


@overload
def rebuild_collect_shared(
outputs: Sequence[Out],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
list[Out],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
...


def rebuild_collect_shared(
outputs: Union[Sequence[Variable], Variable, Out, Sequence[Out]],
inputs=None,
replace=None,
updates=None,
rebuild_strict=True,
copy_inputs_over=True,
no_default_updates=False,
clone_inner_graphs=False,
) -> tuple[
list[Variable],
Union[list[Variable], Variable, Out, list[Out]],
tuple[
dict[Variable, Variable],
dict[SharedVariable, Variable],
list[Variable],
list[SharedVariable],
],
]:
r"""Replace subgraphs of a computational graph.

It returns a set of dictionaries and lists which collect (partial?)
Expand Down Expand Up @@ -260,7 +361,7 @@ def clone_inputs(i):
return (
input_variables,
cloned_outputs,
[clone_d, update_d, update_expr, shared_inputs],
(clone_d, update_d, update_expr, shared_inputs),
)


Expand Down
131 changes: 90 additions & 41 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,56 @@
from functools import partial
from typing import (
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)

from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs
from typing import Iterable, Optional, Sequence, Union, cast, overload

from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph


ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]


def _format_replace(replace: Optional[ReplaceTypes] = None) -> dict[Variable, Variable]:
items: dict[Variable, Variable]
if isinstance(replace, dict):
# PyLance has issues with type resolution
items = cast(dict[Variable, Variable], replace)
elif isinstance(replace, Iterable):
items = dict(replace)
elif replace is None:
items = {}
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
return items


@overload
def clone_replace(
output: Sequence[Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> list[Variable]:
...


@overload
def clone_replace(
output: Collection[Variable],
output: Variable,
replace: Optional[
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
] = None,
**rebuild_kwds,
) -> List[Variable]:
) -> Variable:
...


def clone_replace(
output: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
**rebuild_kwds,
) -> Union[list[Variable], Variable]:
"""Clone a graph and replace subgraphs within it.

It returns a copy of the initial subgraph with the corresponding
Expand All @@ -39,40 +68,49 @@ def clone_replace(
"""
from pytensor.compile.function.pfunc import rebuild_collect_shared

items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
if isinstance(replace, dict):
items = list(replace.items())
elif isinstance(replace, (list, tuple)):
items = replace
elif replace is None:
items = []
else:
raise ValueError(
"replace is neither a dictionary, list, "
f"tuple or None ! The value provided is {replace},"
f"of type {type(replace)}"
)
items = list(_format_replace(replace).items())

tmp_replace = [(x, x.type()) for x, y in items]
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)

# TODO Explain why we call it twice ?!
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)

return cast(List[Variable], outs)
return outs


@overload
def graph_replace(
outputs: Variable,
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> Variable:
...


@overload
def graph_replace(
outputs: Sequence[Variable],
replace: Dict[Variable, Variable],
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> list[Variable]:
...


def graph_replace(
outputs: Union[Sequence[Variable], Variable],
replace: Optional[ReplaceTypes] = None,
*,
strict=True,
) -> List[Variable]:
) -> Union[list[Variable], Variable]:
"""Replace variables in ``outputs`` by ``replace``.

Parameters
----------
outputs: Sequence[Variable]
outputs: Union[Sequence[Variable], Variable]
Output graph
replace: Dict[Variable, Variable]
Replace mapping
Expand All @@ -83,20 +121,26 @@ def graph_replace(

Returns
-------
List[Variable]
Output graph with subgraphs replaced
Union[Variable, List[Variable]]
Output graph with subgraphs replaced, see function overload for the exact type

Raises
------
ValueError
If some replacemens could not be applied and strict is True
If some replacements could not be applied and strict is True
"""
as_list = False
if not isinstance(outputs, Sequence):
outputs = [outputs]
else:
as_list = True
replace_dict = _format_replace(replace)
# collect minimum graph inputs which is required to compute outputs
# and depend on replacements
# additionally remove constants, they do not matter in clone get equiv
conditions = [
c
for c in truncated_graph_inputs(outputs, replace)
for c in truncated_graph_inputs(outputs, replace_dict)
if not isinstance(c, Constant)
]
# for the function graph we need the clean graph where
Expand All @@ -117,7 +161,7 @@ def graph_replace(
# replace the conditions back
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv})
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly

Expand All @@ -126,12 +170,14 @@ def graph_replace(
# So far FunctionGraph does these replacements inplace it is thus unsafe
# apply them using fg.replace, it may change the original graph
if strict:
non_fg_replace = {r: v for r, v in replace.items() if r not in equiv}
non_fg_replace = {r: v for r, v in replace_dict.items() if r not in equiv}
if non_fg_replace:
raise ValueError(f"Some replacements were not used: {non_fg_replace}")
toposort = fg.toposort()

def toposort_key(fg: FunctionGraph, ts, pair):
def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
) -> int:
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
Expand All @@ -148,4 +194,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
return list(fg.outputs)
if as_list:
return list(fg.outputs)
else:
return fg.outputs[0]
11 changes: 11 additions & 0 deletions tests/graph/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ def test_graph_replace(self):
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w

def test_non_list_input(self):
x = MyVariable("x")
y = MyVariable("y")
o = MyOp("xyop")(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new")
# test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x

def test_graph_replace_advanced(self):
x = MyVariable("x")
y = MyVariable("y")
Expand Down
Loading