-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactoring] Training functions and TrainConfig can be made more modular #570
Comments
Hey @debrevitatevitae , Thanks for raising this. For the refactoring, I had a few ideas and suggestions.
The simple outline of the train function would look something like this
I had two thoughts where I needed inputs though: Q2: Do we need the separate functions - train_with_grad and train_without_grad? Can we have a single Train (function/class) - that can be used for either based on a user defined argument? So, do we need the trainer to be a function and do we need separate functions for train_with_grand/train_without_grad? Let me know what you think. Once these ideas are refined - I will start working on a PR for this. |
I think a Train class is fine but it will need a bit of readapting code in other places. For Q2, if these can be factored into one it would be better. I'd like to see a prototype here. |
Hey, So here is how we can refactor this code. We will have a common Trainer Class, where 1). ConfigManager: A class whose object handles Training configurations. Including initialization, logging, loading, saving, etc. Below are the codes for all three of these. These is still rough and I am cleaning up the edges. I am sharing this to get reviews on the structure of the code - and am open to suggestions.
The biggest benefit of having it as a class will be inheritance. Eg. In another library, someone wants to add an extra callback after each epoch of training - they can just inherit the Trainer class - and modify the Note: that we can modify the training logic in run_train_iter to transfer from _with_grad to without_grad.
Note: the trainer class takes train_dataloader, val_dataloader, test_dataloader seperately - which is different from the previous method. This is debatable - I separated them for cleaner division in operations and methods - but let me know if you think otherwise. Thanks! @chMoussa @inafergra @smitchaudhary @debrevitatevitae @RolandMacDoland ConfigManager: To manage TrainConfigrations
CallbacksManager:
Trainer:
|
Tagging @arthurfaria to keep him in the loop |
@mlahariya Thanks for the effort. I would have more to say on the topic of callbacks. First, I would suggest for the callback manager to have one attributes that is a mapping str -> list of callbacks, where str would be one of the predefined strings Second, when introducing callbacks, the idea was to quickly define in which place of the training part they would be run without a worry of the type they are. So that would mean also introducing a mapping of callbacks in the Trainer to identify which callbacks would be run in a given on_train and on_val functions. Or should it be actually introduced in the callback manager itself? |
@chMoussa Callbacks would be run in on_train/on_val functions. I am resistant to define this in the current callback manager, mainly because it makes the manager very specialized to this train function - and it would be better if we keep it generic. Yet, we can also try completely different approach. We can modify the base callback class as well - which will have the segmentation of on_epoch_start/ontrain_start -- saved. And use them during the run. This could be in line with the overall structure of the training process. The callback manager will be simplified because of this as well. Let me review both of these, and I will select the one that makes most sense. I will start committing changes in the branch. |
@mlahariya As I'm not an ML guy, I can't provide insightful comments here. Generally, it looks like a reasonable refactor and I support it. Happy for you to proceed whenever you've found a common ground that suits everybody. |
Describe the feature
Refactoring the training system (training functions and training configuration) in more internal classes or functions.
It should be implemented because
Currently, there are implementation and readability problems about training functions and
TrainConfig
train_grad
counts more than 360 lines of code!). Therefore, on one hand, they have too many responsibilities, which makes them hard to maintain and extend. On the other hand they are hard to read.TrainConfig
unloads too much responsibility to the__post_init__
, which causes unwanted behavior such as this one.Modularization and responsibility separation will help in the readability, maintenance and extensibility of the training system
Additional context
No response
Would you like to work on this issue?
None
The text was updated successfully, but these errors were encountered: