-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
saving ADAM optimizer is broken [@save] [BSON] #737
Comments
I think the underlying problem is using BSON: @save
dict = IdDict()
@save "test.bson" dict and using BSON: @save
dict = IdDict()
a = rand(5)
dict[a] = 5
@save "test.bson" dict |
BTW, according to the documentation it should work: https://fluxml.ai/Flux.jl/stable/saving/ |
The underlying cause of this issue is JuliaIO/BSON.jl#43 |
Anything on this? I just fell for this and lost some models I was checkpointing |
You could try using JLD2 instead of BSON. We are still waiting for JuliaIO/BSON.jl#43 to be fixed. |
But JLD and JLD2, as far as I can understand, can't save closures or functions in general. |
But I might try that, save the model with BSON and the optim with JLD |
I'm not sure about JLD. But JLD2 cannot save closures. |
Yeah that might work. |
JLD can't either, trust me, just lost 2 days of work because of that |
This appears to be fixed by @dhairyagandhi96 in JuliaIO/BSON.jl#70 |
Can the docs explain how saving IdDicts works? Don't they rely on the object ID of their keys? How can we know if these IDs are preserved when loading the parameters? (Does |
IIUC this was just about being able to save structs with using Flux
# change opt type and IdDict field name for whatever you're using
function extract_opt_state(opt::ADAM, model)
func = Flux.Functors.children(model)
map(func) do child
if Flux.isleaf(child)
get(opt.state, child, nothing)
else
extract_opt_state(opt, child)
end
end
end
function restore_opt_state!(opt::ADAM, model, state)
func = Flux.Functors.children(model)
map(func, state) do child, st
st === nothing && return
if Flux.isleaf(child)
opt.state[child] = st
else
restore_opt_state!(opt, child, st)
end
end
end
opt1 = ADAM(0.1)
model = Chain(
Conv((3,), 1 => 16),
GlobalMeanPool(),
Flux.flatten,
Dense(16, 2)
)
x = rand(Float32, 10, 1, 64)
ps = params(model)
grads = gradient(() -> sum(model(x)), ps)
Flux.Optimise.update!(opt1, ps, grads)
opt_state = extract_opt_state(opt1, model)
@show Flux.fmap(summary, opt_state)
opt2 = ADAM(0.1)
restore_opt_state!(opt2, model, opt_state)
@assert opt1.state == opt2.state
for p in ps
@assert opt1.state[p] == opt2.state[p]
end |
@ToucheSir The Flux documentation currently states that just saving the optimizer normally with BSON works; If that is incorrect, can we add your workaround to the docs instead? (Anyhow, the docs there need some clarification either way.) PS: Is Functors.children guaranteed to return the parameters in a deterministic order? |
Funnily enough, the relevant docs were last updated 4 years ago (!), back before Flux used IdDicts. I could throw up a PR for the snippet if there's enough buy-in. In the meantime, if you wouldn't mind submitting one that excises that paragraph completely, we could at least avoid spreading more confusion. RE postscript: yes, it's not explicitly written down IIRC but that's part of the contract. Most implementors of the Functors interface use |
@ToucheSir I sent a PR, fixing the outdated info. Can't we somehow make Here is one way I thought up for going about this:
|
Thanks for the PR!
That's, basically what the code above does, just more structured. The problem is that object IDs are not guaranteed to be unique, so you need some kind of out-of-band information as well when (de)serializing. That works for a specific use-case like Flux weights, but not so in general for IdDicts (how do you pass this info to the deserializer?). One easier but still non-general approach would be to convert the IdDict into an array of values, ordered by the order of parameters as they appear in your model. This is basically |
After looking into a similar issue today, I was in fact wrong about not being able to save |
this code should probably run:
The text was updated successfully, but these errors were encountered: