Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactoring] Training functions and TrainConfig can be made more modular #570

Closed
debrevitatevitae opened this issue Sep 17, 2024 · 8 comments · Fixed by #593
Closed

[Refactoring] Training functions and TrainConfig can be made more modular #570

debrevitatevitae opened this issue Sep 17, 2024 · 8 comments · Fixed by #593
Assignees
Labels
refactoring Refactoring of legacy code

Comments

@debrevitatevitae
Copy link
Collaborator

Describe the feature

Refactoring the training system (training functions and training configuration) in more internal classes or functions.

It should be implemented because

Currently, there are implementation and readability problems about training functions and TrainConfig

  • training functions are too long (train_grad counts more than 360 lines of code!). Therefore, on one hand, they have too many responsibilities, which makes them hard to maintain and extend. On the other hand they are hard to read.
  • TrainConfig unloads too much responsibility to the __post_init__, which causes unwanted behavior such as this one.

Modularization and responsibility separation will help in the readability, maintenance and extensibility of the training system

Additional context

No response

Would you like to work on this issue?

None

@debrevitatevitae debrevitatevitae added feature New feature or request refactoring Refactoring of legacy code and removed feature New feature or request labels Sep 17, 2024
@mlahariya mlahariya self-assigned this Oct 3, 2024
@mlahariya
Copy link
Collaborator

mlahariya commented Oct 3, 2024

Hey @debrevitatevitae , Thanks for raising this. For the refactoring, I had a few ideas and suggestions.

  1. Logger: We can move the logging part of the training functions outside into a separate class (Logger). An instance of this class could take care of all logging operations - including defining callbacks and providing methods to log at different steps in the training loop
  2. ConfigHandler: We can define a config handler class that would be a collection of methods to handle TrainConfig. This would help us in initial and end stages of the training - where we can define separate methods based on the configuration provided by the user. We can also move the __post_init__ methods into this class, allowing us to resolve this issue

The simple outline of the train function would look something like this

Trainer

Train(args):
  logger = Logger(config)

  confighandle = ConfigHandle(config)

  # training 
  # and logging 

  # end logging and close the logger`

I had two thoughts where I needed inputs though:
Q1: Currently this is defined as a Train function. Can we move it to a Train class ? (a structure similar to the trainer function from PyTorch-lightning trainer).
-- You will be able to call a trainer.fit()
-- fit/step/log other methods will become available to the user to modify later on.

Q2: Do we need the separate functions - train_with_grad and train_without_grad? Can we have a single Train (function/class) - that can be used for either based on a user defined argument?

So, do we need the trainer to be a function and do we need separate functions for train_with_grand/train_without_grad?

Let me know what you think. Once these ideas are refined - I will start working on a PR for this.
@chMoussa @Roland-djee @smitchaudhary @n-toscano
Thanks
M

@chMoussa
Copy link
Collaborator

chMoussa commented Oct 7, 2024

Hey @debrevitatevitae , Thanks for raising this. For the refactoring, I had a few ideas and suggestions.

  1. Logger: We can move the logging part of the training functions outside into a separate class (Logger). An instance of this class could take care of all logging operations - including defining callbacks and providing methods to log at different steps in the training loop
  2. ConfigHandler: We can define a config handler class that would be a collection of methods to handle TrainConfig. This would help us in initial and end stages of the training - where we can define separate methods based on the configuration provided by the user. We can also move the __post_init__ methods into this class, allowing us to resolve this issue

The simple outline of the train function would look something like this

Trainer

Train(args):
  logger = Logger(config)

  confighandle = ConfigHandle(config)

  # training 
  # and logging 

  # end logging and close the logger`

I had two thoughts where I needed inputs though: Q1: Currently this is defined as a Train function. Can we move it to a Train class ? (a structure similar to the trainer function from PyTorch-lightning trainer). -- You will be able to call a trainer.fit() -- fit/step/log other methods will become available to the user to modify later on.

Q2: Do we need the separate functions - train_with_grad and train_without_grad? Can we have a single Train (function/class) - that can be used for either based on a user defined argument?

So, do we need the trainer to be a function and do we need separate functions for train_with_grand/train_without_grad?

Let me know what you think. Once these ideas are refined - I will start working on a PR for this. @chMoussa @Roland-djee @smitchaudhary @n-toscano Thanks M

I think a Train class is fine but it will need a bit of readapting code in other places. For Q2, if these can be factored into one it would be better. I'd like to see a prototype here.

@mlahariya
Copy link
Collaborator

mlahariya commented Oct 17, 2024

Hey,

So here is how we can refactor this code. We will have a common Trainer Class, where .fit() would allow us to train the model. We can achieve this by the following

1). ConfigManager: A class whose object handles Training configurations. Including initialization, logging, loading, saving, etc.
2). CallbacksManager: A class whose object handles different types of possible callbacks. I notice that we have a bunch of different callbacks, called at different steps of the training process - this will allow us to centralize and modularize them.
3) Trainer: A trainer class with all the support for training, along with a fit function. It should have properties - model/optimizer/config, etc, and also allow for modifiations at different steps of the training.

Below are the codes for all three of these. These is still rough and I am cleaning up the edges. I am sharing this to get reviews on the structure of the code - and am open to suggestions.

Example
Previously the training required

    data = to_dataloader(x, y, batch_size=batch_size, infinite=True)
    train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)

After this implementation you would need to change this to

    trian_data = to_dataloader(x, y, batch_size=batch_size, infinite=True)
    trainer = Trainer(model, train_data, optimizer, config, loss_fn=loss_fn)
    trainer.fit()

The biggest benefit of having it as a class will be inheritance. Eg. In another library, someone wants to add an extra callback after each epoch of training - they can just inherit the Trainer class - and modify the on_epoch_end() method - and thats it!

Note: that we can modify the training logic in run_train_iter to transfer from _with_grad to without_grad.
Note: training is divided into epochs -> batches

  • Currently in the training logic we are referring to iterations as epocs? We can transfer this by changing the max_iter to max_epocs.
  • The batch_size could be batch_size of the dataloader.

Note: the trainer class takes train_dataloader, val_dataloader, test_dataloader seperately - which is different from the previous method. This is debatable - I separated them for cleaner division in operations and methods - but let me know if you think otherwise.

Thanks!

@chMoussa @inafergra @smitchaudhary @debrevitatevitae @RolandMacDoland

ConfigManager: To manage TrainConfigrations

class ConfigManager:
    def __init__(self, config: TrainConfig):
        self.config = config
        self.writer = None
        self.best_val_loss = math.inf
        self.init_iter = 0

    def initialize_config(self):
        self._initialize_folder()
        self._initialize_writer()

    def _initialize_folder(self):
        if self.config.folder:
            self.config.folder = self._create_subfolder(self.config.folder)

    def _create_subfolder(self, folder: Union[str, Path]) -> Path:
        folder_path = Path(folder)
        if self.config.create_subfolder_per_run:
            subfolder_name = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + "_" + hex(os.getpid())[2:]
            folder_path = folder_path / subfolder_name
        folder_path.mkdir(parents=True, exist_ok=True)
        return folder_path

    def _initialize_writer(self):
        if self.config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
            self.writer = SummaryWriter(log_dir=str(self.config.folder))
        elif self.config.tracking_tool == ExperimentTrackingTool.MLFLOW:
            self.writer = self.config.mlflow_config
        else:
            self.writer = None

    def load_checkpoint(self, model, optimizer, device):
        log_device = "cpu" if device is None else device
        if self.config.folder:
            model, optimizer, init_iter = load_checkpoint(self.config.folder, model, optimizer, device=log_device )
            logger.debug(f"Loaded model and optimizer from {self.config.folder}")
        return model, optimizer, init_iter

    def write_checkpoint(self, iteration, is_best=False):
        # Add Logging. Omitted for readability.

    def log_hyperparameters(self, metrics):
        # Add Logging. Omitted for readability.

    def log_training_metrics(self, loss, metrics, iteration):
        # Add Logging. Omitted for readability.

    def log_model(self, model, dataloader):
        # Add Logging. Omitted for readability.

    def close_writer(self):
        if self.writer:
            if self.config.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
                self.writer.close()
            elif self.config.tracking_tool == ExperimentTrackingTool.MLFLOW:
                self.writer.end_run()


CallbacksManager:
For now, I have only included methods for 4 callbacks - print/write/plot/checkpoint - more can be added later. This can be simply used by .run(opt_res, type)

class CallbacksManager:
    def __init__(self, writer, config):
        self.print_callbacks: List[Callback] = []
        self.write_callbacks: List[Callback] = []
        self.plot_callbacks: List[Callback] = []
        self.checkpoint_callbacks: List[Callback] = []
        self.writer = writer
        self.config = config

    def add_print_callback(self):
        # Add callbacks. Omitted for readability.

    def add_write_callback(self):
        # Add callbacks. Omitted for readability.

    def add_plot_callback(self):
        # Add callbacks. Omitted for readability.

    def add_checkpoint_callback(self, checkpoint_at):
        # checkpoint_at could be "best" or iteration.
        # Add callbacks. Omitted for readability.

    def initialize_callbacks(self, checkpoint_at=None):
        self.add_print_callback()
        self.add_write_callback()
        self.add_plot_callback()
        self.add_checkpoint_callback(checkpoint_at)

    def run(self, opt_result, callback_type: str):
        # Run callbacks of the specified type
        if callback_type == "print":
            for callback in self.print_callbacks:
                callback(opt_result)
        elif callback_type == "write":
            for callback in self.write_callbacks:
                callback(opt_result)
        elif callback_type == "plot":
            for callback in self.plot_callbacks:
                callback(opt_result)
        elif callback_type == "checkpoint":
            for callback in self.checkpoint_callbacks:
                callback(opt_result)
        else:
            raise ValueError(f"Unknown callback type: {callback_type}")

Trainer:

  • First we init, define properties, and validate the inputs
  • Second we add the training methods. The main training algorithm is implemented in .fit()
  • I have also defined default loss functions (e.g. mse_loss, etc) that could be used out-of-the-box in qadence, by using loss_fn argument.

class Trainer:
   def __init__(
       self,
       model: nn.Module = None,
       optimizer: optim.Optimizer = None,
       config: TrainConfig = None,
       loss_fn: Union[None, Callable, str] = None,
       train_dataloader: Union[None, DataLoader] = None,
       val_dataloader: Union[None, DataLoader] = None,
       test_dataloader: Union[None, DataLoader] = None,
   ):

       self._model = model
       self._optimizer = optimizer
       self._config = config
       self._train_dataloader = train_dataloader
       self._val_dataloader = val_dataloader
       self._test_dataloader = test_dataloader
       self.loss_fn = self.get_loss_fn(loss_fn)

   def _validate_dataloader(self, dataloader: Union[None, DataLoader], dataloader_type: str):
       # Add validations for dataloaders. Omitted for readability. 

   def get_loss_fn(self, loss_fn: Union[None, Callable, str]) -> Callable:
       if callable(loss_fn):
           return loss_fn
       elif isinstance(loss_fn, str):
           if loss_fn == "mse":
               return mse_loss
           elif loss_fn == "cross_entropy":
               return cross_entropy_loss
           else:
               raise ValueError(f"Unsupported loss function: {loss_fn}")
       else:
           return mse_loss

   @property
   def model(self):
       if self._model is None:
           raise ValueError("Model has not been set.")
       return self._model

   @model.setter
   def model(self, model: nn.Module):
       self._model = model

   # Other Properties. Omitted for readability.




class Trainer(Trainer):
   def __init__(
       self,
       model: nn.Module = None,
       optimizer: optim.Optimizer = None,
       config: TrainConfig = None,
       loss_fn: Union[None, Callable, str] = None,
       train_dataloader: DataLoader = None,
       val_dataloader: Union[None, DataLoader] = None,
       test_dataloader: Union[None, DataLoader] = None,
       batch_size: int = None,
       writer=None,
   ):
       super().__init__(
           model=model,
           optimizer=optimizer,
           config=config,
           loss_fn=loss_fn,
           train_dataloader=train_dataloader,
           val_dataloader=val_dataloader,
           test_dataloader=test_dataloader,
       )
       self.current_epoch = 0
       self.global_step = 0
       self.best_val_loss = float("inf")
       self.batch_size = batch_size if batch_size else len(train_dataloader.dataset)

       self.callback_manager = CallbacksManager(config)
       self.config_manager = ConfigManager(config)

       self.progress = Progress(
           TextColumn("[progress.description]{task.description}"),
           BarColumn(),
           TaskProgressColumn(),
           TimeRemainingColumn(elapsed_when_finished=True),
       )

   # Training loop  ------------------------------------------------------------------------------------
   def fit(self):
       self.on_train_start()
       train_loss_metrics = []

       with self.progress:
           task = self.progress.add_task("Training", total=self.config.max_iter)
           for epoch in range(self.config.max_iter):
               try:
                   self.on_epoch_start()
                   self.current_epoch = epoch
                   epoch_loss_metrics = self.run_training_epoch(self.train_dataloader)
                   train_loss_metrics.extend(epoch_loss_metrics)

                   if self.val_dataloader:
                       self.run_validation_epoch(self.val_dataloader)

                   self.on_epoch_end(epoch_loss_metrics)
                   self.progress.advance(task)
               except KeyboardInterrupt:
                   logger.info("Terminating training gracefully after the current iteration.")
                   break

       self.on_train_end(train_loss_metrics)

   def on_train_start(self):
       self.callback_manager.initialize_callbacks()
       self.config_manager.initialize_config()
       self.config_manager.load_checkpoint(device="cpu")

   def on_epoch_start(self):
       pass

   def on_epoch_end(self, loss_metrics):
      # this is how we can use the callback_manager
       self.callback_manager.run( self.build_optimize_result(self, loss_metrics), "plot")
       self.callback_manager.run( self.build_optimize_result(self, loss_metrics), "write")
       self.config_manager.log_training_metrics(loss_metrics[-1]['loss'], loss_metrics[-1], self.current_epoch)

   def on_train_end(self, loss_metrics):
       self.callback_manager.run( self.build_optimize_result(self, loss_metrics), "checkpoint")
       self.config_manager.close_writer()

   # Train ------------------------------------------------------------------------------------
   def run_training_epoch(self, dataloader):
       self.model.train()
       epoch_loss_metrics = []

       for batch_idx, batch in enumerate(dataloader):
           self.on_train_iter_start(batch)
           loss_metrics = self.run_train_iter(batch)
           epoch_loss_metrics.append(loss_metrics)  # Accumulate loss
           self.on_train_iter_end(loss_metrics)

       return epoch_loss_metrics

   def run_train_iter(self, batch):
       loss_metrics = self.loss_fn(self.model, batch)


       # Optimization logic sits here. 



       self.global_step += 1
       return loss_metrics

   def on_train_iter_start(self):
       pass

   def on_train_iter_end(self, loss_metrics):
       pass

   # Validation  ------------------------------------------------------------------------------------
   def run_validation_epoch(self, dataloader):
       self.model.eval()
       epoch_loss_metrics = []

       for batch_idx, batch in enumerate(dataloader):
           self.on_val_iter_start(batch)
           loss_metrics = self.run_validation_iter(batch)
           epoch_loss_metrics.append(loss_metrics)  # Accumulate validation loss
           self.on_val_iter_end(loss_metrics)

       return epoch_loss_metrics

   def run_validation_iter(self, batch):
       with torch.no_grad():
           loss_metrics = self.loss_fn(self.model, batch)
       return loss_metrics

   def on_val_iter_start(self):
       pass
   
   def on_val_iter_end(self, loss_metrics):
       pass

   # Testing  ------------------------------------------------------------------------------------
   def run_test_epoch(self):
       if self.test_dataloader:
           self.model.eval()
           epoch_loss_metrics = []

           for batch_idx, batch in enumerate(self.test_dataloader):
               self.on_test_iter_start(batch)
               loss_metrics = self.run_test_iter(batch)
               epoch_loss_metrics.append(loss_metrics)  # Accumulate test loss
               self.on_test_iter_end(loss_metrics)

           return epoch_loss_metrics

   def run_test_iter(self, batch):
       with torch.no_grad():
           loss_metrics = self.loss_fn(self.model, batch)
       return loss_metrics

   def on_test_iter_start(self):
       pass

   def on_test_iter_end(self, loss_metrics):
       pass
   
   # Other methods  ------------------------------------------------------------------------------------
   def build_optimize_result(self, loss_metrics = None):
       loss, metrics = (None, None) if loss_metrics is None else loss_metrics
       return OptimizeResult(self.global_step, self.model, self.optimizer, loss, metrics)



@inafergra
Copy link
Collaborator

Tagging @arthurfaria to keep him in the loop

@chMoussa
Copy link
Collaborator

@mlahariya Thanks for the effort. I would have more to say on the topic of callbacks.

First, I would suggest for the callback manager to have one attributes that is a mapping str -> list of callbacks, where str would be one of the predefined strings print, etc. I would also introduce a generic add_callback method where the str can be specified, and you could keep the more specialised ones.

Second, when introducing callbacks, the idea was to quickly define in which place of the training part they would be run without a worry of the type they are. So that would mean also introducing a mapping of callbacks in the Trainer to identify which callbacks would be run in a given on_train and on_val functions. Or should it be actually introduced in the callback manager itself?

@mlahariya
Copy link
Collaborator

mlahariya commented Oct 21, 2024

@chMoussa
Sure, we can add the add_callback method to the callback manager. And we can also have a dict structure to store the callbacks.

Callbacks would be run in on_train/on_val functions. I am resistant to define this in the current callback manager, mainly because it makes the manager very specialized to this train function - and it would be better if we keep it generic.

Yet, we can also try completely different approach. We can modify the base callback class as well - which will have the segmentation of on_epoch_start/ontrain_start -- saved. And use them during the run. This could be in line with the overall structure of the training process. The callback manager will be simplified because of this as well.

Let me review both of these, and I will select the one that makes most sense. I will start committing changes in the branch.

@RolandMacDoland
Copy link
Collaborator

@mlahariya As I'm not an ML guy, I can't provide insightful comments here. Generally, it looks like a reasonable refactor and I support it. Happy for you to proceed whenever you've found a common ground that suits everybody.

@mlahariya
Copy link
Collaborator

mlahariya commented Oct 30, 2024

Hey.

Here is the final refactor that I think will could be used. All log related activities should pass through the callback system. A config manager to handle training config, and a trainer to support training models.

I will create a MR and start committing changes.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Refactoring of legacy code
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants