Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primitive implementation for serialization #258

Merged
merged 6 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.7.5"
version = "0.8.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -22,6 +22,7 @@ MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -67,6 +68,7 @@ MacroTools = "0.5"
MetaGraphsNext = "0.6, 0.7"
OrderedCollections = "1"
PDMats = "0.10, 0.11"
Serialization = "1.9.0"
SpecialFunctions = "2"
StaticArrays = "1.9"
Statistics = "1.9"
Expand Down
3 changes: 2 additions & 1 deletion src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using OrderedCollections: OrderedDict
using Random
using Serialization: Serialization
using StaticArrays

import Base: ==, hash, Symbol, size
Expand Down Expand Up @@ -172,7 +173,7 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N
values(eval_env),
),
)
return BUGSModel(g, nonmissing_eval_env, initial_params)
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
end

"""
Expand Down
37 changes: 34 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ end
The `BUGSModel` object is used for inference and represents the output of compilation. It implements the
[`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <:
AbstractBUGSModel
struct BUGSModel{
base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV,data_T
} <: AbstractBUGSModel
" Indicates whether the model parameters are in the transformed space. "
transformed::Bool

Expand All @@ -74,6 +75,10 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,

"If not `Nothing`, the model is a conditioned model; otherwise, it's the model returned by `compile`."
base_model::base_model_T

# for serialization, save the original model definition and data
model_def::Expr
data::data_T
end

function Base.show(io::IO, model::BUGSModel)
Expand Down Expand Up @@ -137,7 +142,9 @@ variables(model::BUGSModel) = collect(labels(model.g))
function BUGSModel(
g::BUGSGraph,
evaluation_env::NamedTuple,
initial_params::NamedTuple=NamedTuple();
model_def::Expr,
data::NamedTuple,
initial_params::NamedTuple=NamedTuple(),
is_transformed::Bool=true,
)
flattened_graph_node_data = FlattenedGraphNodeData(g)
Expand Down Expand Up @@ -199,6 +206,8 @@ function BUGSModel(
flattened_graph_node_data,
g,
nothing,
model_def,
data,
)
end

Expand All @@ -220,9 +229,31 @@ function BUGSModel(
FlattenedGraphNodeData(g, sorted_nodes),
g,
isnothing(model.base_model) ? model : model.base_model,
model.model_def,
model.data,
)
end

function Serialization.serialize(s::Serialization.AbstractSerializer, model::BUGSModel)
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
Serialization.serialize(s, typeof(model))
Serialization.serialize(s, model.transformed)
Serialization.serialize(s, model.model_def)
Serialization.serialize(s, model.data)
Serialization.serialize(s, model.evaluation_env)
return nothing
end

function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{<:BUGSModel})
model_def = Serialization.deserialize(s)
data = Serialization.deserialize(s)
evaluation_env = Serialization.deserialize(s)
transformed = Serialization.deserialize(s)
# use evaluation_env as initialization to restore the values
model = compile(model_def, data, evaluation_env)
return settrans(model, transformed)
end

"""
initialize!(model::BUGSModel, initial_params::NamedTuple)

Expand Down
29 changes: 29 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
@testset "serialization" begin
(; model_def, data) = JuliaBUGS.BUGSExamples.rats
model = compile(model_def, data)
serialize("m.jls", model)
deserialized = deserialize("m.jls")
@testset "test values are correctly restored" begin
for vn in MetaGraphsNext.labels(model.g)
@test isequal(
get(model.evaluation_env, vn), get(deserialized.evaluation_env, vn)
)
end

@test model.transformed == deserialized.transformed
@test model.untransformed_param_length == deserialized.untransformed_param_length
@test model.transformed_param_length == deserialized.transformed_param_length
@test all(
model.untransformed_var_lengths[k] == deserialized.untransformed_var_lengths[k]
for k in keys(model.untransformed_var_lengths)
)
@test all(
model.transformed_var_lengths[k] == deserialized.transformed_var_lengths[k] for
k in keys(model.transformed_var_lengths)
)
@test Set(model.parameters) == Set(deserialized.parameters)
# skip testing g
@test model.model_def === deserialized.model_def
end
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "controlling sampling behavior for conditioned variables" begin
model_def = @bugs begin
x ~ Normal(0, 1)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using MacroTools
using MCMCChains
using Random
using ReverseDiff
using Serialization

AbstractMCMC.setprogress!(false)

Expand Down
Loading