Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default rewrites in inner graph #6996

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 108 additions & 4 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, Op, node_rewriter
from pytensor.graph.basic import walk
from pytensor.compile.builders import OpFromGraph
from pytensor.graph import Apply, Op, clone_replace, node_rewriter
from pytensor.graph.basic import io_toposort, walk
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
from pytensor.scan.op import Scan
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -201,18 +203,120 @@ def __str__(self):
return f"Check{{{self.msg}}}"


@node_rewriter(tracks=[CheckParameterValue])
class InnerGraphRewriter:
Copy link
Member

@ricardoV94 ricardoV94 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be overdoing it, but what do you think of implementing this as something like WalkingGraphRewriter: https://github.com/ricardoV94/pytensor/blob/9cf2d181f07dc99bbd2e7c9e2b4a3e1b0aeff034/pytensor/graph/rewriting/basic.py#L1998

A WalkingNestedGraphRewriter which applies the same node_rewriter to both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or a Scan OpFromGraph. It would be the WalkingNestedGraphRewriter that would apply that logic regardless of which NodeRewriter it's given?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, this should at least inherit from NodeRewriter, which already takes care of enforcing a self.transform is implemented, so you don't have to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, sounds great. I've been thinking about something like that, not sure how to best implement it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should be implemented in pytensorf. The specific rewrites that use it can be implemented here though

def transform(self, node: Apply) -> Union[Variable, str, None]:
"""
Return: "remove" if node should be removed
new_node if node should be replaced by new_node
None if node should not be replaced
"""
raise NotImplementedError

def apply(self, fgraph: FunctionGraph, node: Union[Scan, OpFromGraph]):
if not isinstance(node.op, (Scan, OpFromGraph)):
raise TypeError("Expected Scan or OpFromGraph in InernerGraphRewriter")
node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs
local_fgraph_topo = io_toposort(node_inputs, node_outputs)
op = node.op

givens = dict()
to_remove_set = set()
for nd in local_fgraph_topo:
if nd not in to_remove_set:
if isinstance(nd.op, (Scan, OpFromGraph)):
new_node = self.rewrite(nd)
if new_node is not None:
givens.update(zip(nd.outputs, new_node.owner.outputs))
to_remove_set.add(nd)
else:
transform = self.transform(nd)
if transform == "remove":
givens[nd.outputs[0]] = nd.inputs[0]
to_remove_set.add(nd)
elif transform is not None:
givens.update(zip(nd.outputs, transform.owner.outputs))
to_remove_set.add(nd)

if len(to_remove_set) == 0:
return None
op_outs = clone_replace(node_outputs, replace=givens)
if isinstance(op, Scan):
nwScan = Scan(
node_inputs,
op_outs,
op.info,
mode=op.mode,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
name=op.name,
allow_gc=op.allow_gc,
)
nw_node = nwScan(*(node.inputs))

else:
nwOpFromGraph = OpFromGraph(
node_inputs,
op_outs,
op.is_inline,
op.lop_overrides,
op.grad_overrides,
op.rop_overrides,
connection_pattern=op._connection_pattern,
name=op.name,
**op.kwargs,
)
nw_node = nwOpFromGraph(*(node.inputs))
return nw_node


class RemoveCheckParameterInnerGraph(InnerGraphRewriter):
def transform(self, node):
if isinstance(node.op, CheckParameterValue):
return "remove"
else:
return None


class ReplaceCheckParameterInnerGraph(InnerGraphRewriter):
def transform(self, node):
if isinstance(node.op, CheckParameterValue):
if node.op.can_be_replaced_by_ninf:
logp_expr, *logp_conds = node.inputs
if len(logp_conds) > 1:
logp_cond = pt.all(logp_conds)
else:
(logp_cond,) = logp_conds
new_node = pt.switch(logp_cond, logp_expr, -np.inf)

if new_node.dtype != node.outputs[0].dtype:
new_node = pt.cast(new_node, node.outputs[0].dtype)
return new_node
return None


@node_rewriter(tracks=[CheckParameterValue, Scan, OpFromGraph])
def local_remove_check_parameter(fgraph, node):
"""Rewrite that removes CheckParameterValue

This is used when compile_rv_inplace
"""
if isinstance(node.op, (Scan, OpFromGraph)):
new_node = RemoveCheckParameterInnerGraph().rewrite(node)
if new_node is None:
return None
return new_node if isinstance(new_node, list) else [new_node]
if isinstance(node.op, CheckParameterValue):
return [node.inputs[0]]


@node_rewriter(tracks=[CheckParameterValue])
@node_rewriter(tracks=[CheckParameterValue, Scan, OpFromGraph])
def local_check_parameter_to_ninf_switch(fgraph, node):
if isinstance(node.op, (Scan, OpFromGraph)):
new_node = ReplaceCheckParameterInnerGraph().rewrite(node)
if new_node is None:
return None
return new_node if isinstance(new_node, list) else [new_node]

if not node.op.can_be_replaced_by_ninf:
return None

Expand Down
134 changes: 134 additions & 0 deletions tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,140 @@ def test_check_parameters_can_be_replaced_by_ninf(self):
with pytest.raises(ParameterValueError, match="test"):
fn([-1, 2, 3])

def test_check_parameters_removed_from_scan(self):
def scan_step(x_0):
cond = pt.ge(x_0, 1)
x = check_parameters(x_0, cond)
x_update = collect_default_updates([x])
return x, x_update

xs, _ = scan(
fn=scan_step,
sequences=[
pt.zeros(3),
],
name="xs",
)

with pytest.raises(ParameterValueError):
pytensor.function([], xs)()

with pm.Model() as m:
pass

m.check_bounds = False
with m:
fn = compile_pymc([], xs)
assert np.all(fn() == 0)

m.check_bounds = True
with m:
fn = compile_pymc([], xs)
assert np.all(fn() == -np.inf)
Comment on lines +363 to +374
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of messing with models and compile_pymc, let's just use pytensor.function directly and pass the rewrites we want.

pytensor.function(..., mode=get_mode().including("rewrite_name"))

These tests should be in logprob/test_utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change and move tests


def test_check_parameters_removed_from_nested_scan(self):
def inner_scan_step(x_0):
cond = pt.ge(x_0, 1)
x = check_parameters(x_0, cond)
x_update = collect_default_updates([x])
return x, x_update

def outer_scan_step(x_0):
x, _ = scan(
fn=inner_scan_step,
sequences=[
x_0,
],
name="xs",
)
x_update = collect_default_updates([x])
return x, x_update

xs, _ = scan(
fn=outer_scan_step,
sequences=[
pt.zeros((3, 2)),
],
name="xs",
)
with pytest.raises(ParameterValueError):
pytensor.function([], xs)()

with pm.Model() as m:
pass

m.check_bounds = False
with m:
fn = compile_pymc([], xs)
assert np.all(fn() == 0)

def test_check_parameters_can_be_replaced_by_ninf_in_scan(self):
def scan_step(x_0):
cond = pt.ge(x_0, 0)
x = check_parameters(x_0, cond, can_be_replaced_by_ninf=True)
x_update = collect_default_updates([x])
return x, x_update

xs, _ = scan(
fn=scan_step,
sequences=[
pt.as_tensor_variable([-1.0, 0.0, 1.0]),
],
name="xs",
)
fn = compile_pymc([], xs)
np.testing.assert_array_equal(fn(), [-np.inf, 0, 1])

def test_check_parameters_can_be_removed_from_op_from_graph(self):
x, y, z = pt.scalars("xyz")
e = x + y * z
cond = pt.ge(e, 0)
e = check_parameters(e, cond)
op = OpFromGraph([x, y, z], [e])
e2 = op(x, y, z) + op(z, y, x)

with pm.Model() as m:
pass

with pytest.raises(ParameterValueError):
pytensor.function([x, y, z], e2)(-1, -2, -3)

m.check_bounds = False
with m:
fn = compile_pymc([x, y, z], e2)
assert fn(-1, -2, -3) == 4

m.check_bounds = True
with m:
fn = compile_pymc([x, y, z], e2)
assert np.all(fn(-1.0, -2.0, -3.0) == -np.inf)

def test_check_parameters_can_be_removed_from_nested_op_from_graph(self):
x, y, z = pt.scalars("xyz")
e = x + y
cond = pt.ge(e, 1)
e = check_parameters(e, cond)
op = OpFromGraph([x, y], [e])
e2 = op(x, y) * op(x, y)
op2 = OpFromGraph([x, y], [e2])
e3 = op2(x, y) + z

with pytest.raises(ParameterValueError):
pytensor.function([x, y, z], e3)(0, 0, 2)

with pm.Model() as m:
pass

m.check_bounds = False
with m:
fn = compile_pymc([x, y, z], e3)
assert fn(0, 0, 2) == 2

m.check_bounds = True
with m:
fn = compile_pymc([x, y, z], e3)
assert np.all(fn(0.0, 0.0, 2.0) == np.inf)

def test_compile_pymc_sets_rng_updates(self):
rng = pytensor.shared(np.random.default_rng(0))
x = pm.Normal.dist(rng=rng)
Expand Down
Loading