Skip to content

Commit

Permalink
Replace remaining instances of @submodel
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 13, 2024
1 parent 0c266a6 commit bc92248
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ adds the `Prefix` to all parameters.
This context is useful in nested models to ensure that the names of the parameters are
unique.
See also: [`@submodel`](@ref)
See also: [`to_submodel`](@ref)
"""
struct PrefixContext{Prefix,C} <: AbstractContext
context::C
Expand Down
7 changes: 5 additions & 2 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ end

@model function demo_assume_submodel_observe_index_literal()
# Submodel prior
@submodel s, m = _prior_dot_assume()
priors ~ to_submodel(_prior_dot_assume(), false)
s, m = priors
1.5 ~ Normal(m[1], sqrt(s[1]))
2.0 ~ Normal(m[2], sqrt(s[2]))

Expand Down Expand Up @@ -475,7 +476,9 @@ end
m .~ Normal.(0, sqrt.(s))

# Submodel likelihood
@submodel _likelihood_mltivariate_observe(s, m, x)
# With to_submodel, we have to have a left-hand side variable to
# capture the result, so we just use a dummy variable
_ignore ~ to_submodel(_likelihood_mltivariate_observe(s, m, x))

return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
end
Expand Down
29 changes: 10 additions & 19 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ module Issue537 end
return x ~ Normal()
end
@model function demo2(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Uniform()
end
# No observation.
Expand Down Expand Up @@ -441,7 +441,7 @@ module Issue537 end

# Check values makes sense.
@model function demo3(x, y)
@submodel demo1(x)
_ignore ~ to_submodel(demo1(x), false)
return y ~ Normal(x)
end
m = demo3(1000.0, missing)
Expand All @@ -453,12 +453,10 @@ module Issue537 end
x ~ Normal()
return x
end

@model function demo_useval(x, y)
@submodel prefix = "sub1" x1 = demo_return(x)
@submodel prefix = "sub2" x2 = demo_return(y)

return z ~ Normal(x1 + x2 + 100, 1.0)
sub1 ~ to_submodel(demo_return(x))
sub2 ~ to_submodel(demo_return(y))
return z ~ Normal(sub1 + sub2 + 100, 1.0)
end
m = demo_useval(missing, missing)
vi = VarInfo(m)
Expand All @@ -472,21 +470,18 @@ module Issue537 end
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
η ~ MvNormal(zeros(num_steps), I)
δ = sqrt(1 - α^2)

x = TV(undef, num_steps)
x[1] = η[1]
@inbounds for t in 2:num_steps
x[t] = @. α * x[t - 1] + δ * η[t]
end

return @. μ + σ * x
end

@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)

num_steps = length(y[1])
num_obs = length(y)
@inbounds for i in 1:num_obs
Expand Down Expand Up @@ -613,14 +608,11 @@ module Issue537 end
@model demo() = x ~ Normal()
retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext())

# Return-value when using `@submodel`
# Return-value when using `to_submodel`
@model inner() = x ~ Normal()
# Without assignment.
@model outer() = @submodel inner()
@test outer()() isa Real

# With assignment.
@model outer() = @submodel x = inner()
@model function outer()
_ignore ~ to_submodel(inner())
end
@test outer()() isa Real

# Edge-cases.
Expand Down Expand Up @@ -720,8 +712,7 @@ module Issue537 end
return (; x, y)
end
@model function demo_tracked_submodel()
@submodel (x, y) = demo_tracked()
return (; x, y)
vals ~ to_submodel(demo_tracked(), false)
end
for model in [demo_tracked(), demo_tracked_submodel()]
# Make sure it's runnable and `y` is present in the return-value.
Expand Down

0 comments on commit bc92248

Please sign in to comment.