Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the content of ad.jl from Turing.jl to here #571

Merged
merged 28 commits into from
Feb 14, 2024
Merged

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented Jan 23, 2024

Twin PR from Turing.jl

Description:

  • Make ADTypes and LogDensityProblemsAD direct deps, which should be fine given their small sizes
  • Add package extension for ForwardDiff and ReverseDiff to automate LogDensityProblemsAD.ADgradient(::ADType, ::LogDensityFunction)
  • Add some AD tests using TestUtils.DEMO_MODELS
  • Fix some bugs so the tests can pass (with help from @torfjelde)

@sunxd3 sunxd3 changed the title Move the content of ad.jl and ad_utils.jl from Turing.jl to here [WIP] Move the content of ad.jl and ad_utils.jl from Turing.jl to here Jan 23, 2024
@sunxd3 sunxd3 marked this pull request as draft January 23, 2024 14:14
ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved
test/ext/DynamicPPLADTypesExt.jl Outdated Show resolved Hide resolved

# AD related code
getADType(spl::Sampler) = getADType(spl.alg)
getADType(::SampleFromPrior) = ADTypes.AutoForwardDiff(; chunksize=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is copied from Turing but I'd like to mention anyway that somehow this seems wrong - I think either we should remove it (is it actually needed?) or make it possible to adjust the AD type (similar to the HMC algs it could be saved as a field of the struct).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left these functions in Turing.

In DynamicPPL we should probably only define two-arg ADgradient functions, so it's always required to say what ADType to use.

@sunxd3 sunxd3 changed the title [WIP] Move the content of ad.jl and ad_utils.jl from Turing.jl to here [WIP] Move the content of ad.jl from Turing.jl to here Jan 26, 2024
src/simple_varinfo.jl Outdated Show resolved Hide resolved
sunxd3 and others added 2 commits January 30, 2024 17:07
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3 sunxd3 marked this pull request as ready for review January 30, 2024 17:10
@sunxd3 sunxd3 requested review from torfjelde and devmotion January 30, 2024 17:16
@sunxd3
Copy link
Member Author

sunxd3 commented Jan 30, 2024

Looks like the changes break something, will investigate later

@coveralls
Copy link

coveralls commented Jan 31, 2024

Pull Request Test Coverage Report for Build 7775567755

  • 24 of 30 (80.0%) changed or added relevant lines in 5 files are covered.
  • 43 unchanged lines in 10 files lost coverage.
  • Overall coverage increased (+0.05%) to 84.695%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/DynamicPPL.jl 1 3 33.33%
ext/DynamicPPLForwardDiffExt.jl 14 18 77.78%
Files with Coverage Reduction New Missed Lines %
src/model.jl 1 89.22%
src/context_implementations.jl 2 63.77%
src/model_utils.jl 2 37.5%
src/prob_macro.jl 2 87.59%
src/utils.jl 2 82.28%
src/test_utils.jl 3 87.16%
src/DynamicPPL.jl 5 22.22%
src/simple_varinfo.jl 6 80.66%
src/varinfo.jl 8 92.6%
src/threadsafe.jl 12 53.1%
Totals Coverage Status
Change from base Build 7755514813: 0.05%
Covered Lines: 2706
Relevant Lines: 3195

💛 - Coveralls

Copy link

codecov bot commented Jan 31, 2024

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (c33eeae) 84.32% compared to head (7b84ba1) 84.37%.

Files Patch % Lines
ext/DynamicPPLForwardDiffExt.jl 88.88% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #571      +/-   ##
==========================================
+ Coverage   84.32%   84.37%   +0.05%     
==========================================
  Files          26       28       +2     
  Lines        3183     3207      +24     
==========================================
+ Hits         2684     2706      +22     
- Misses        499      501       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sunxd3
Copy link
Member Author

sunxd3 commented Feb 2, 2024

@yebai @torfjelde @devmotion tests are passing, another look?

test/ad.jl Outdated Show resolved Hide resolved
Co-authored-by: Tor Erlend Fjelde <[email protected]>
@sunxd3 sunxd3 changed the title [WIP] Move the content of ad.jl from Turing.jl to here Move the content of ad.jl from Turing.jl to here Feb 2, 2024
@sunxd3
Copy link
Member Author

sunxd3 commented Feb 2, 2024

Okay, more errors, fixing...

@sunxd3
Copy link
Member Author

sunxd3 commented Feb 2, 2024

The error with ReverseDiff was that, for UntypedVarInfo, getindex(varinfo, ::Colon) can return a Vector{Real}.

When given this as the input, the gradients will be initialized as Int 0s, because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
zero(D) when D is Real returns 0.

test/ad.jl Outdated Show resolved Hide resolved
sunxd3 and others added 3 commits February 2, 2024 17:42
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
test/ad.jl Outdated Show resolved Hide resolved
sunxd3 and others added 2 commits February 4, 2024 12:32
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3
Copy link
Member Author

sunxd3 commented Feb 4, 2024

I disabled the Zygote tests -- they just seem too temperamental, the tests with ForwardDiff and ReverseDiff are all passing now.

Another look @torfjelde?

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Starting to look real nice! A few more changes, but after that I think we'll be good to go:) Nice work!

Might be nice to have @devmotion have a look though as he's been more involved in the transition to ADTypes

Project.toml Outdated Show resolved Hide resolved
src/varinfo.jl Outdated Show resolved Hide resolved
ext/DynamicPPLReverseDiffExt.jl Outdated Show resolved Hide resolved
test/ad.jl Outdated Show resolved Hide resolved
test/ad.jl Outdated Show resolved Hide resolved
@sunxd3
Copy link
Member Author

sunxd3 commented Feb 5, 2024

@torfjelde thanks for the suggestions and the help in debugging earlier. I made the updates accordingly.

@devmotion a quick look maybe?

@yebai yebai merged commit b924a17 into master Feb 14, 2024
13 checks passed
@yebai yebai deleted the sunxd/move_ad branch February 14, 2024 15:18
@yebai
Copy link
Member

yebai commented Feb 14, 2024

Thanks @sunxd3 and @torfjelde!

@torfjelde
Copy link
Member

Just a heads up wrt. these types of changes for the future: we usually make these releases breaking if we worry that it might have downstream effects (even when the change itself is not technically breaking) so as to avoid issues like TuringLang/Turing.jl#2173

@sunxd3
Copy link
Member Author

sunxd3 commented Feb 21, 2024

@torfjelde breaking for DynamicPPL or both?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants