Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continual regression tutorial question: Laplace #107

Open
ran-weii opened this issue Jul 15, 2024 · 2 comments
Open

Continual regression tutorial question: Laplace #107

ran-weii opened this issue Jul 15, 2024 · 2 comments
Labels
help wanted Extra attention is needed

Comments

@ran-weii
Copy link

Hi, thanks for creating this library!

I was going through the continual regression tutorial notebook and wanted to try out diag_fisher and I expected it to produce similar results to diag_vi. However, the predictive distribution did not look right except for the first episode. Here are my results for VI (by sampling parameters) and Fisher (linearized forward):

image

image

One thing I noticed is for diag_fisher, the prec_diag of some parameters are close to zero, which I clipped to 1e-5 because otherwise it will cause an error with posteriors.laplace.diag_fisher.sample. I also de-normalized log_posterior by multiplying it with the number of data points for diag_fisher transform. I was wondering whether this is expected or there's something wrong with my implementation?

Here's my code which can be subbed in after the VI training block in the notebook:

def train_for_la(dataloader, prior_mean, prior_sd, n_epochs=100, init_log_sds=None):
    seq_log_post = partial(log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)
    
    # compute map estimate
    opt = torch.optim.Adam(mlp.parameters())
    for _ in range(n_epochs):
        for batch in dataloader:
            opt.zero_grad()
            loss = -seq_log_post(dict(mlp.named_parameters()), batch)[0]
            loss.backward()
            opt.step()
    
    # update laplace state
    def _log_posterior(params, batch, prior_mean, prior_sd):
        """Non data normalized log post"""
        x, y = batch
        y_pred = mlp_functional(params, x)
        log_post = log_likelihood(y_pred, y) * samps_per_episode + log_prior(params, prior_mean, prior_sd)
        return log_post, y_pred
    _seq_log_post = partial(_log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)

    transform = posteriors.laplace.diag_fisher.build(
        _seq_log_post, per_sample=False, init_prec_diag=0.,
    )
    state = transform.init({k: v.data for k, v in dict(mlp.named_parameters()).items()})
    for batch in dataloader:
        state = transform.update(state, batch, inplace=False)
    state.prec_diag = tree_map(lambda x: torch.clip(x, 1e-5, 1e5), state.prec_diag)
    return state

# train laplace
mlp.load_state_dict(trained_params[0])
la_states = []
for i in range(n_episodes):
    seq_prior_mean = prior_mean if i == 0 else tree_map(lambda x: x.clone(), la_states[i - 1].params)
    seq_prior_sd = prior_sd if i == 0 else tree_map(
        lambda prec: torch.sqrt(1 / prec.clone() + transition_sd ** 2), la_states[i - 1].prec_diag
    )
    
    state = train_for_la(
        dataloaders[i], seq_prior_mean, seq_prior_sd, n_epochs=100, init_log_sds=None
    )

    la_states += [copy.deepcopy(state)]
    mlp.load_state_dict(la_states[i].params)

# laplace forward
def to_sd_diag_la(state, temperature=1.0):
    return tree_map(lambda x: torch.sqrt(temperature / (x + 1e-3)), state.prec_diag)

def forward_linearized(model, state, batch, temperature=1.0):
    n_linearised_test_samples = 30
    x, _ = batch
    sd_diag = to_sd_diag_la(state, temperature)

    def model_func_with_aux(p, x):
        return torch.func.functional_call(model, p, x), torch.tensor([])

    lin_mean, lin_chol, _ = posteriors.linearized_forward_diag(
        model_func_with_aux,
        state.params,
        x,
        sd_diag,
    )

    samps = torch.randn(
        lin_mean.shape[0],
        n_linearised_test_samples,
        lin_mean.shape[1],
        device=lin_mean.device,
    )
    lin_logits = lin_mean.unsqueeze(1) + samps @ lin_chol.transpose(-1, -2)
    return lin_logits

# plot laplace
fig, axes = plt.subplots(1, n_episodes, figsize=(n_episodes * 4, 4), sharex=True, sharey=True)
for i, ax in enumerate(axes):
    plot_data(ax, up_to_episode=i+1)
    with torch.no_grad():
        preds = forward_linearized(mlp, la_states[i], [plt_linsp.view(-1, 1), None]).squeeze(-1)
    
    # plot predictions
    sd = preds.std(1)
    preds = preds.mean(1)
    ax.plot(plt_linsp, preds, color='blue', alpha=1.)
    ax.fill_between(plt_linsp, preds - sd, preds + sd, color='blue', alpha=0.2)
    ax.set_title(f"After Episode {i+1}")
plt.suptitle("diag_fisher")
@SamDuffield SamDuffield added the help wanted Extra attention is needed label Jul 24, 2024
@SamDuffield
Copy link
Contributor

Ok there are a few things going on here 😅 hopefully I can help somewhat.

You should not de-normalize the log_posterior for Laplace since it is called internally per-sample and so will work as long as when called with batchsize=1 is properly scaled. (Note that in most definitions of the Laplace approximation only the likelihood appears with no prior although I think its fine to also include the prior).

I am not sure if that will solve the issue here though. One thing is that often for small datasets (like the episodes here), the full batch gradient at the mode will be very small and therefore the Empirical Fisher precision also very small and the Empirical Fisher covariance very large meaning the Laplace approximation breaks. But it is curious here that you get some good performance in the latest episode but not in previous ones.

@ran-weii
Copy link
Author

ran-weii commented Aug 3, 2024

Hi, thanks for your help and sorry for the late reply! I've been studying this and reading the Martens paper. Like you said, it does make sense that the empirical fisher approximation breaks in this setting.

So I did try out your suggestion of not de-normalizing log posterior with this:

def _log_posterior(params, batch, prior_mean, prior_sd):
        x, y = batch
        y_pred = mlp_functional(params, x)
        log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode
        return log_post, y_pred

IMO the result is much worse?
image

I also tried approximating the true fisher as opposed to the empirical fisher by sampling from the model predictions with this:

def _log_posterior(params, batch, prior_mean, prior_sd):
        x, y = batch
        y_pred = mlp_functional(params, x)
        log_post = log_likelihood(y_pred, y) + log_prior(params, prior_mean, prior_sd) / samps_per_episode
        return log_post, y_pred
    _seq_log_post = partial(_log_posterior, prior_mean=prior_mean, prior_sd=prior_sd)

    transform = posteriors.laplace.diag_fisher.build(
        _seq_log_post, per_sample=False, init_prec_diag=0.,
    )
    state = transform.init({k: v.data for k, v in dict(mlp.named_parameters()).items()})
    for batch in dataloader:
        _batch = copy.deepcopy(batch)
        with torch.no_grad():
            x_batch, y_batch = _batch
            y_pred = mlp_functional(state.params, x_batch)
            y_sample = torch.distributions.Normal(y_pred, y_sd, validate_args=False).sample()
            _batch[1] = y_sample

        state = transform.update(state, _batch, inplace=False)

The result is pretty much the same.
image

I also played briefly with diag_ggn and pretty much got the same thing. So currently don't have a good theory about this. Any thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants