-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce value variables in logprob IR #7491
Conversation
36312a2
to
32bb231
Compare
The failing test should pass after the changes in #7480 |
32bb231
to
def56fe
Compare
Also introduce MeasurableOpMixin for string representation
def56fe
to
f777864
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7491 +/- ##
==========================================
+ Coverage 92.15% 92.42% +0.27%
==========================================
Files 103 103
Lines 17208 17104 -104
==========================================
- Hits 15858 15809 -49
+ Misses 1350 1295 -55
|
This is pretty substantial. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ricardoV94, great PR, as always, especially with so much extension in abstraction :)
A lot of my comments are to clarify some parts of the code and me thinking out loud as I went over the PR
Are you able to add a test showing an IR graph after applying rewrites from early_measurable_ir_rewrites_db
but before measurable_ir_rewrites_db
? Perhaps this could highlight where and why exactly PromisedValuedRV
s are needed
pymc/logprob/abstract.py
Outdated
|
||
|
||
class PromisedValuedRV(Op): | ||
r"""Marks a variable as being promised a valued variable in the logprob method.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we directly provide a corresponding ValuedRV
during the rewrite? Just asking to clarify. I see that this (only?) is used for Join
s and MakeVector
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revisiting this comment after skimming a bit more
I now see that there are a set of newly defined early measurable ir rewrites and I imagine that these promised value RVs are temporarily substituted into that intermediate graph before applying the other logprob rewrites. Why does this approach prevent breaking interdependencies?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're trying to fix the same problem of not breaking dependencies. The funny thing with Join/MakeVector is that they combine multiple RVs, potentially interdependent, into a single composite valued node. Only in the logp function is this value split and sent to each component, but we still want to prevent the same outer-problem that motivated this PR. If you have:
x1 = pt.random.normal() * 5
x2 = pt.random.normal(x1 * 8)
xs = pt.stack([x1, x2])
xs_vv = pt.vector("xs_vv", shape=(2,))
By the time you get to the logp of xs_vv
, you want to basically do the same thing that the new test does, plus concatenate the logp terms. To prevent the conditioning from breaking during the IR rewrites we would want to introduce the ValuedRV nodes. We have to do this as an early rewrite, before anything comes up and breaks it. In normal situations we also do it before any rewrites in the construct_ir_fgraph
code, but for join/stack this already accounts as some sort of inference.
Now why Promised and not just vanilla Valued? Just for convenience, because we still want a function from xs_vv
to stack([logp(x1), logp(x2)])
in the end, and if I split the values, I wouldn't know how to stack them later in the loop in conditional_logp
(or if this was a single logp
call), or even know that I needed to. So we just "promise" there will be values to avoid rewrites from breaking the dependency that will be required in the logp function.
Another way of thinking is that we are basically trying to truncate the graphs between valued nodes and do manipulations only within these subgraphs, not across them. We could have literally split the graph, done IR in each and collect logp expressions. What we are doing is identical. Either way, we would still need to do a further split for the RVs within Join/MakeVector.
b_value = b.type() | ||
logp_b = conditional_logp({a: a_value, b: b_value})[b_value] | ||
|
||
assert_no_rvs(logp_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brainstorming to see if I understand here. From #6917, IIUC assert_no_rvs(logp_b)
would throw a warning (or an error?).
When would we want b
to be rewritten as pm.Normal.dist(a_base * 40)
? When, say, a_base_value = a_base.type()
is provided in conditional_logp
? My understanding is that, while both situations are mathematically equivalent, PyMC's log-prob inference does not work well when value variables (a_value
) of deterministic transformations of RandomVariable
s (a
) are provided in lieu of their value variables (a_base.type()
; not provided). Is this correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the general problem is that a second variable cannot depend on a transformation of the value variable. We don't have the machinery to do that sort of inversion (other than the ad-hoc transform_value rewrite). For example ,the following is not supported:
x = pt.random.normal()
x_exp = pt.exp(x)
y = pt.random.normal(loc=x)
x_exp_vv = pt.scalar("x_exp_vv")
y_vv = pt.scalar("y_vv")
conditional_logprob({x_exp: x_exp_vv, y: y_vv})
There is nothing wrong in principle, but we don't have the machinery to find out that the density of y
should depend on a (log) transform of x_exp_vv
. Would be nice to have, but I haven't come across an elegant solution. The changes in this PR prevent our rewrites from introducing such indirections by accident by changing the IR of this test example to something like.
a_base = pm.Normal.dist()
a = valued_rv(a_base * 5, a_value)
b = valued_rv(pm.Normal.dist(a * 8), b_value)
Since there are no default PyTensor rewrites that know what to do with a valued_rv
, there is no risk of "mixing" information before and after the conditioning points (in this case, constant_folding 5 * 8 = 40
in the graph of b
)
This avoids rewrites across conditioning points, that could break dependencies Also extend logprob derivation of scans with multiple valued output types
778272d
to
bac9f6e
Compare
@larryshamalama thanks for the thoughtful questions! I cleaned up the answers I gave above and added them in the docstrings of |
Description
This supersedes stale #6918, it introduces ValuedRV nodes in the IR so rewrites can transparently reason about the conditioning points. The main purpose is to simplify the IR rewrite logic.
It also prevents default PyTensor rewrites from breaking dependency on valued RVs (which was behind the bug in #6917)
It also fixes some limitations in derived Scans, and makes it more strict. For instance, #6909 now fails explicitly instead of returning wrong result silently.
Related Issue
📚 Documentation preview 📚: https://pymc--7491.org.readthedocs.build/en/7491/