-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose train loop to user code #1471
base: master
Are you sure you want to change the base?
Conversation
Simple usage example struct Callback{T}
losses::AbstractVector{T}
end
# l: loss at the datapoint
# ps: params (maybe can skip but good to have to avoid globals)
# gs: grads at the datapoint to inspect
# d: datapoint
# opt: modify optimiser based on some condition
(cb::Callback)(l, ps, gs, d, opt) = append!(cb.losses, l)
prehook(l, ps, gs, d, opt) = throw(Flux.Optimise.SkipException())
c = Callback(Float32[])
Flux.train!(loss, ps, data, opt, cb = [() -> (), c])
Flux.train!(loss, ps, data, opt, prehooks = prehook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If adding these hooks is the desired behavior, then I think the approach is fine.
But as I said in #1461, I don't think that we should try and make a simple hook system for train!
. If a user wants to add a non-standard loop behavior (e.g. conditionals before the update, etc.), then it is easier and cleaner to define your own for loop in your own train!
function.
src/optimise/train.jl
Outdated
update!(opt, ps, gs) | ||
cb() | ||
gs = back(l) | ||
all(train_prehooks(l, ps, gs, d, opt)) && update!(opt, ps, gs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"prehooks" doesn't seem like a good name for what the intent of this hook is. It isn't immediately obvious that the hook can block a parameter update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would be happy to hear thoughts on a better name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe "gradient check" or "update check"?
src/optimise/train.jl
Outdated
cb() | ||
gs = back(l) | ||
all(train_prehooks(l, ps, gs, d, opt)) && update!(opt, ps, gs) | ||
cb(l, ps, gs, d, opt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a nasty list of args...we're just passing everything into the hook regardless of whether it is necessary. I feel like this is further evidence that a simple hook system is not a scalable design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, but the idea is really to mirror the train api, really that was the intent here, but we can definitely not have to deal with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do want to add this PR to Flux, then maybe the argument list can be just l, d, gs
(maybe we don't even need d
in there). Cause ps
and opt
will be accessible in the scope calling train!
, so the hook function can just close over them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can expose a train context within which the step!
takes place and the callback has access to the local scope too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does sorry?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore that comment. I posted before the page refreshed with your latest comment.
This isn't saying that we should be forcing more people into the train function, it's to say that we can make the train function more usable by itself. I am for the proper docs fwiw |
I think the bit that I most want is exposing the loss via |
I think this PR definitely makes
I do agree that this a major pain point of |
The thing is, from an api design perspective, our loop is pretty well engineered. There's few places to hook code into, and the most of the checkpointing can be easily done in the loss function and teach the ad to ignore it, which is pretty trivial with wrapping up things in a function and no gradding it. I have also proposed adding a scope in the training loop which users can work with. This would be preferable I think, and is in line with #1017 |
I don't think it's a rabbit hole because there's not many places to add hooks to. You can have multiple prehooks and multiple callbacks already, and what you do with them is up to you. The library exposes the necessary functionality so to speak. Doing things before the grads? I'd be interested to see an example of that come up in use. |
Note to self: check out some of the things fastai does and see what we can generalise as expected places for hooking up code. |
Hmmmm, this is a tough one. It strikes at the philosophy of what Flux truly is; is Flux a toolkit that you use to get things done, or is it an all-in-one application that you customize in order to perform one of the tasks it knows how to perform? If it's the former, I support the "completely remove If it's the latter, then I think we should take a careful look around at what other people do and do our best to provide an ergonomic mapping of those hooks into the training loop. This will always be a moving target, because as new architectures and training regimen appear, this training loop will naturally shift and change. That's not necessarily a bad thing though, but it will create a natural rift between the users that "paint within the box" and those that are forced to escape the confines presented by Anecdotally, I have never finished a project with |
Even if removing The nice thing about being in Julia is that you don't have to learn some limited sub-language to interact with some incomprehensible back-end moster via its own callback API. You already know how to write a loop. If you write |
Pinging @ChrisRackauckas here to avoid the model-zoo issue getting side-tracked.
Agreed, but the discussion about APIs is not that the highest-level abstraction that exists ever should be the for-loop. It's that it should be the highest-level in Flux.jl. You don't need to look further than the numerous uses of Flux.jl in Julia or the myriad of training-loop abstractions available in other frameworks to see that designing a good simplified API is hard to do. It's undeniable that more keyword options in |
Again, to reiterate, that was meant to mirror an existing API which had already been around for a while. There is no reason to stick to it, and an alternative has been offered already, so hopefully that takes care of that.
Not really? If I wouldn't tell you that we use a for loop but can't quite do training by recursion or yield-based iteration, then it would be irresponsible of me as a package author. The limitations of the package, it's semantics and supported use cases is part of this high level abstraction as well. For folks who have been engaged, I don't doubt for a second that they'd jump into writing hugely complex routines, but for someone who is writing their first 2-300 lines of Julia, I wouldn't want them to have to go through that. Many savvy people on the slack or the zulip or anywhere would still need a hot minute to process the |
I feel like we're arguing in circles a bit. I agree with you that wrapping But as mentioned, the main discussion here is philosophical, and my reservations are not based on smaller technical details. I think it's clear where I stand on the philosophical question 😄. I tried to articulate why in the original issue, and I think @mcabbott's comments here are a good summary of that. |
I think this would be a great discussion topic for an ML call :) |
This is an example of doing something like #1461 where we can now add schedulers as
prehooks
and have the loss etc be available locally to the users to write better more efficient callbacks as well as manage when updates happen via pre hooks. This means pre hooks can throwSkipException
and not update the params at all, or run arbitrary code in general.Generally, it exposes the inner objects to be used by custom callbacks rather than pick them up from global scope or run into further hacks that can be bad for performance, which we have seen happen often.
of course, this is just implementing the changes to the loop itself, and the documentation and advertising the
for
loop more stands. This is to understand if that kind of API can be adopted, which seems clean enough and opens helps plug some holes in our training routines.