Skip to content

Commit

Permalink
Merge pull request #175 from Julia-Tempering/type-piracy
Browse files Browse the repository at this point in the history
Fix `==` piracy
  • Loading branch information
miguelbiron authored Nov 7, 2023
2 parents dd986f2 + 0d3a3e1 commit 225ffbc
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 146 deletions.
9 changes: 8 additions & 1 deletion docs/src/correctness.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,15 @@ parallelization might get used internally e.g. to parallelize likelihood evaluat
Here the check passed successfully as expected. But what
if you had a third-party target distribution that is not multi-threaded friendly?
For example some code sometimes write in global variables or
other non-thread safe constructs. In such situation, you can still use your thread-naive
other non-thread safe constructs. In such situation, you can still use your thread-naive
target over MPI *processes*.
For example, if the thread-unsafety comes from the use of global variables, then each
process will have its own copy of the global variables.

!!! note "Failed equality check"
If you are using a custom struct that is either mutable or containing
mutables, it
is possible that the check will fail even if your implementation is sound.
This is caused by `==` dispatching `===` on your type, which is too strict
for the purpose of comparing two deserialized checkpoints. See
[`recursive_equal`](@ref) for instructions on how to prevent this behavior.
10 changes: 2 additions & 8 deletions ext/PigeonsBridgeStanExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,9 @@ function Pigeons.variable(state::Pigeons.StanState, name::Symbol)
end
end

Pigeons.step!(explorer::AutoMALA, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)

Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)



Pigeons.variable_names(::Pigeons.StanState, log_potential) = BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true)


Expand All @@ -40,10 +35,9 @@ function Pigeons.slice_sample!(h::SliceSampler, state::Pigeons.StanState, log_po
end


Base.:(==)(a::StanLogPotential, b::StanLogPotential) =
Pigeons.recursive_equal(a::StanLogPotential, b::StanLogPotential) =
a.data == b.data && BridgeStan.name(a.model) == BridgeStan.name(b.model)
# TODO: Fix type piracy
Base.:(==)(a::StanRNG, b::StanRNG) = Pigeons.recursive_equal(a, b)
Pigeons.recursive_equal(a::StanRNG, b::StanRNG) = Pigeons._recursive_equal(a, b)

(log_potential::Pigeons.ScaledPrecisionNormalLogPotential)(x::Pigeons.StanState) = log_potential(x.unconstrained_parameters)
Random.rand!(rng::AbstractRNG, state::Pigeons.StanState{Vector{Float64}}, log_potential::Pigeons.ScaledPrecisionNormalLogPotential) =
Expand Down
42 changes: 13 additions & 29 deletions ext/PigeonsDynamicPPLExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,22 @@ function Pigeons.variable_names(state::DynamicPPL.TypedVarInfo, _)
all_names = fieldnames(typeof(state.metadata))
for var_name in all_names
var = state.metadata[var_name].vals
if var isa Number || (var isa Array && length(var) == 1)
if var isa Number || (var isa AbstractArray && length(var) == 1)
push!(result, var_name)
elseif var isa Array
elseif var isa AbstractArray
# flatten vector names following Turing convention
l = length(var)
for i in 1:l
for i in eachindex(var)
var_and_index_name =
Symbol(var_name, "[", join(ind2sub(size(var), i), ","), "]")
push!(result, var_and_index_name)
end
else
error()
error("don't know how to handle var `$var_name` of type $(typeof(var))")
end
end
return result
end

function Pigeons.step!(explorer::AutoMALA, replica, shared, vi::DynamicPPL.TypedVarInfo)
log_potential = Pigeons.find_log_potential(replica, shared.tempering, shared)
state = DynamicPPL.getall(vi)
Pigeons._extract_commons_and_run_auto_mala!(explorer, replica, shared, log_potential, state)
DynamicPPL.setall!(replica.state, state)
end

function Pigeons.slice_sample!(h::SliceSampler, state::DynamicPPL.TypedVarInfo, log_potential, cached_lp, replica)
cached_lp = Pigeons.cached_log_potential(log_potential, state, cached_lp)
for i in 1:length(state.metadata)
Expand All @@ -75,21 +67,13 @@ function Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, vi
end


## TODO: This is type piracy and should be fixed upstream
function Base.:(==)(a::DynamicPPL.TypedVarInfo, b::DynamicPPL.TypedVarInfo)
# as of Jan 2023, DynamicPPL does not supply == for TypedVarInfo
if length(a.metadata) != length(b.metadata)
return false
end
for i in 1:length(a.metadata)
if a.metadata[i].vals != b.metadata[i].vals
return false
end
end
return true
end
Pigeons.recursive_equal(a::DynamicPPL.TypedVarInfo, b::DynamicPPL.TypedVarInfo) =
# as of Nov 2023, DynamicPPL does not supply == for TypedVarInfo
length(a.metadata) == length(b.metadata) &&
variable_names(a,1) == variable_names(b,1) && # second argument is not used
DynamicPPL.getall(a) == DynamicPPL.getall(b)


Base.:(==)(a::TuringLogPotential, b::TuringLogPotential) = Pigeons.recursive_equal(a, b)
# TODO: Fix type piracy
Base.:(==)(a::DynamicPPL.Model, b::DynamicPPL.Model) = Pigeons.recursive_equal(a, b)
Base.:(==)(a::DynamicPPL.ConditionContext, b::DynamicPPL.ConditionContext) = Pigeons.recursive_equal(a, b)
Pigeons.recursive_equal(
a::Union{TuringLogPotential,DynamicPPL.Model,DynamicPPL.ConditionContext},
b) = Pigeons._recursive_equal(a, b)
1 change: 0 additions & 1 deletion src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import OnlineStats._fit!
import OnlineStats.value
import OnlineStats._merge!
import Random.rand!
import Base.(==)
import Base.keys
import Statistics.mean
import Statistics.var
Expand Down
6 changes: 3 additions & 3 deletions src/evidence/stepping_stone.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ end
# Determine which chains to use for normalization constant estimation

# For one-leg: all chains
stepping_stone_keys(::PT, log_sum_ratios, ::NonReversiblePT) = keys(log_sum_ratios)
stepping_stone_keys(::PT, log_sum_ratios::GroupBy, ::NonReversiblePT) = keys(log_sum_ratios.value)

# use only the variational leg for 2-legs PT
# rationale: for should give lower error for given compute since
# it the KL should be lower between target and variational
function stepping_stone_keys(pt::PT, log_sum_ratios, ::StabilizedPT)
function stepping_stone_keys(pt::PT, log_sum_ratios::GroupBy, ::StabilizedPT)
# Note: we rely on the variational leg being in increasing order
# (the roles of 2 legs were swapped on 2023/07/20)
indexer = pt.shared.tempering.indexer
variational_indices = Set(variational_leg_indices(indexer))
result = Vector{Tuple{Int, Int}}()
for (i, j) in keys(log_sum_ratios)
for (i, j) in keys(log_sum_ratios.value)
if i in variational_indices && j in variational_indices
push!(result, (i, j))
end
Expand Down
107 changes: 62 additions & 45 deletions src/pt/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ compare_checkpoints(checkpoint_folder1, checkpoint_folder2, immutables) =
end
end

function compare_serialized(file1, file2, immutables = nothing)
function compare_serialized(file1, file2)
first = deserialize(file1)
second = deserialize(file2)
if first != second
if !recursive_equal(first, second)
error(
"""
detected non-reproducibility, to investigate, type in the REPL:
Expand All @@ -96,20 +96,72 @@ function compare_serialized(file1, file2, immutables = nothing)
first = deserialize("$file1");
second = deserialize("$file2");
─────────────────────────────────
If you are using custom stuct, either mutable or containing
mutables, you may just need to add custom ==, see
If you are using a custom struct, either mutable or containing
mutables, you may just need to extend `recursive_equal`; see
src/pt/checks.jl.
"""
)
end
end

function Base.:(==)(a::GroupBy, b::GroupBy)

"""
$SIGNATURES
Recursively check equality between two objects by comparing their fields.
By default calls `==` but for certain types we dispatch a custom method.
This is necessary because for some mutable structs (and even immutable ones with
mutable fields) `==` actually dispatches `===`. The latter is too strict for the
purpose of checking that two checkpoints are equal.
If you are using custom struct and encounter a failed correctness check, you may
need to provide a special equality check for this type. In most cases it will be
enough to overload `recursive_equal` as follows
```julia
Pigeons.recursive_equal(a::MyType, b::MyType) = Pigeons._recursive_equal(a,b)
```
For examples of more specific checks, refer to the code of `PigeonsBridgeStanExt`.
"""
recursive_equal(a, b) = a==b

#=
For types on this list, we use the default recursive version `_recursive_equal`.
Note that this list is not exhaustive; some types in Pigeons' extensions
also call `_recursive_equal`.
=#
const RecursiveEqualInnerType =
Union{
StanState, SplittableRandom, Replica, Augmentation, AutoMALA, SliceSampler,
Compose, Mix, Iterators, Schedule, DEO, BlangTarget, NonReversiblePT,
InterpolatingPath, InterpolatedLogPotential, RoundTripRecorder,
OnlineStateRecorder, LocalBarrier, NamedTuple, Vector{<:InterpolatedLogPotential},
Vector{<:Replica}, Tuple, Inputs
}
recursive_equal(a::RecursiveEqualInnerType, b::RecursiveEqualInnerType) =
_recursive_equal(a,b)
function _recursive_equal(a::T, b::T, exclude::NTuple{N,Symbol}=()) where {T,N}
for f in fieldnames(T)
if !(f in exclude || recursive_equal(getfield(a, f), getfield(b, f)))
println("$f is different between a and b:\n\ta.f=$(getfield(a, f))\n\tb.f=$(getfield(b, f))")
return false
end
end
return true
end
_recursive_equal(a,b,exclude=nothing) = false # generic case catches difference in types of a and b

# types for which some fields need to be excluded
recursive_equal(a::Shared, b::Shared) = _recursive_equal(a, b, (:reports,))

#=
leaf methods of recursive_equal: these do not need to be recursive but are still
needed in place of the default `==`.
=#
function recursive_equal(a::GroupBy, b::GroupBy)
# as of Jan 2023, OnlineStat uses a default method of
# descending into the fields, which is somehow not valid for GroupBy,
# probably due to undeterminism of underlying OrderedCollections.OrderedDict
common_keys = keys(a)
if common_keys != keys(b)
common_keys = keys(a.value)
if common_keys != keys(b.value)
return false
end
for key in common_keys
Expand All @@ -121,10 +173,7 @@ function Base.:(==)(a::GroupBy, b::GroupBy)
end

# CovMatrix contains a cache matrix, which is NaN until value(.) is called
Base.:(==)(a::CovMatrix, b::CovMatrix) = value(a) == value(b)

Base.keys(a::GroupBy) = keys(a.value)

recursive_equal(a::CovMatrix, b::CovMatrix) = value(a) == value(b)

#=
Since the state reside in different processes, there are not generic way to
Expand All @@ -135,37 +184,5 @@ But we still want to perform checks on the rest of the PT state
TODO: in the future, add an optional get_hash() in the Stream protocol
to improve this.
=#
Base.:(==)(a::StreamState, b::StreamState) = true
Base.:(==)(a::NonReproducible, b::NonReproducible) = true

# mutable (incl imm with mut fields) structs do not have a nice ===, overload those:
# TODO: This is type-piracy we need to fix this
Base.:(==)(a::StanState, b::StanState) = recursive_equal(a, b)
Base.:(==)(a::SplittableRandom, b::SplittableRandom) = recursive_equal(a, b)
Base.:(==)(a::Replica, b::Replica) = recursive_equal(a, b)
Base.:(==)(a::Augmentation, b::Augmentation) = recursive_equal(a, b)
Base.:(==)(a::AutoMALA, b::AutoMALA) = recursive_equal(a, b)
Base.:(==)(a::SliceSampler, b::SliceSampler) = recursive_equal(a, b)
Base.:(==)(a::Compose, b::Compose) = recursive_equal(a, b)
Base.:(==)(a::Mix, b::Mix) = recursive_equal(a, b)
Base.:(==)(a::Iterators, b::Iterators) = recursive_equal(a, b)
Base.:(==)(a::Schedule, b::Schedule) = recursive_equal(a, b)
Base.:(==)(a::DEO, b::DEO) = recursive_equal(a, b)
Base.:(==)(a::Shared, b::Shared) = recursive_equal(a, b, [:reports])
Base.:(==)(a::BlangTarget, b::BlangTarget) = recursive_equal(a, b)
Base.:(==)(a::NonReversiblePT, b::NonReversiblePT) = recursive_equal(a, b)
Base.:(==)(a::InterpolatingPath, b::InterpolatingPath) = recursive_equal(a, b)
Base.:(==)(a::InterpolatedLogPotential, b::InterpolatedLogPotential) = recursive_equal(a, b)
Base.:(==)(a::RoundTripRecorder, b::RoundTripRecorder) = recursive_equal(a, b)
Base.:(==)(a::OnlineStateRecorder, b::OnlineStateRecorder) = recursive_equal(a, b)
Base.:(==)(a::LocalBarrier, b::LocalBarrier) = recursive_equal(a, b)


function recursive_equal(a::T, b::T, exclude = []) where {T}
for f in fieldnames(T)
if !(f in exclude) && (getfield(a, f) != getfield(b, f))
return false
end
end
return true
end
recursive_equal(a::StreamState, b::StreamState) = true
recursive_equal(a::NonReproducible, b::NonReproducible) = true
6 changes: 3 additions & 3 deletions test/test_checkpoint.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function compare_pts(p1, p2)
@test p1.replicas == p2.replicas
@test p1.shared == p2.shared
@test p1.reduced_recorders == p2.reduced_recorders
@test Pigeons.recursive_equal(p1.replicas, p2.replicas)
@test Pigeons.recursive_equal(p1.shared, p2.shared)
@test Pigeons.recursive_equal(p1.reduced_recorders, p2.reduced_recorders)
end

@testset "Checkpoints" begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_lazy_target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ include("supporting/lazy.jl")
pt1 = load(r)
pt2 = pigeons(target = toy_mvn_target(1))

@test pt1.replicas == pt2.replicas
@test Pigeons.recursive_equal(pt1.replicas, pt2.replicas)
end
Loading

0 comments on commit 225ffbc

Please sign in to comment.