From c4402b3ea0983f4dc797a8e82b2a543ddb5f013a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 16:35:40 +0100 Subject: [PATCH 1/3] use `subsumes` in `subset` to allow more flexibility in subsetting varinfos --- src/simple_varinfo.jl | 15 +++++++++------ src/varinfo.jl | 6 +++++- test/varinfo.jl | 27 ++++++++++++++++++++++++++- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ad37130d6..2458c12f3 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -430,18 +430,21 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) end function _subset(x::AbstractDict, vns) - # NOTE: This requires `vns` to be explicitly present in `x`. - if any(!Base.Fix1(haskey, x), vns) + vns_present = collect(keys(x)) + vns_found = mapreduce(vcat, vns) do vn + return map(Base.Fix1(subsumes, vn), vs_present) + end + + # NOTE: This `vns` to be subsume varnames explicitly present in `x`. + if isempty(vns_found) throw( ArgumentError( - "Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " * - "For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " * - "`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.", + "Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.", ), ) end C = ConstructionBase.constructorof(typeof(x)) - return C(vn => x[vn] for vn in vns) + return C(vn => x[vn] for vn in vns_found) end function _subset(x::NamedTuple, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index c8c46ee27..e9b496dac 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -264,8 +264,12 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) end -function subset(metadata::Metadata, vns::AbstractVector{<:VarName}) +function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) # TODO: Should we error if `vns` contains a variable that is not in `metadata`? + # For each `vn` in `vns`, get the variables subsumed by `vn`. + vns = mapreduce(vcat, vns_given) do vn + filter(Base.Fix1(subsumes, vn), metadata.vns) + end indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) indices = Dict(vn => i for (i, vn) in enumerate(vns)) # Construct new `vals` and `ranges`. diff --git a/test/varinfo.jl b/test/varinfo.jl index 71e341767..15566a596 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -483,7 +483,14 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) [@varname(s), @varname(m), @varname(x[2])], [@varname(s), @varname(x[1]), @varname(x[2])], [@varname(m), @varname(x[1]), @varname(x[2])], - [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], + ] + + # Patterns requiring `subsumes`. + vns_supported_with_subsumes = [ + [@varname(s), @varname(x)] => [@varname(s), @varname(x[1]), @varname(x[2])], + [@varname(m), @varname(x)] => [@varname(m), @varname(x[1]), @varname(x[2])], + [@varname(s), @varname(m), @varname(x)] => + [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] # `SimpleaVarInfo` only supports subsetting using the varnames as they appear @@ -516,6 +523,24 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Values should be the same. @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] end + + @testset "$(convert(Vector{VarName}, vns_subset))" for ( + vns_subset, vns_target + ) in vns_supported_with_subsumes + varinfo_subset = subset(varinfo, vns_subset) + # Should now only contain the variables in `vns_subset`. + check_varinfo_keys(varinfo_subset, vns_target) + # Values should be the same. + @test [varinfo_subset[vn] for vn in vns_target] == [varinfo[vn] for vn in vns_target] + + # `merge` with the original. + varinfo_merged = merge(varinfo, varinfo_subset) + vns_merged = keys(varinfo_merged) + # Should be equivalent. + check_varinfo_keys(varinfo_merged, vns) + # Values should be the same. + @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + end end # For certain varinfos we should have errors. From 80370f415722291c1c4eda473208cdcee758b31c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 16 Apr 2024 08:33:40 +0100 Subject: [PATCH 2/3] fixed bug in `subset` for `AbstractDict` --- src/simple_varinfo.jl | 2 +- test/varinfo.jl | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2458c12f3..b128b234d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -432,7 +432,7 @@ end function _subset(x::AbstractDict, vns) vns_present = collect(keys(x)) vns_found = mapreduce(vcat, vns) do vn - return map(Base.Fix1(subsumes, vn), vs_present) + return filter(Base.Fix1(subsumes, vn), vns_present) end # NOTE: This `vns` to be subsume varnames explicitly present in `x`. diff --git a/test/varinfo.jl b/test/varinfo.jl index 15566a596..9e6781dd0 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -551,15 +551,6 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) varinfo, [@varname(s), @varname(m), @varname(x[1])] ) end - # `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model. - varinfo = varinfos[findfirst( - Base.Fix2(isa, SimpleVarInfo{<:AbstractDict}), varinfos - )] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x)] - ) - end end @testset "merge" begin From 8bdf5f6a76952fa134eecdf79984393ea3db6e32 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 16 Apr 2024 08:45:15 +0100 Subject: [PATCH 3/3] fixed bug where `subset` wasn't properly tested on `SimpleVarInfo` --- src/simple_varinfo.jl | 2 +- src/threadsafe.jl | 1 + test/varinfo.jl | 18 +++++++++++------- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b128b234d..ec079fac8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -459,7 +459,7 @@ function _subset(x::NamedTuple, vns) end syms = map(getsym, vns) - return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms))) + return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms))) end # `merge` diff --git a/src/threadsafe.jl b/src/threadsafe.jl index fb1cc1c0c..fe03fd3fd 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -200,6 +200,7 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) end +values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) diff --git a/test/varinfo.jl b/test/varinfo.jl index 9e6781dd0..806b3409b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -493,20 +493,24 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleaVarInfo` only supports subsetting using the varnames as they appear + # `SimpleVarInfo` only supports subsetting using the varnames as they appear # in the model. vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos_standard + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) # Added a `convert` to make the naming of the testsets a bit more readable. - vns_supported = if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple - vns_supported_simple - else - vns_supported_standard - end + # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, + ## i.e. `VarName{sym}()` without any indexing, etc. + vns_supported = + if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && + values_as(varinfo) isa NamedTuple + vns_supported_simple + else + vns_supported_standard + end @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported varinfo_subset = subset(varinfo, vns_subset)