From e2178c66b39553ea1c7228345694410ab17f9039 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Sep 2023 13:22:34 +0100 Subject: [PATCH] Fixes for SimpleVarInfo with `Ref` (#527) * added missing getlogp impl for SimpleVarInfo with Ref * included SimpleVarInfo with Ref in the TestUtils.setup_varinfos * bump patch version * moved impls of acclogp!! and setlogp!! for SimpleVarInfo next to each other --- Project.toml | 2 +- src/simple_varinfo.jl | 18 ++++++++++-------- src/test_utils.jl | 12 +++++++++--- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index a20e3546a..7931c5820 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.14" +version = "0.23.15" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 025b4aad7..a9d38fb07 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -259,17 +259,11 @@ end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) getlogp(vi::SimpleVarInfo) = vi.logp +getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] + setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi @@ -280,6 +274,14 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end +""" + keys(vi::SimpleVarInfo) + +Return an iterator of keys present in `vi`. +""" +Base.keys(vi::SimpleVarInfo) = keys(vi.values) +Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) + function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") diff --git a/src/test_utils.jl b/src/test_utils.jl index 6604b5df1..5028699f2 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -43,16 +43,22 @@ Return a tuple of instances for different implementations of `AbstractVarInfo` w each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. """ function setup_varinfos(model::Model, example_values::NamedTuple, varnames) - # <:VarInfo + # VarInfo vi_untyped = VarInfo() model(vi_untyped) vi_typed = DynamicPPL.TypedVarInfo(vi_untyped) - # <:SimpleVarInfo + # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) svi_untyped = SimpleVarInfo(OrderedDict()) + # SimpleVarInfo{<:Any,<:Ref} + svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) + svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) + lp = getlogp(vi_typed) - return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi + return map(( + vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref + )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) end