Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Particle Gibbs with ancestor resampling #32

Merged
merged 21 commits into from
Mar 29, 2022

Conversation

FredericWantiez
Copy link
Member

@FredericWantiez FredericWantiez commented Nov 3, 2021

Adding ancestor resampling for state space models, this works I think but it's mostly to discuss the API for the model/sampler. In this I assume something like:

mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
    X::Vector{Float64}
    θ::θ
    T::Int32
    NonLinearTimeSeries::θ) = new(Vector{Float64}(), θ, 1)
    NonLinearTimeSeries() = new(Vector{Float64}(), θ₀, 1)
end

function AdvancedPS.initialization(m::NonLinearTimeSeries)
    return Normal(0, m.θ.σ^2)
end

function AdvancedPS.transition(m::NonLinearTimeSeries, state)
    return Normal(m.θ.ν + m.θ.α * state, m.θ.σ^2)
end

function AdvancedPS.observation(m::NonLinearTimeSeries, state)
    return m.T > length(y) ? nothing : logpdf(Normal(0, exp(state)), y[m.T])
end

Edit
The model API changed slightly to make the step and state explicit in the initialization/observation functions (needed for the ancestor sampling step)

mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
    X::Vector{Float64}
    θ::θ
    NonLinearTimeSeries::θ) = new(Vector{Float64}(), θ, 1)
    NonLinearTimeSeries() = new(Vector{Float64}(), θ₀, 1)
end

AdvancedPS.initialization(m::NonLinearTimeSeries) = Normal(0, m.θ.σ^2)
AdvancedPS.transition(m::NonLinearTimeSeries, state, step) = Normal(m.θ.ν + m.θ.α * state, m.θ.σ^2)
AdvancedPS.observation(m::NonLinearTimeSeries, state, step) = logpdf(Normal(0, exp(state)), y[step])
AdvancedPS.isdone(m::NonLinearTimeSeries, step) = step > 3 # Stop the state machine

@FredericWantiez FredericWantiez changed the title PG AS - rough draft [WIP] PG AS - rough draft Nov 3, 2021
@coveralls
Copy link

coveralls commented Nov 3, 2021

Pull Request Test Coverage Report for Build 2024648543

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 114 of 127 (89.76%) changed or added relevant lines in 6 files are covered.
  • 7 unchanged lines in 2 files lost coverage.
  • Overall coverage increased (+3.2%) to 64.471%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/model.jl 29 31 93.55%
src/pgas.jl 50 52 96.15%
src/smc.jl 19 28 67.86%
Files with Coverage Reduction New Missed Lines %
src/rng.jl 3 88.89%
src/smc.jl 4 72.55%
Totals Coverage Status
Change from base Build 1959576624: 3.2%
Covered Lines: 323
Relevant Lines: 501

💛 - Coveralls

@codecov
Copy link

codecov bot commented Nov 3, 2021

Codecov Report

Merging #32 (8ce9101) into master (cd05f0f) will increase coverage by 3.42%.
The diff coverage is 89.78%.

@@            Coverage Diff             @@
##           master      #32      +/-   ##
==========================================
+ Coverage   61.26%   64.69%   +3.42%     
==========================================
  Files           6        7       +1     
  Lines         426      507      +81     
==========================================
+ Hits          261      328      +67     
- Misses        165      179      +14     
Impacted Files Coverage Δ
src/smc.jl 72.54% <67.85%> (-24.75%) ⬇️
src/model.jl 91.89% <91.42%> (-8.11%) ⬇️
src/pgas.jl 96.29% <96.29%> (ø)
src/container.jl 96.77% <100.00%> (+2.07%) ⬆️
src/resampling.jl 96.66% <100.00%> (+0.11%) ⬆️
src/rng.jl 90.00% <100.00%> (+1.11%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update cd05f0f...8ce9101. Read the comment docs.

src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated
end

function fork(trace::SSMTrace, isref::Bool)
model = deepcopy(trace.f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this might become memory intensive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the main issue with this implementation, we keep copying state vectors around which will probably be quite heavy for long time series

src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
@yebai yebai mentioned this pull request Dec 9, 2021
@yebai
Copy link
Member

yebai commented Feb 8, 2022

@FredericWantiez, if you can push your local changes, I’ll take a look later this week.

@yebai yebai mentioned this pull request Feb 24, 2022
@yebai yebai changed the title [WIP] PG AS - rough draft PG AS - rough draft Mar 4, 2022
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks, @FredericWantiez - very good work. I left some minor comments below.

examples/gaussian-ssm/script.jl Outdated Show resolved Hide resolved
examples/gaussian-ssm/script.jl Outdated Show resolved Hide resolved
examples/gaussian-ssm/script.jl Outdated Show resolved Hide resolved
examples/gaussian-ssm/script.jl Outdated Show resolved Hide resolved
examples/gaussian-ssm/script.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/pgas.jl Show resolved Hide resolved
src/smc.jl Outdated Show resolved Hide resolved
@yebai yebai changed the title PG AS - rough draft Particle Gibbs with ancestor resampling Mar 21, 2022
src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/pgas.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/model.jl Outdated Show resolved Hide resolved
src/container.jl Outdated Show resolved Hide resolved
src/container.jl Outdated Show resolved Hide resolved
src/rng.jl Outdated Show resolved Hide resolved
src/rng.jl Outdated Show resolved Hide resolved
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model, state, step), y[step])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a typo here? It does not seem y is in the arguments.

Suggested change
return logpdf(g(model, state, step), y[step])
return logpdf(g(model, state, step), y[step])

src/pgas.jl Show resolved Hide resolved
@yebai
Copy link
Member

yebai commented Mar 28, 2022

Many thanks, @FredericWantiez - excellent work! Only a few minor comment/clarification questions remain. Once fixed, we should be ready to go.

@yebai yebai merged commit f0baacd into TuringLang:master Mar 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants