-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add explicit train!
, unify update!
, and auto-translate the two Adam
s
#2082
Conversation
d80bf53
to
bbc0f85
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/train.jl
Outdated
* Instead of `loss` being a function which typically accepts two arguments | ||
(the input `x` and expected output `y` from each element of `data`) | ||
now it should typically accept three, the first of which is the `model` itself. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never restricted specifically to 2 arguments (and we don't seem to restrict to 3 now either). I think the change is
- <= v0.13:
loss
accepts as many arguments aslength(first(data))
- > v0.13:
loss
accepts at least 1 argument, the model, and can accept additionalN
additional arguments whereN = length(first(data))
I think the distinction is important, since for things like language models, length(first(data)) == 1
(conceivably), and the loss handles taking the single data argument and turning it into a supervised problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I don't know how to word this. I think all the doc examples have 2 & need 3, so wrote "typically". Maybe it needs more explanation.
<= v0.13: loss accepts as many arguments as length(first(data))
It's weirder than that, because if first(data)
isn't a tuple, then it isn't splatted. I made the new code simply demand a tuple, else error, so that (at least for now) there is just one path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the "3 arguments" bit, to say just "Instead of loss
being a function which accepts only the data, now it must also accept the model
itself, as the first argument."
Yes. I simplified this to just the I have comments on #2083 which I should tidy up and post there. I think the major question here is whether, when making a new and incompatible method for
|
So the confusing features of
One alternative would be to change it to
which calls Since the keyword |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only hesitancy around this PR involve choices made to keep train!
as close to the old one as possible. So, do not require users to capture the return value, etc.
I would prefer returning both the model and the state from train!
. In v0.13.X, we do not need to document this change to the behavior. We silently upgrade to full Optimisers.jl support and throw deprecation warnings whenever Params
and AbstractOptimiser
are used. Then some versions down the line, we kill support for implicit.
Also, is the a reason for a separate Train
submodule?
(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) | ||
``` | ||
""" | ||
function setup(rule::Optimisers.AbstractRule, model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am hesitant to create a function in the Flux namespace that clashes with Optimisers.jl. It is hard enough already to keep track of where "Flux functions" actually come from.
Why not extend Optimisers.setup
for Flux.Optimise.AbstractOptimiser
and remove the mutability check? I am guessing this is to guard against immutable models since train!
does not return the model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought I did (in addition) extend Optimisers.setup
that way, but in fact I made an error. Can change it.
Yes the guard against immutable models is the point. All Flux models are assumed mutable right now, and this just makes the check explicit.
I don't love the collision, but neither name is exported, and the consequences of using the wrong one are (I think) slight. You lose the safety check but any model which does work with Flux.setup will also work correctly with Optimisers.setup.
We can of course make train!
return the model. But this isn't enough, as you also have to re-do your code to keep not discard the returned model. It's a bit awkward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, I picture Flux 0.14 deleting Flux.Optimise and exporting Adam etc from Optimisers.jl.
Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt)
will work equally well on 0.13 and 0.14. You don't have to load Optimisers.jl yourself at all, and all will be safe.
If you do load Optimisers.jl yourself and use its functions, then you have opted into the model, _ = update!(opt, model, grad)
thing where you are supposed to get back the new model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code that goes using Flux; opt = Flux.setup(Adam(), model); train!(loss, model, data, opt) will work equally well on 0.13 and 0.14.
I guess the question is should this be the case? I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam
, etc. and upgrades Flux.Optimise.Adam
with a warning. This will allow train!
to work like quoted above too. But in v0.14, I was expecting that we force people to start using model = train!(...)
. Previously, train!
and update!
worked similarly (mutating optimizers and model), and we could say train!
is "just" a loop. Diverging how they work seems worse than a minor code refactor on a breaking release. Especially given people will get warnings from before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think there should be a patch release of v0.13.X that accepts Optimisers.Adam, etc. and upgrades Flux.Optimise.Adam with a warning
Yes. That's what this PR makes.
I was expecting that we force people to start using
model = train!(...)
Especially given people will get warnings from before
But how? You want train!
not to mutate, so that everyone will wonder why their model isn't training, and why it's called train!
? Or worse to make it return a copy and write NaN into the old model to trash it? These seem awful to me, deliberate breakage for which we gain nothing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, let's forget my suggestions about the warnings, as I agree about the orthogonality.
One option here is to return the model from train!
which would allow for immutable models to work. Mutable models still don't need to capture the return value to work. So, we don't force people to do model = train!(...)
. And we still have Flux.setup
here to work in the reverse direction: warn if any leaf is immutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I agree that if we are adding Flux.setup
, then this seems like something that can be revisited later too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we should consider is changing Optimisers. Maybe its present function should be called update!!
as that's something of a convention for "tries to mutate but may fail".
Then in [email protected], we can introduce a new function update!
which demands mutability, fails on immutable parameters. And that's the one we identify with Flux's function.
That's now FluxML/Optimisers.jl#116
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think train!
should guarantee mutation, this is the widespread julian convention. We can have a train
and a train!!
for non-mutating and mutate-if-possible versions.
In that case, whether it returns the model or not hasn't great relevance. Base functions such as replace!
and map!
return the mutated input. Maybe just for REPL usage convenience? In our case returning the model in the repl would just be an annoyance I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the present one returns nothing
, which as you say means you don't get a screenful of stuff, and also serves as a reminder that it mutates the model.
I think I'd be happiest if update!
did the same. I mean Flux.update!
does now, but after unifying with Optimisers.update!
too.
I understand the attraction of state, model = update(state, model, grad)
but IMO it's a pain to remember the order, and update!
is now in a weird place where it does guarantee to mutate the state
, but not the model.
See suggestions above for how we might change it. But one of the goals is making the transition as easy as possible, which argues for keeping the weird
How could you require this? You could encourage it, but since code without it will continue to work, it seems tricky to keep saying "you really ought to...".
Note that there is no need to return the state. This is guaranteed to be mutated, even for a model of immutable arrays. We could change
Yes, that's precisely what this PR amis to do. It makes claims that 0.14 will remove this. But how soon that is, we can see.
No very strong one, there's a Losses module. And an Optimise module which (1) is a terrible near-clash, and (2) contains neatly everything which the no-more-implicit change will delete entirely. Would be fine to remove the sub-module, maybe that's better? |
70242cb
to
b8c6192
Compare
train!
without removing implicit onetrain!
, unify update!
, and auto-translate the two Adam
s
* `data` must iterate tuples, otherwise you get an error. | ||
(Previously non-tuple types were not splatted into the loss. | ||
Pass in `((d,) for d in data)` to simulate this.) | ||
* `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Documenter need Flux.Train.setup
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe. I just tried, and I think neither this new train!
nor setup
appear in the docs at present. That section needs to be re-worked for explicit parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc update is #2114, will need to be rebased on this & then checked.
…d you can just use update! instead really.
6825fd5
to
f3e1559
Compare
I would keep the |
conditional on fixing the failing test LGTM |
gs = gradient(marg -> marg(x), m) | ||
@test gs isa Tuple | ||
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs) # friendly | ||
@test_throws ErrorException Flux.update!(Flux.Adam(), Flux.params(m), gs[1]) # friendly | ||
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gs) # friendly | ||
@test_throws ErrorException Flux.update!(Flux.Adam(), m, gs[1]) # friendly | ||
s = Flux.setup(Adam(), m) | ||
@info "ignore this warning, just testing an upgrade path:" | ||
Flux.update!(s, m, gs) # Chain + Tuple can be unambiguously sorted out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most recent commits add some friendly errors to most ways you could use update!
wrong, by mixing up implicit & explicit bits, or forgetting to call setup
.
The dimensions of these model parameters depend on the number of inputs and outputs. Since models can have hundreds of inputs and several layers, it helps to have a function to collect the parameters into the data structure Flux expects: | ||
|
||
```jldoctest overview; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?" | ||
julia> parameters = Flux.params(predict) | ||
Params([Float32[0.9066542], Float32[0.0]]) | ||
``` | ||
|
||
These are the parameters Flux will change, one step at a time, to improve predictions. At each step, the contents of this `Params` object changes too, since it is just a collection of references to the mutable arrays inside the model: | ||
|
||
```jldoctest overview | ||
julia> predict.weight in parameters, predict.bias in parameters | ||
(true, true) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also new, remove all params
from "overview".
This is a non-breaking alternative to #2029, and part of #1986's goal to kill off Flux.Optimise.
It adds an explicit
train!(loss, model, data, opt)
, which uses Optimisers.jl inside, and needsopt = Flux.setup(Adam(), model)
before use. Tries to provide a friendly upgrade path:setup
accepts the oldAdam()
, which is unchanged & can still be used with implicittrain!
. I hope this makes it easy to translate working code, as new & old will work in the same session. Then v0.14 will export the newAdam()
instead, but code won't change.setup
, it will still train, but will warn you if state is being lost.d in data
which isn't a Tuple. This avoids the weirdbatchmaybe
thing, which sometimes splats and sometimes doesn't.Flux.update! === Optimisers.update!
, so you can't use the wrong one. (At present there are twosetup
s.)Needs doc updates, but probably not this PR. I think this would free us to remove every mention of implicit parameters from the docs, during 0.13.
Then 0.14 can delete Flux.Optimise and the
train!(..., ::Params, ...)
function completely.This
train!
would like to have methods likemse(m, x, y) = mse(m(x), y)
for all the loss functions, to allowtrain!(mse, model, data, opt)
rather than defining a trivial wrapper function every time. Not this PR though. (Now #2090.)The meaning of
data
is still the same as before -- it's an iterator usually over tuples, which are usually splatted into the loss. This is (IMO) a confusing feature oftrain!
, and perhaps the implicit / explicit break is the time to fix that too. One possibility would be to take arrays not an iteratortrain!(mse, model, X, Y; opt)
:It also adds an explicit way of changing the AD used, viaRemoved in bbc0f85, for now, to make things more orthogonal. The macro was the same as this bit of 2029.@train_autodiff
. RFC, I guess. Tests for it run on Tracker & Yota.PR Checklist