Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default rewrites in inner graph #6996

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

aerubanov
Copy link
Contributor

@aerubanov aerubanov commented Nov 7, 2023

This PR will add implementation of default rewrites for nodes with inner fgraph and related with issue #6697. For now I just added new test cases


📚 Documentation preview 📚: https://pymc--6996.org.readthedocs.build/en/6996/

@aerubanov aerubanov marked this pull request as draft November 7, 2023 12:59
Copy link

codecov bot commented Nov 7, 2023

Codecov Report

Merging #6996 (b8fde47) into main (547bcb4) will decrease coverage by 4.25%.
Report is 11 commits behind head on main.
The diff coverage is 94.11%.

❗ Current head b8fde47 differs from pull request most recent head 0c80a68. Consider uploading reports for the commit 0c80a68 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6996      +/-   ##
==========================================
- Coverage   92.17%   87.92%   -4.25%     
==========================================
  Files         101      101              
  Lines       16849    16891      +42     
==========================================
- Hits        15530    14851     -679     
- Misses       1319     2040     +721     
Files Coverage Δ
pymc/logprob/utils.py 96.19% <94.11%> (-1.31%) ⬇️

... and 8 files with indirect coverage changes

@ricardoV94
Copy link
Member

The tests are great!

@aerubanov
Copy link
Contributor Author

Logic for removing CheckParameterValues from Scan is very similar to one we recently added in #6873 for moment calculation in Scan, so may be I will move it to some helper function later

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 14, 2023

There is some refactoring going on in #6976 that will impact the location of these changes.

One thing we need to think about is nested Scans/OpFromGraph, I think this approach will only work for the first level.

PS: Agree with refactoring the functionality if it's so similar

@aerubanov
Copy link
Contributor Author

aerubanov commented Nov 15, 2023

I added recursive check for nested Scans. For now it is only remove CheckParameterValues nodes from fgraph, but I will add replacement by ninf next step

@aerubanov aerubanov force-pushed the default-rewrites-in-inner-graph branch from 0ee8404 to b33dd78 Compare November 18, 2023 15:19
@aerubanov
Copy link
Contributor Author

@ricardoV94 I added logic for replacement CheckParameterValues by ninf in Scan and rebased my branch. Now I am going to add support for OpFromGraph and want to extract common logic for rewrites in inner fgraphs - probably some helper function. I think we can add some general solution for inner fgraphs rewriting.

@aerubanov
Copy link
Contributor Author

I added support for replacements in OpFromGraph

@aerubanov
Copy link
Contributor Author

I move some common logic into InnerGraphRewrite class

@aerubanov aerubanov marked this pull request as ready for review November 30, 2023 16:48
@aerubanov
Copy link
Contributor Author

@ricardoV94 Could you please take a look when you will have some time?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Functionality wise I think it's complete. I just have some suggestions to make it more similar to other PyTensor rewriters, and some test organization. Feel free to push back on my suggestions.

@@ -201,18 +203,120 @@ def __str__(self):
return f"Check{{{self.msg}}}"


@node_rewriter(tracks=[CheckParameterValue])
class InnerGraphRewriter:
Copy link
Member

@ricardoV94 ricardoV94 Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be overdoing it, but what do you think of implementing this as something like WalkingGraphRewriter: https://github.com/ricardoV94/pytensor/blob/9cf2d181f07dc99bbd2e7c9e2b4a3e1b0aeff034/pytensor/graph/rewriting/basic.py#L1998

A WalkingNestedGraphRewriter which applies the same node_rewriter to both the outer graph and inner graphs. The idea is you would pass the previous rewrite which doesn't distinguish between the core case or a Scan OpFromGraph. It would be the WalkingNestedGraphRewriter that would apply that logic regardless of which NodeRewriter it's given?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not, this should at least inherit from NodeRewriter, which already takes care of enforcing a self.transform is implemented, so you don't have to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, sounds great. I've been thinking about something like that, not sure how to best implement it.

@@ -201,18 +203,120 @@ def __str__(self):
return f"Check{{{self.msg}}}"


@node_rewriter(tracks=[CheckParameterValue])
class InnerGraphRewriter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class should be implemented in pytensorf. The specific rewrites that use it can be implemented here though

Comment on lines +363 to +374
with pm.Model() as m:
pass

m.check_bounds = False
with m:
fn = compile_pymc([], xs)
assert np.all(fn() == 0)

m.check_bounds = True
with m:
fn = compile_pymc([], xs)
assert np.all(fn() == -np.inf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of messing with models and compile_pymc, let's just use pytensor.function directly and pass the rewrites we want.

pytensor.function(..., mode=get_mode().including("rewrite_name"))

These tests should be in logprob/test_utils.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change and move tests

pymc/logprob/utils.py Outdated Show resolved Hide resolved
Co-authored-by: Ricardo Vieira <[email protected]>
@aerubanov
Copy link
Contributor Author

I will re-implement this with WalkingNestedGraphRewriter after merging pymc-devs/pytensor#556

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants