Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with constrained parameters depending on each other #2195

Closed
torfjelde opened this issue Apr 18, 2024 · 8 comments
Closed

Issues with constrained parameters depending on each other #2195

torfjelde opened this issue Apr 18, 2024 · 8 comments

Comments

@torfjelde
Copy link
Member

torfjelde commented Apr 18, 2024

Problem

julia> using Turing

julia> @model function buggy_model()
           lb ~ Uniform(0, 0.1)
           ub ~ Uniform(0.11, 0.2)
           x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
       end

buggy_model (generic function with 2 methods)

julia> model = buggy_model();

julia> chain = sample(model, NUTS(), 1000);
┌ Info: Found initial step size
└   ϵ = 3.2
Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:01

julia> results = generated_quantities(model, chain); # (×) Breaks!
ERROR: DomainError with -0.05206647177072762:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
DomainError detected in the user `f` function. This occurs when the domain of a function is violated.
For example, `log(-1.0)` is undefined because `log` of a real number is defined to only output real
numbers, but `log` of a negative number is complex valued and therefore Julia throws a DomainError
by default. Cases to be aware of include:

* `log(x)`, `sqrt(x)`, `cbrt(x)`, etc. where `x<0`
* `x^y` for `x<0` floating point `y` (example: `(-1.0)^(1/2) == im`)
...

In contrast, if we use Prior to sample, we're good:

julia> chain_prior = sample(model, Prior(), 1000);

Sampling 100%|█████████████████████████████████████████████████████████████████████████████| Time: 0:00:00

julia> results_prior = generated_quantities(model, chain_prior); # (✓) Works because no linking needed

The issue is caused by the fact that we use DynamicPPL.invlink!!(varinfo, model) when constructing a transition, which is what ends up in the chain rather than an issue with the inference itself.

For example, if we use AdvancedHMC.jl directly:

julia> using AdvancedHMC: AdvancedHMC

julia> f = DynamicPPL.LogDensityFunction(model);

julia> DynamicPPL.link!!(f.varinfo, f.model);

julia> chain_ahmc = sample(f, AdvancedHMC.NUTS(0.8), 1000);
[ Info: Found initial step size 3.2
Sampling 100%|███████████████████████████████| Time: 0:00:00
  iterations:                                   1000
  ratio_divergent_transitions:                  0.0
  ratio_divergent_transitions_during_adaption:  0.0
  n_steps:                                      7
  is_accept:                                    true
  acceptance_rate:                              0.7879658455930968
  log_density:                                  -5.038135476673508
  hamiltonian_energy:                           7.775565727543868
  hamiltonian_energy_error:                     -0.11294798909710124
  max_hamiltonian_energy_error:                 0.5539216379943772
  tree_depth:                                   3
  numerical_error:                              false
  step_size:                                    1.1685229504528063
  nom_step_size:                                1.1685229504528063
  is_adapt:                                     false
  mass_matrix:                                  DiagEuclideanMetric([1.0, 1.0, 1.0])

julia> function to_constrained(θ)
           lb = inverse(Bijectors.Logit(0.0, 0.1))(θ[1])
           ub = inverse(Bijectors.Logit(0.11, 0.2))(θ[2])
           x = inverse(Bijectors.Logit(lb, ub))(θ[3])
           return [lb, ub, x]
       end
to_constrained (generic function with 1 method)

julia> chain_ahmc_constrained = mapreduce(hcat, chain_ahmc) do t
           to_constrained(t.z.θ)
       end;

julia> chain_ahmc = Chains(
           permutedims(chain_ahmc_constrained),
           [:lb, :ub, :x]
       );

Visualizing the densities of the resulting chains, we also see that the one from Turing.NUTS is incorrect (the blue line), while the other two (Prior and AdvancedHMC.NUTS) coincide:

image

Solution?

Fixing this I think will actually be quite annoying 😕 But I do think it's worth doing.

There are a few approaches:

  1. Re-evaluate the model for every transition we end up accepting to get the distributions corresponding to that particular realization.
  2. Double the memory usage of VarInfo and always store both the linked and the invlinked realizations.
  3. Use a separate context to capture the invlinked realizations.

No matter how we do this, there is the issue that we can't support this properly for externalsampler, etc. that uses the LogDensityFunction, without explicit re-evaluation of the model 😕 Though it seems it would still be worth adding proper support for this in the "internal" impls of the samplers

Might be worth providing an option to force re-evaluation in combination with, say, a warning if we notice that supports change between two different realizations

@yebai @devmotion @sunxd3

@torfjelde
Copy link
Member Author

This can be resolved with something like TuringLang/DynamicPPL.jl#588 + some minor changes to Turing.Inference.getparams by turning

# Extract parameter values in a simple form from the invlinked `VarInfo`.
DynamicPPL.values_as(DynamicPPL.invlink(vi, model), OrderedDict)

into

vals = if DynamicPPL.is_static(model)
    # Extract parameter values in a simple form from the invlinked `VarInfo`.
    DynamicPPL.values_as(DynamicPPL.invlink(vi, model), OrderedDict)
else
    # Re-evaluate the model completely to get invlinked parameters since
    # we can't trust the invlinked `VarInfo` to be up-to-date.
    extract_realizations(model, deepcopy(vi))
end

This then defaults to the "make sure we're doing everything correctly"-approach, but allows the user to avoid all the additional model evaluations by just doing:

model = DynamicPPL.mark_as_static(model)

before passing model to sample

As noted in the PR, we probably should have something a bit more general to also capture when we need to use a fully blown UntypedVarInfo to allow arbitrary number of parameters + changing between evaluations, but given that we will (soonTM) have a more flexible approach to UntypedVarInfo which can grow arbitrarily (TuringLang/DynamicPPL.jl#555) this might not be so important.

@torfjelde
Copy link
Member Author

torfjelde commented Apr 19, 2024

Combining aforementioned PRs + TuringLang/DynamicPPL.jl#540, I imagine putting something like the following in our sample:

if DynamicPPL.has_static_constraints(model)
    model = DynamicPPL.mark_as_static(model)
end

and then continue business as usual. It will be a heuristic ofc, but will work very well in practice. Could make this a keyword argument to allow it to be disabled.

@sunxd3
Copy link
Member

sunxd3 commented Apr 19, 2024

Just to clarify for my understanding: this seems to be a VarInfo issue -- because distributions in metadata is evaluated and saved. Then they are used during invlink, which means when using VarInfo, we always assume the distribution type and support are consistent?

Then how about with SimpleVarinfo? And would directing user to use SimpleVarinfo an option for solution? (Of course still need utility to check if static)

@torfjelde
Copy link
Member Author

It's indeed a VarInfo issue, but the way we do it with SimpleVarInfo is exactly to re-evaluate the model 🤷 So what I'm suggesting is to always do this by default, because that will always produce the right thing, and then allow specialization in the cases where it makes sense

@yebai
Copy link
Member

yebai commented Apr 19, 2024

A better fix to to remove VarInfo in favour of (generalised) SimpleVarInfo if that is the case!

@torfjelde
Copy link
Member Author

torfjelde commented Apr 19, 2024

A better fix to to remove VarInfo in favour of (generalised) SimpleVarInfo if that is the case!

But that doesn't address the issue!

It's not "a bug" per-se on the side of VarInfo, it's a question of whether we have to re-evaluate the model to get the correct transform or not. Replacing VarInfo completely with SimpleVarInfo would just force us to always re-evaluate the model fully whenever we want to invlink (which is obviously very undesirable in many cases).

EDIT: Even the new VarNameVector I've been working on will also suffer from the same issue.

@torfjelde
Copy link
Member Author

For example, a good use-case is #2099 where we need to (inv)link all the time to go between constrained (used to condition) and unconstrained (used in sampling). We really shouldn't be doing this through re-evaluation of the model unless needed 😕

torfjelde added a commit that referenced this issue Apr 20, 2024
torfjelde added a commit that referenced this issue May 7, 2024
#2202)

* use `values_as_in_model` to extract the parameters from a `Transition`
rather than `invlink` + `values_as`

* bump BangBang compat entry

* added test from #2195 + added HypothesisTests.jl so we can compare
chains properly

* deepcopy varinfo before calling `values_as_in_model` to avoid mutating
the original logprob computations, etc.

* bump patch version

* fixed tests

---------

Co-authored-by: Hong Ge <[email protected]>
@yebai
Copy link
Member

yebai commented May 7, 2024

fixed by #2202 (comment)

@yebai yebai closed this as completed May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants