Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setval! is being weird #167

Closed
torfjelde opened this issue Sep 27, 2020 · 15 comments
Closed

setval! is being weird #167

torfjelde opened this issue Sep 27, 2020 · 15 comments

Comments

@torfjelde
Copy link
Member

torfjelde commented Sep 27, 2020

So I'm currently making a PR for generated_quantities, and ran into the following:

julia> using DynamicPPL, Turing

julia> Turing.turnprogress(false);
[ Info: [Turing]: progress logging is disabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as false

julia> @model function demo_fails(xs, ::Type{TV} = Vector{Float64}) where {TV}
           m = TV(undef, 2)
           for i in 1:2
               m[i] ~ Normal(0, 1)
           end

           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end

           return (m, )
       end;

julia> xs = randn(3);

julia> model_fails = demo_fails(xs);

julia> chain_fails = sample(model_fails, NUTS(0.65), 100);
┌ Info: Found initial step size
└   ϵ = 1.6

julia> var_info = VarInfo(model_fails);

julia> spl = SampleFromPrior();

julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
 -0.34786841389137324
  0.4703320351984347

julia> res_0 = model_fails(var_info, spl)
([-0.34786841389137324, 0.4703320351984347],)

julia> DynamicPPL.setval!(var_info, chain_fails, 1, 1);

julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
 -0.34786841389137324
  0.4703320351984347

julia> res_1 = model_fails(var_info, spl) # <= has NOT changed since `res_0`!!!
([-0.34786841389137324, 0.4703320351984347],)
 ...

In contrast, the following works just fine:

julia> @model function demo_works(xs)
           m ~ MvNormal(2, 1.)
           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end

           return (m, )
       end;

julia> model_works = demo_works(xs);

julia> chain_works = sample(model_works, NUTS(0.65), 100);
┌ Info: Found initial step size
└   ϵ = 1.6

julia> var_info = VarInfo(model_works);

julia> spl = SampleFromPrior();

julia> θ_0 = var_info[spl]
2-element Array{Float64,1}:
 -0.38315867392034214
  1.2157641175253535

julia> res_0 = model_works(var_info, spl)
([-0.38315867392034214, 1.2157641175253535],)

julia> DynamicPPL.setval!(var_info, chain_works, 1, 1);

julia> θ_1 = var_info[spl] # <= note that the value has changed since `θ_0`
2-element Array{Float64,1}:
 -0.5500762603588643
 -0.8341982280432694

julia> res_1 = model_works(var_info, spl) # <= has indeed changed since `res_0`
([-0.5500762603588643, -0.8341982280432694],)
 ...

So the in first example setval! fails.

@torfjelde
Copy link
Member Author

Seems like this is the offending line: https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/varinfo.jl#L1183..L1183

When ran on the failing example, indices is empty.

@torfjelde
Copy link
Member Author

Uhhh:

julia> regex = r"^m[1]$|^m[1]\["
r"^m[1]$|^m[1]\["

julia> match(regex, "m[1]") === nothing
true

@torfjelde
Copy link
Member Author

Crap, of course:

julia> regex = r"m\[1\]$|m\[1\]\["
r"m\[1\]$|m\[1\]\["

julia> match(regex, "m[1]")
RegexMatch("m[1]")

So we need to escape the names!

@devmotion
Copy link
Member

You have to escape [ and ] by writing

julia> regex = r"^m\[1\]$|^m\[1\]\["

@torfjelde
Copy link
Member Author

Lol @devmotion , you didn't get that until NOW? #5secondsfaster

@devmotion
Copy link
Member

Could we maybe use subsumes instead of reimplementing this logic with regex matching?

@torfjelde
Copy link
Member Author

We don't have VarName for keys though, right?

@devmotion
Copy link
Member

No I realized that it's probably not useful, exactly due to this issue.

@torfjelde
Copy link
Member Author

Immediate solution:

function regex_escape(s::AbstractString)
    res = replace(s, r"([()[\]{}?*+\-|^\$\\.&~#\s=!<>|:])" => s"\\\1")
    replace(res, "\0" => "\\0")
end

from JuliaLang/julia#29643

But you're right though, Regex isn't the nicest way to deal with this stuff...

@torfjelde
Copy link
Member Author

torfjelde commented Sep 27, 2020

But it does fix the issue:

julia> using DynamicPPL, Turing
[ Info: Precompiling DynamicPPL [366bfd00-2699-11ea-058f-f148b4cae6d8]
[ Info: Precompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]

julia> Turing.turnprogress(false)
[ Info: [Turing]: progress logging is disabled globally
[ Info: [AdvancedVI]: global PROGRESS is set as false
false

julia> @model function demo(xs, ::Type{TV} = Vector{Float64}) where {TV}
           m = TV(undef, 2)
           for i in 1:2
               m[i] ~ Normal(0, 1)
           end

           for i in eachindex(xs)
               xs[i] ~ Normal(m[1], 1.)
           end

           return (m, )
       end
demo (generic function with 2 methods)

julia> xs = randn(3);

julia> model = demo(xs);

julia> chain = sample(model, NUTS(0.65), 100);
┌ Info: Found initial step size
└   ϵ = 1.6

julia> generated_quantities(model, chain)
50×1 Array{Tuple{Array{Float64,1}},2}:
 ([-0.6466888684208906, -1.0605924818561454],)
 ([-0.6466888684208906, -1.0605924818561454],)
 ([-0.21895508638907632, 1.6653331143718089],)
 ([-0.21895508638907632, 1.6653331143718089],)
 ([-0.2997309317443484, -0.029323999401321466],)
 ([-1.0387336145893835, 0.7184211285862931],)
 ([-0.21765620075312997, 0.01436077650117068],)
 ([0.529802098448724, -0.2980663360869661],)
 ([0.11811237219422727, -0.5270924388346214],)
 ([-1.0033004645889954, 1.0369564237981046],)
 ([-0.6387891971861386, 0.6617567920524946],)
 ([-0.010075207468817182, 0.38321779739786466],)
 ...

@torfjelde
Copy link
Member Author

Should I just make a PR with this solution? Seems like this is something that ought to be fixed asap...

@torfjelde
Copy link
Member Author

Will also add some tests, so this doesn't happen in the future

@devmotion
Copy link
Member

Maybe it would be safer to use

string_vn = string(vn)
string_vn_indexing = string_vn * "["
findall(keys) do x
    string_x = string(x)
    return string_x == string_vn || startswith(string_x, string_vn_indexing)
end

@torfjelde
Copy link
Member Author

Nice, that indeed seems safer 👍

bors bot pushed a commit that referenced this issue Sep 28, 2020
This PR adds `generated_quantities` as discussed in TuringLang/Turing.jl#1335 + adds a fix for #167.
@devmotion
Copy link
Member

Fixed by #168.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants