Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

subset and merge for VarInfo #543

Closed
wants to merge 20 commits into from
Closed

subset and merge for VarInfo #543

wants to merge 20 commits into from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Oct 8, 2023

To implement a Gibbs sampler based on condition a la TuringLang/Turing.jl#2099 we need two missing operations for VarInfo:

  • subset(varinfo, varnames): creates a VarInfo from varinfo based only on varnames.
  • merge(varinfo_left, varinfo_right): attempts to merge two instances of VarInfo into a single VarInfo. These can include both shared and not-shared variables.

Here's an example:

julia> using Distributions, DynamicPPL

julia> @model function demo()
           s ~ InverseGamma(2, 3)
           m ~ Normal(0, sqrt(s))
           x = Vector{Float64}(undef, 2)
           x[1] ~ Normal(m, sqrt(s))
           x[2] ~ Normal(m, sqrt(s))
       end

demo (generic function with 2 methods)

julia> model = demo();

julia> varinfo = VarInfo(model);

julia> keys(varinfo)
4-element Vector{VarName}:
 s
 m
 x[1]
 x[2]

julia> varinfo[@varname(s)] = 1;

julia> varinfo[@varname(m)] = 2;

julia> varinfo[@varname(x[1])] = 3;

julia> varinfo[@varname(x[2])] = 4;

julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0

julia> # Extract one with only `m`.
       varinfo_subset1 = subset(varinfo, [@varname(m),]);


julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
 m

julia> varinfo_subset1[@varname(m)]
2.0

julia> # Extract one with both `s` and `x[2]`.
       varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);

julia> keys(varinfo_subset2)
2-element Vector{VarName}:
 s
 x[2]

julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
 1.0
 4.0

julia> # Merge the two.
       varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);

julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
 m
 s
 x[2]

julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
 1.0
 2.0
 4.0

julia> # Merge the two with the original.
       varinfo_merged = merge(varinfo, varinfo_subset_merged);

julia> keys(varinfo_merged)
4-element Vector{VarName}:
 s
 m
 x[1]
 x[2]

julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
 1.0
 2.0
 3.0
 4.0

Note that none of these are, at the moment, implemented for SimpleVarInfo. This is both because (1) we don't need it yet, and (2) this will support a much smaller set of combinations.

@torfjelde
Copy link
Member Author

Note that this is based on #542 (because I wanted to test the Gibbs PR with this)

@torfjelde torfjelde force-pushed the torfjelde/varinfo-ops branch from bde00d0 to ac91cfc Compare October 8, 2023 23:18
src/varinfo.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde torfjelde force-pushed the torfjelde/varinfo-ops branch from ac91cfc to 743c8b6 Compare October 8, 2023 23:21
Comment on lines +452 to +455
push!(
vals.args,
:(merge_metadata(metadata_left.$sym, metadata_right.$sym))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
push!(
vals.args,
:(merge_metadata(metadata_left.$sym, metadata_right.$sym))
)
push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym)))

@torfjelde
Copy link
Member Author

Replaced by #544 to avoid dependence on #542

@torfjelde torfjelde closed this Oct 9, 2023
@yebai yebai deleted the torfjelde/varinfo-ops branch October 12, 2023 12:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant