Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

saving ADAM optimizer is broken [@save] [BSON] #737

Closed
aldopareja opened this issue Apr 15, 2019 · 18 comments
Closed

saving ADAM optimizer is broken [@save] [BSON] #737

aldopareja opened this issue Apr 15, 2019 · 18 comments

Comments

@aldopareja
Copy link

this code should probably run:

using Flux
using BSON: @save
opt = ADAM(0.1)
@save "some_optim.bson" opt
@devmotion
Copy link
Contributor

I think the underlying problem is

using BSON: @save
dict = IdDict()
@save "test.bson" dict

and

using BSON: @save
dict = IdDict()
a = rand(5)
dict[a] = 5
@save "test.bson" dict

@devmotion
Copy link
Contributor

BTW, according to the documentation it should work: https://fluxml.ai/Flux.jl/stable/saving/

@DilumAluthge
Copy link
Member

The underlying cause of this issue is JuliaIO/BSON.jl#43

@bhvieira
Copy link
Contributor

Anything on this? I just fell for this and lost some models I was checkpointing

@DilumAluthge
Copy link
Member

You could try using JLD2 instead of BSON.

We are still waiting for JuliaIO/BSON.jl#43 to be fixed.

@bhvieira
Copy link
Contributor

But JLD and JLD2, as far as I can understand, can't save closures or functions in general.

@bhvieira
Copy link
Contributor

But I might try that, save the model with BSON and the optim with JLD

@DilumAluthge
Copy link
Member

I'm not sure about JLD. But JLD2 cannot save closures.

@DilumAluthge
Copy link
Member

But I might try that, save the model with BSON and the optim with JLD

Yeah that might work.

@bhvieira
Copy link
Contributor

I'm not sure about JLD. But JLD2 cannot save closures.

JLD can't either, trust me, just lost 2 days of work because of that

@CarloLucibello
Copy link
Member

This appears to be fixed by @dhairyagandhi96 in JuliaIO/BSON.jl#70

@NightMachinery
Copy link
Contributor

Can the docs explain how saving IdDicts works? Don't they rely on the object ID of their keys? How can we know if these IDs are preserved when loading the parameters? (Does BSON do some magic to use the same IDs, or is this somehow guaranteed for any Julian serialization method?)

@ToucheSir
Copy link
Member

IIUC this was just about being able to save structs with IdDicts at all without errors, rather than being able to make use of them later. If you want to do the latter, give the following a try:

using Flux

# change opt type and IdDict field name for whatever you're using
function extract_opt_state(opt::ADAM, model)
    func = Flux.Functors.children(model)
    map(func) do child
        if Flux.isleaf(child)
            get(opt.state, child, nothing)
        else
            extract_opt_state(opt, child)
        end
    end
end

function restore_opt_state!(opt::ADAM, model, state)
    func = Flux.Functors.children(model)
    map(func, state) do child, st
        st === nothing && return
        if Flux.isleaf(child)
            opt.state[child] = st
        else
            restore_opt_state!(opt, child, st)
        end
    end
end

opt1 = ADAM(0.1)
model = Chain(
    Conv((3,), 1 => 16),
    GlobalMeanPool(),
    Flux.flatten,
    Dense(16, 2)
)

x = rand(Float32, 10, 1, 64)
ps = params(model)
grads = gradient(() -> sum(model(x)), ps)
Flux.Optimise.update!(opt1, ps, grads)

opt_state = extract_opt_state(opt1, model)
@show Flux.fmap(summary, opt_state)

opt2 = ADAM(0.1)
restore_opt_state!(opt2, model, opt_state)
@assert opt1.state == opt2.state
for p in ps
    @assert opt1.state[p] == opt2.state[p]
end

@NightMachinery
Copy link
Contributor

@ToucheSir The Flux documentation currently states that just saving the optimizer normally with BSON works; If that is incorrect, can we add your workaround to the docs instead? (Anyhow, the docs there need some clarification either way.)

PS: Is Functors.children guaranteed to return the parameters in a deterministic order?

@ToucheSir
Copy link
Member

Funnily enough, the relevant docs were last updated 4 years ago (!), back before Flux used IdDicts.

I could throw up a PR for the snippet if there's enough buy-in. In the meantime, if you wouldn't mind submitting one that excises that paragraph completely, we could at least avoid spreading more confusion.

RE postscript: yes, it's not explicitly written down IIRC but that's part of the contract. Most implementors of the Functors interface use @functor, which maintains a known, set order https://fluxml.ai/Functors.jl/stable/.

@NightMachinery
Copy link
Contributor

@ToucheSir I sent a PR, fixing the outdated info.


Can't we somehow make IdDicts properly saveable via BSON? That would compose better.

Here is one way I thought up for going about this:

  • Check if all the keys of the IdDict are present in the objects we are going to save to the BSON
  • Put some ID tag on these objects (possibly their objectid)
  • Store these IDs for keys of the IdDict
  • When loading an IdDict, we can construct a new IdDict and populate it from the newly loaded objects

@ToucheSir
Copy link
Member

Thanks for the PR!

Here is one way I thought up for going about this:

That's, basically what the code above does, just more structured. The problem is that object IDs are not guaranteed to be unique, so you need some kind of out-of-band information as well when (de)serializing. That works for a specific use-case like Flux weights, but not so in general for IdDicts (how do you pass this info to the deserializer?).

One easier but still non-general approach would be to convert the IdDict into an array of values, ordered by the order of parameters as they appear in your model. This is basically Flux.loadparams!, but without the flattening step so that it works with all data types. Again though, it will not work with arbitrary IdDicts because iteration order is not guaranteed.

@ToucheSir
Copy link
Member

After looking into a similar issue today, I was in fact wrong about not being able to save IdDicts. The reason the current docs suggest saving the model and optimizer together is because BSON is smart enough to cache values and insert links when saving, but only only if it knows everything to be saved up front: https://github.com/JuliaIO/BSON.jl/blob/3b4a2cebda0afae11aab310f0a4d12b6a5234160/src/write.jl#L71. Let's update #1762 to reflect that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants