-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default rewrites in inner graph #6996
base: main
Are you sure you want to change the base?
Changes from all commits
865ec77
8564e17
3db798c
1946a8f
0588119
0593d68
6f539ec
b33dd78
ff4b33c
c257ade
cdda94d
92d4069
f30d3be
726a53d
d4c88df
b8fde47
0c80a68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,11 +43,13 @@ | |
|
||
from pytensor import Variable | ||
from pytensor import tensor as pt | ||
from pytensor.graph import Apply, Op, node_rewriter | ||
from pytensor.graph.basic import walk | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.graph import Apply, Op, clone_replace, node_rewriter | ||
from pytensor.graph.basic import io_toposort, walk | ||
from pytensor.graph.op import HasInnerGraph | ||
from pytensor.link.c.type import CType | ||
from pytensor.raise_op import CheckAndRaise | ||
from pytensor.scan.op import Scan | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.variable import TensorVariable | ||
|
||
|
@@ -201,18 +203,120 @@ def __str__(self): | |
return f"Check{{{self.msg}}}" | ||
|
||
|
||
@node_rewriter(tracks=[CheckParameterValue]) | ||
class InnerGraphRewriter: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This class should be implemented in |
||
def transform(self, node: Apply) -> Union[Variable, str, None]: | ||
""" | ||
Return: "remove" if node should be removed | ||
new_node if node should be replaced by new_node | ||
None if node should not be replaced | ||
""" | ||
raise NotImplementedError | ||
|
||
def apply(self, fgraph: FunctionGraph, node: Union[Scan, OpFromGraph]): | ||
if not isinstance(node.op, (Scan, OpFromGraph)): | ||
raise TypeError("Expected Scan or OpFromGraph in InernerGraphRewriter") | ||
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs | ||
local_fgraph_topo = io_toposort(node_inputs, node_outputs) | ||
op = node.op | ||
|
||
givens = dict() | ||
to_remove_set = set() | ||
for nd in local_fgraph_topo: | ||
if nd not in to_remove_set: | ||
if isinstance(nd.op, (Scan, OpFromGraph)): | ||
new_node = self.rewrite(nd) | ||
if new_node is not None: | ||
givens.update(zip(nd.outputs, new_node.owner.outputs)) | ||
to_remove_set.add(nd) | ||
else: | ||
transform = self.transform(nd) | ||
if transform == "remove": | ||
givens[nd.outputs[0]] = nd.inputs[0] | ||
to_remove_set.add(nd) | ||
elif transform is not None: | ||
givens.update(zip(nd.outputs, transform.owner.outputs)) | ||
to_remove_set.add(nd) | ||
|
||
if len(to_remove_set) == 0: | ||
return None | ||
op_outs = clone_replace(node_outputs, replace=givens) | ||
if isinstance(op, Scan): | ||
nwScan = Scan( | ||
node_inputs, | ||
op_outs, | ||
op.info, | ||
mode=op.mode, | ||
profile=op.profile, | ||
truncate_gradient=op.truncate_gradient, | ||
name=op.name, | ||
allow_gc=op.allow_gc, | ||
) | ||
nw_node = nwScan(*(node.inputs)) | ||
|
||
else: | ||
nwOpFromGraph = OpFromGraph( | ||
node_inputs, | ||
op_outs, | ||
op.is_inline, | ||
op.lop_overrides, | ||
op.grad_overrides, | ||
op.rop_overrides, | ||
connection_pattern=op._connection_pattern, | ||
name=op.name, | ||
**op.kwargs, | ||
) | ||
nw_node = nwOpFromGraph(*(node.inputs)) | ||
return nw_node | ||
|
||
|
||
class RemoveCheckParameterInnerGraph(InnerGraphRewriter): | ||
def transform(self, node): | ||
if isinstance(node.op, CheckParameterValue): | ||
return "remove" | ||
else: | ||
return None | ||
|
||
|
||
class ReplaceCheckParameterInnerGraph(InnerGraphRewriter): | ||
def transform(self, node): | ||
if isinstance(node.op, CheckParameterValue): | ||
if node.op.can_be_replaced_by_ninf: | ||
logp_expr, *logp_conds = node.inputs | ||
if len(logp_conds) > 1: | ||
logp_cond = pt.all(logp_conds) | ||
else: | ||
(logp_cond,) = logp_conds | ||
new_node = pt.switch(logp_cond, logp_expr, -np.inf) | ||
|
||
if new_node.dtype != node.outputs[0].dtype: | ||
new_node = pt.cast(new_node, node.outputs[0].dtype) | ||
return new_node | ||
return None | ||
|
||
|
||
@node_rewriter(tracks=[CheckParameterValue, Scan, OpFromGraph]) | ||
def local_remove_check_parameter(fgraph, node): | ||
"""Rewrite that removes CheckParameterValue | ||
|
||
This is used when compile_rv_inplace | ||
""" | ||
if isinstance(node.op, (Scan, OpFromGraph)): | ||
new_node = RemoveCheckParameterInnerGraph().rewrite(node) | ||
if new_node is None: | ||
return None | ||
return new_node if isinstance(new_node, list) else [new_node] | ||
if isinstance(node.op, CheckParameterValue): | ||
return [node.inputs[0]] | ||
|
||
|
||
@node_rewriter(tracks=[CheckParameterValue]) | ||
@node_rewriter(tracks=[CheckParameterValue, Scan, OpFromGraph]) | ||
def local_check_parameter_to_ninf_switch(fgraph, node): | ||
if isinstance(node.op, (Scan, OpFromGraph)): | ||
new_node = ReplaceCheckParameterInnerGraph().rewrite(node) | ||
if new_node is None: | ||
return None | ||
return new_node if isinstance(new_node, list) else [new_node] | ||
|
||
if not node.op.can_be_replaced_by_ninf: | ||
return None | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,6 +342,140 @@ def test_check_parameters_can_be_replaced_by_ninf(self): | |
with pytest.raises(ParameterValueError, match="test"): | ||
fn([-1, 2, 3]) | ||
|
||
def test_check_parameters_removed_from_scan(self): | ||
def scan_step(x_0): | ||
cond = pt.ge(x_0, 1) | ||
x = check_parameters(x_0, cond) | ||
x_update = collect_default_updates([x]) | ||
return x, x_update | ||
|
||
xs, _ = scan( | ||
fn=scan_step, | ||
sequences=[ | ||
pt.zeros(3), | ||
], | ||
name="xs", | ||
) | ||
|
||
with pytest.raises(ParameterValueError): | ||
pytensor.function([], xs)() | ||
|
||
with pm.Model() as m: | ||
pass | ||
|
||
m.check_bounds = False | ||
with m: | ||
fn = compile_pymc([], xs) | ||
assert np.all(fn() == 0) | ||
|
||
m.check_bounds = True | ||
with m: | ||
fn = compile_pymc([], xs) | ||
assert np.all(fn() == -np.inf) | ||
Comment on lines
+363
to
+374
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of messing with models and pytensor.function(..., mode=get_mode().including("rewrite_name")) These tests should be in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will change and move tests |
||
|
||
def test_check_parameters_removed_from_nested_scan(self): | ||
def inner_scan_step(x_0): | ||
cond = pt.ge(x_0, 1) | ||
x = check_parameters(x_0, cond) | ||
x_update = collect_default_updates([x]) | ||
return x, x_update | ||
|
||
def outer_scan_step(x_0): | ||
x, _ = scan( | ||
fn=inner_scan_step, | ||
sequences=[ | ||
x_0, | ||
], | ||
name="xs", | ||
) | ||
x_update = collect_default_updates([x]) | ||
return x, x_update | ||
|
||
xs, _ = scan( | ||
fn=outer_scan_step, | ||
sequences=[ | ||
pt.zeros((3, 2)), | ||
], | ||
name="xs", | ||
) | ||
with pytest.raises(ParameterValueError): | ||
pytensor.function([], xs)() | ||
|
||
with pm.Model() as m: | ||
pass | ||
|
||
m.check_bounds = False | ||
with m: | ||
fn = compile_pymc([], xs) | ||
assert np.all(fn() == 0) | ||
|
||
def test_check_parameters_can_be_replaced_by_ninf_in_scan(self): | ||
def scan_step(x_0): | ||
cond = pt.ge(x_0, 0) | ||
x = check_parameters(x_0, cond, can_be_replaced_by_ninf=True) | ||
x_update = collect_default_updates([x]) | ||
return x, x_update | ||
|
||
xs, _ = scan( | ||
fn=scan_step, | ||
sequences=[ | ||
pt.as_tensor_variable([-1.0, 0.0, 1.0]), | ||
], | ||
name="xs", | ||
) | ||
fn = compile_pymc([], xs) | ||
np.testing.assert_array_equal(fn(), [-np.inf, 0, 1]) | ||
|
||
def test_check_parameters_can_be_removed_from_op_from_graph(self): | ||
x, y, z = pt.scalars("xyz") | ||
e = x + y * z | ||
cond = pt.ge(e, 0) | ||
e = check_parameters(e, cond) | ||
op = OpFromGraph([x, y, z], [e]) | ||
e2 = op(x, y, z) + op(z, y, x) | ||
|
||
with pm.Model() as m: | ||
pass | ||
|
||
with pytest.raises(ParameterValueError): | ||
pytensor.function([x, y, z], e2)(-1, -2, -3) | ||
|
||
m.check_bounds = False | ||
with m: | ||
fn = compile_pymc([x, y, z], e2) | ||
assert fn(-1, -2, -3) == 4 | ||
|
||
m.check_bounds = True | ||
with m: | ||
fn = compile_pymc([x, y, z], e2) | ||
assert np.all(fn(-1.0, -2.0, -3.0) == -np.inf) | ||
|
||
def test_check_parameters_can_be_removed_from_nested_op_from_graph(self): | ||
x, y, z = pt.scalars("xyz") | ||
e = x + y | ||
cond = pt.ge(e, 1) | ||
e = check_parameters(e, cond) | ||
op = OpFromGraph([x, y], [e]) | ||
e2 = op(x, y) * op(x, y) | ||
op2 = OpFromGraph([x, y], [e2]) | ||
e3 = op2(x, y) + z | ||
|
||
with pytest.raises(ParameterValueError): | ||
pytensor.function([x, y, z], e3)(0, 0, 2) | ||
|
||
with pm.Model() as m: | ||
pass | ||
|
||
m.check_bounds = False | ||
with m: | ||
fn = compile_pymc([x, y, z], e3) | ||
assert fn(0, 0, 2) == 2 | ||
|
||
m.check_bounds = True | ||
with m: | ||
fn = compile_pymc([x, y, z], e3) | ||
assert np.all(fn(0.0, 0.0, 2.0) == np.inf) | ||
|
||
def test_compile_pymc_sets_rng_updates(self): | ||
rng = pytensor.shared(np.random.default_rng(0)) | ||
x = pm.Normal.dist(rng=rng) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be overdoing it, but what do you think of implementing this as something like WalkingGraphRewriter: https://github.com/ricardoV94/pytensor/blob/9cf2d181f07dc99bbd2e7c9e2b4a3e1b0aeff034/pytensor/graph/rewriting/basic.py#L1998
A
WalkingNestedGraphRewriter
which applies the samenode_rewriter
to both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or aScan
OpFromGraph
. It would be theWalkingNestedGraphRewriter
that would apply that logic regardless of whichNodeRewriter
it's given?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, this should at least inherit from
NodeRewriter
, which already takes care of enforcing aself.transform
is implemented, so you don't have to.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, sounds great. I've been thinking about something like that, not sure how to best implement it.