Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flows module #24

Merged
merged 4 commits into from
Aug 2, 2023
Merged

Refactor flows module #24

merged 4 commits into from
Aug 2, 2023

Conversation

francois-rozet
Copy link
Member

The zuko/flows.py file has become too long. It makes it hard to navigate the code and add new features. I think splitting the module into sub-modules would help greatly. This PR is a first attempt to refactor the flows module.

@francois-rozet francois-rozet mentioned this pull request Jul 30, 2023
@francois-rozet francois-rozet force-pushed the refactor branch 2 times, most recently from 1d49e3f to 601a309 Compare August 1, 2023 14:08
* Split the flows module into many sub-modules
* Rename {Distribution,Transform}Module to {Distribution,Transform}Factory
* Spread the bisection bounds of monotonic transformations
* New rotation transformation (RotationTransform)
* New gaussianization transformation (GaussianizationTransform)
* New gaussianization flow (GF)
@francois-rozet
Copy link
Member Author

francois-rozet commented Aug 1, 2023

Hey @simonschnake, I have updated the README in this PR, do you think the new introduction is good? Feel free to make comments on the PR, if you find the time.

@simonschnake
Copy link
Contributor

simonschnake commented Aug 2, 2023

Hey @francois-rozet,
I think the README very well written and good understandable.
Are you planing to switch the naming from TransformModule to TranformFactory?

I would change one sentence in the README:
“This design allows for flows that behave like distributions while retaining the benefits of Module.“
->
“This design enables flows to act like distributions while maintaining the feature of trainable weights inherent in a Module.”
That is a bit more precise.

Another thing is spot was

flow = FlowModule(
    transforms=[
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
        Unconditional(PermutationTransform, torch.randperm(3), buffer=True),
        MaskedAutoregressiveTransform(3, 5, hidden_features=[128] * 3),
    ],
    base=Unconditional(
        DiagNormal,
        torch.zeros(3),
        torch.ones(3),
        buffer=True,
    ),
)

Here you take as an example, to make the second layer a PermutationTransform. One problem could be, that a new reader would think that this is the way to include permutations in a MaskedAutoregressiveFlow and not that this is a build in functionality of the MaskedAutoregressiveTransform. I would propose to pick a different example or include a sentence which points to that.

These are just small things. Overall, I think it is a very solid piece of work.

@francois-rozet
Copy link
Member Author

Thanks! I addressed your comments for the README. I will merge soon.

@francois-rozet francois-rozet marked this pull request as ready for review August 2, 2023 19:57
@francois-rozet francois-rozet merged commit f7e4f85 into master Aug 2, 2023
4 checks passed
@francois-rozet francois-rozet deleted the refactor branch August 2, 2023 20:00
@francois-rozet
Copy link
Member Author

The new API documentation is up at https://zuko.readthedocs.io/en/stable/api/zuko.flows.html 🔥 I think it is much more readable than before. I also added a few warnings about invertibility. Next step are tutorials.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants