Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethink train design and better callbacks support #1461

Open
DhairyaLGandhi opened this issue Jan 12, 2021 · 15 comments
Open

Rethink train design and better callbacks support #1461

DhairyaLGandhi opened this issue Jan 12, 2021 · 15 comments

Comments

@DhairyaLGandhi
Copy link
Member

There are a few cases where I find myself wondering if we should make it more explicit how we can extend the train loop design to be more friendly for callbacks not having to cheat to get things like the loss and so on. Further, things like FluxTraining.jl also show that we have a certain lack of preexisting callbacks, which don't need to be rewritten.

So keeping this stuff in mind, I think using pullback instead of gradient would be a step towards that, as well as not optimising before a prehook to check for callback conditions etc. This should also fall in nicely how we want to set up schedulers. I would also want to figure out where distributed and multi gpu falls in this, so we know how to proceed.

We don't necessarily want to return the losses etc, but perhaps a slightly more trained model? This would fall in line with how Optimisers.jl is looking as well.

xref #1017 #1067 #1434

cc @lorenzoh

@CarloLucibello
Copy link
Member

Here I totally second @ToucheSir 's opinion expressed on Discourse:

I’m of the somewhat extreme opinion that train! should be removed outright and Flux should only expose gradient/pullback. The one-function API looks nice on the surface, but it is horribly inflexible and obscures a lot of common errors that crop up in training loops. PyTorch seems to do just fine without one too, so it’s not like removing it will cause an unmitigated UX disaster.

We should deprecate train! entirely and let higher level packages such as FluxTraining.jl handle this

@DhairyaLGandhi
Copy link
Member Author

The loop is meant to make training easier and we have seen it being useful in many cases. I am with you in wanting to do away with it, but also want to make sure the rest of the tooling is available, and actually make it more powerful without sacrificing on performance. I don't think there is an ideal loop, but one that can hold most cases, and have the tooling to extend and offer a guiding principle around best practices related to various bits of training is definitely worth it.

@ToucheSir
Copy link
Member

I don't mind having something like train! around, but perhaps the question is how many people are using it because that's all they need for their model, and how many are using it because it looks like what you should be using. For example, Keras has fit, so train! must == fit and thus I should use only train! (with the implication that custom loops are scary and for advanced users only).

I don't think there is an ideal loop, but one that can hold most cases, and have the tooling to extend and offer a guiding principle around best practices related to various bits of training is definitely worth it.

I agree with this, but "hold most cases" is a slippery slope into something like https://github.com/lorenzoh/FluxTraining.jl/blob/master/src/train.jl. Again, there's the question of intent/scope. If Flux is fine with being an opinionated, multilevel framework, then something like train! makes perfect sense. If Flux is looking to be less opinionated on higher-level APIs or eschew them altogether, then train! sticks out like a sore thumb.

@atiyo
Copy link
Contributor

atiyo commented Jan 12, 2021

While there might not be a single ideal loop, I think it should be possible to have customisation that fits in nicely with train! while maintaining easy to use defaults.

Frameworks like PyTorchLightning are quite high level, but allow for custom training loops, for example.

For something similar in Flux, we could introduce a new method for train! that accepts a function which describes a single training step.

I have no strong feelings about above, but thought I would raise it since PyTorchLightning's abstractions to training loops seem to have attracted some fans. The fact that PyTorchLightning's abstractions are not native to Pytorch might indicate the value in having separate high and low level libraries.

@ToucheSir
Copy link
Member

That's what something like FluxTraining is doing already. You may want to have a look at the extensive back-and-forths we've had about training loop design on Zulip.

@darsnack
Copy link
Member

I agree with @CarloLucibello here. There are so many ways to design a training loop that it's going to be impossible to handle every case. train! really serves as a pedagogical example for how each piece of the training iteration comes together in Flux.

FluxTraining itself relies on other packages for pieces of the full workflow. Trying to put even a simple portion of that into Flux seems like asking for maintenance overheads that we can't service. It also doesn't add much value to the Flux ecosystem. Julia users have no qualms about installing Flux + FluxTraining (or any other high level package). Multiple packages with the correct abstractions is how our ecosystem works.

@lorenzoh
Copy link
Member

I also think that this doesn't need to be something Flux.jl handles. I would rather have a full-fledged solution and I think that is out-of-scope for Flux.jl itself considering the complexity of FluxTraining.jl.

Multiple packages with the correct abstractions is how our ecosystem works.

+1 this, composable packages are the way to go where possible.

@CarloLucibello
Copy link
Member

We could adopt a gentle deprecation path since FluxTraining.jl is not ready for debut (or it is?): remove train! from docs and model-zoo's examples for Flux v0.12, deprecate it in v0.13, and remove it later

@CarloLucibello CarloLucibello added this to the v0.12 milestone Jan 16, 2021
@DhairyaLGandhi
Copy link
Member Author

I definitely think that adding the correct abstractions is an important bit. FluxTraining.jl is a very opinionated package in terms of training routines, so it's harder to justify it as a catch all. It's flexibility imo should come from making the callbacks consistent and available more easily to be used directly with the same kind of semantics as Flux. I feel there is benefit to having the train function in, because it's describing the semantics we expect, and is sufficient for most models, but we need to message it appropriately to suggest that it might be used in multiple ways, or that the for loop is a first class api that may be preferred for different packages and training routines and hit up examples showing it.

@DhairyaLGandhi
Copy link
Member Author

This might mean that we flesh out the docs or the function and point to more directly catered packages in the ecosystem. I don't see how that takes away from the composable nature of the Julia ecosystem, but formalizes how we have built the abstractions so far

@lorenzoh
Copy link
Member

Regarding FluxTraining.jl: if you want to do standard supervised learning it already works great and has a similar feature set to fastai's training loop (barring mixed-precision training and advanced data augmentation schemes).

It is also possible to write custom training loops for things like GAN training, though not always elegantly due to how the state is set up. So there is some more design work to be done to make it possible to support other training paradigms cleanly. I haven't yet since I am doing pretty basic supervised training in my current work; maybe someone who more actively works with other training paradigms like GANs and self-supervised learning can weigh in on what is missing FluxTraining.jl to support those use cases.

@darsnack
Copy link
Member

the for loop is a first class api

This is something that I completely agree with.

This might mean that we flesh out the docs or the function and point to more directly catered packages in the ecosystem

More complete examples with for loops would be a good idea.

train! is not a for loop. It is a loop internally, but the interface exposed to users is a single function. This is why there is a need for callbacks at all. The user has no access to the for loop, so callbacks are necessary to allow the user an entry point.

making the callbacks consistent and available more easily to be used directly with the same kind of semantics as Flux

The implementation of a callback in simple for loop semantics is a function call. Doing anything more complicated would only make the for loop less simple.

More documentation for how users can write their own train! loops seems like the answer here instead of designing a callback system.

@lorenzoh
Copy link
Member

More documentation for how users can write their own train! loops seems like the answer here instead of designing a callback system.

Agreed, train! is only a small improvement API-wise but a large restriction in extensibility compared to writing a simple for-loop

@DhairyaLGandhi
Copy link
Member Author

train! is not a for loop.

That is exactly correct and thanks for bringing it up. This also harkens back to #1461 (comment) where part of the goal would be to expose functionality within this loop. Be that through pre/posthooks or scheduling points to give control to user written code in other ways.

Doing anything more complicated would only make the for loop less simple

Yes, and this thread is to weigh in those schemes. I don't think having stubs would necessitate complicating the function to any meaningful degree, as long as the objective is to let tinkering with the loop possible.

large restriction in extensibility compared to writing a simple for-loop

Since the question here is to see how to make extensible designs for the train function, I think this is subsumed?

@DhairyaLGandhi
Copy link
Member Author

I was thinking of something like this to expose the loop to the users. Users can add containers to hold some params, and allow for arbitrary code to run before the optimisation step, and after

struct Callback{T}
  losses::AbstractVector{T}
end

# l: loss at the datapoint
# ps: params (maybe can skip but good to have to avoid globals)
# gs: grads at the datapoint to inspect
# d: datapoint
# opt: modify optimiser based on some condition

(cb::Callback)(l, ps, gs, d, opt) = append!(cb.losses, l)

prehook(l, ps, gs, d, opt) = throw(Flux.Optimise.SkipException())

c = Callback(Float32[])
Flux.train!(loss, ps, data, opt, cb = [() -> (), c])
Flux.train!(loss, ps, data, opt, prehooks = prehook)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants