Skip to content

Commit

Permalink
use fix instead of condition
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Nov 20, 2023
1 parent 0578c35 commit f87e2d1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/mcmc/gibbs_new.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
return namedtuple_compatible ? NamedTuple : OrderedDict
end

function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo)
function DynamicPPL.fix(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo)
# TODO: Determine when it's okay to use `NamedTuple` and use that instead.
return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
return DynamicPPL.fix(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
end
function DynamicPPL.condition(
function DynamicPPL.fix(
context::DynamicPPL.AbstractContext,
varinfo::AbstractVarInfo,
varinfos::AbstractVarInfo...
)
return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...)
return DynamicPPL.fix(DynamicPPL.fix(context, varinfo), varinfos...)
end


Expand Down Expand Up @@ -52,7 +52,7 @@ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfo
# TODO: Check if this is known at compile-time if `varinfos isa Tuple`.
# FIXME: Revert commit 53bd7072 and use `gibbs_condition` as soon as
# https://github.com/TuringLang/DynamicPPL.jl/pull/563 is merged.
return condition(
return fix(
model,
filter(Base.Fix1(!==, target_varinfo), varinfos)...
)
Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/gibbs_new.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ end
# `GibbsV2` does not work with SMC samplers, e.g. `CSMC`.
# FIXME: Oooor it is (see tests below). Uncertain.
Random.seed!(100)
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m))
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS())
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_gdemo(chain)
end
Expand Down

0 comments on commit f87e2d1

Please sign in to comment.