-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default rewrites in inner graph #6996
base: main
Are you sure you want to change the base?
Default rewrites in inner graph #6996
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6996 +/- ##
==========================================
- Coverage 92.17% 87.92% -4.25%
==========================================
Files 101 101
Lines 16849 16891 +42
==========================================
- Hits 15530 14851 -679
- Misses 1319 2040 +721
|
The tests are great! |
8e5d012
to
cbdafd2
Compare
Logic for removing |
There is some refactoring going on in #6976 that will impact the location of these changes. One thing we need to think about is nested Scans/OpFromGraph, I think this approach will only work for the first level. PS: Agree with refactoring the functionality if it's so similar |
I added recursive check for nested |
0ee8404
to
b33dd78
Compare
@ricardoV94 I added logic for replacement |
I added support for replacements in OpFromGraph |
I move some common logic into |
@ricardoV94 Could you please take a look when you will have some time? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great. Functionality wise I think it's complete. I just have some suggestions to make it more similar to other PyTensor rewriters, and some test organization. Feel free to push back on my suggestions.
@@ -201,18 +203,120 @@ def __str__(self): | |||
return f"Check{{{self.msg}}}" | |||
|
|||
|
|||
@node_rewriter(tracks=[CheckParameterValue]) | |||
class InnerGraphRewriter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be overdoing it, but what do you think of implementing this as something like WalkingGraphRewriter: https://github.com/ricardoV94/pytensor/blob/9cf2d181f07dc99bbd2e7c9e2b4a3e1b0aeff034/pytensor/graph/rewriting/basic.py#L1998
A WalkingNestedGraphRewriter
which applies the same node_rewriter
to both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or a Scan
OpFromGraph
. It would be the WalkingNestedGraphRewriter
that would apply that logic regardless of which NodeRewriter
it's given?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If not, this should at least inherit from NodeRewriter
, which already takes care of enforcing a self.transform
is implemented, so you don't have to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, sounds great. I've been thinking about something like that, not sure how to best implement it.
@@ -201,18 +203,120 @@ def __str__(self): | |||
return f"Check{{{self.msg}}}" | |||
|
|||
|
|||
@node_rewriter(tracks=[CheckParameterValue]) | |||
class InnerGraphRewriter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class should be implemented in pytensorf
. The specific rewrites that use it can be implemented here though
with pm.Model() as m: | ||
pass | ||
|
||
m.check_bounds = False | ||
with m: | ||
fn = compile_pymc([], xs) | ||
assert np.all(fn() == 0) | ||
|
||
m.check_bounds = True | ||
with m: | ||
fn = compile_pymc([], xs) | ||
assert np.all(fn() == -np.inf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of messing with models and compile_pymc
, let's just use pytensor.function
directly and pass the rewrites we want.
pytensor.function(..., mode=get_mode().including("rewrite_name"))
These tests should be in logprob/test_utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change and move tests
Co-authored-by: Ricardo Vieira <[email protected]>
I will re-implement this with |
This PR will add implementation of default rewrites for nodes with inner
fgraph
and related with issue #6697. For now I just added new test cases📚 Documentation preview 📚: https://pymc--6996.org.readthedocs.build/en/6996/