Trainer controller

Trainer controller is a framework for controlling the trainer loop using user-defined rules and metrics.


This frameworks helps user define rules to capture scenarios like criteria for stopping an ongoing training (e.g., validation loss reaching a certain target, validation loss increasing with epoch, training loss values for last 100 steps increasing, etc).


Note: Evaluation loss and validation loss are the same.

  1. The trainer controller feature can be controlled by a configuration file supplied by the user at the start of the training. Here is a sample of how the user can initiate a trainer controller for a training job, by specifying path to an existing configuration loss.yaml in the ./examples/trainercontroller_configs directory using the flag --trainer_controller_config_file:

    python ./tuning/  \
    --trainer_controller_config_file "$EXAMPLE_CONFIGS/epoch-level-eval-loss-below-threshold.yaml" \
  2. For this usage illustration, we could use the epoch-level-eval-loss-below-threshold.yaml in the ./examples/trainercontroller_configs directory as shown below:

    - name: trainer_state
      class: TrainingState
    - name: evalmetric
      class: EvalMetrics
    - name: epoch_level_eval_loss_below_threshold
      - on_epoch_end
      rule: 'evalmetric["eval_loss"] < 2.25 and trainer_state["epoch"] > 2'
      - hfcontrols.should_training_stop

    Here is a brief primer on the above configuration. More details could be found here. Note that in the following descriptions, we use metric and metric handler interchangeably to describe a class which exposes numeric information about the training state / relevant computations for use in a rule for early termination.

    • Description: The above configuration stops the training when a evaluation loss decreases below 2.25 after two epochs.

    • Metrics: The configuration uses two metrics listed under controller_metrics section. One is named evalmetric, which uses the built-in metric handler class EvalMetrics to expose evaluation loss, and the other, trainer_state, uses the built-in metric handler class TrainingState to expose the current epoch. These are referred to in the rule as shown above. There are other metrics that could also be used in place of evalmetric and trainer_state. At the time of writing, the supported metric handler classes are as follows:

      • Loss: This metric handler exposes the training loss after every on_log event. See more on trainer events here.

      • TrainerState: This metric exposes the trainer state (more on trainer state can be found here). Here is an example metric which uses both the TrainerState and Loss metric.

      • EvalMetrics: This metric exposes all the evaluation metrics used in the training job (E.g evaluation/validation loss). Here is an example config which uses the EvalMetric's eval_loss.

      • HistoryBasedMetric: This metric exposes a moving window of evaluation metrics and training loss. It is useful to create rules on a history of values (i.e. evaluation metrics and training loss). Following are some examples which illustrate how this metric could be used:

        • epoch-level-eval-loss-patience.yaml: This configuration performs a threshold test for evaluation loss with a patience threshold of 2. I.e., suppose the evaluation loss lower threshold is 2, and patience threshold is 3, then the trainer controller will not take an action, e.g., stop the training, when the rule becomes true. i.e., evaluation loss is lower than 2, three consecutive times.
        • non-decreasing-training-loss.yaml: This configuration compares the first and last values of a window of training loss samples and determines if the training loss has increased or not. If there is an increase, the training is stopped.

        Let us assume use the below example to understand the usage:

        - name: history_window
            class: HistoryBasedMetric
            window_size: 2
        - name: epoch_level_eval_loss_patience
            - on_epoch_end
            rule: len(history_window["metrics"]) > 0 and history_window["metrics"]["eval_loss"][-1] > 2
            patience_threshold: 2
            - hfcontrols.should_training_stop

        In the above YAML, the name for HistoryBasedMetric used is history_window. Here is short primer on defining rules using the HistoryBasedMetric:

        1. Treat the history_window as a python dictionary. The structure of the data in this dictionary is:
                  "metrics": {
                                  "global_step": [...],
                                  "epoch": [...],
                                  "eval_loss": [...],
                                  "user_eval_metric_1": [...],
                                  "user_eval_metric_2": [...],
                  "training_loss": {
                                  "global_step": [...],
                                  "epoch": [...],
                                  "loss": [...],
        2. To access the first value in window of evaluation metric eval_loss, here is the illustration history_window["metrics"]["eval_loss"][0]. In the above YAML, the last element is accessed as follows: history_window["metrics"]["eval_loss"][-1].
        3. Similarly, the history_window["metrics"]["global_step"][0] is global_step at the time of generation of this evaluation metric and history_window["metrics"]["epoch"][0] is the corresponding epoch.
        4. A similar approach is followed to access training loss (i.e., history_window["training_loss"]["loss"][0] gives the first training loss).
    • Trigger: There is also a trigger event to decide when the rule needs to be evaluated. This event has to be one of the trainer events listed here. The choice of even to trigger on allows for more control, e.g., controlling the times at which we should consider early training termination.

    • Rule: The rule is a python statement which could use the metric name, e.g., loss in the above case, to define boolean conditions which, when satisfied, will trigger the operation(s) listed in operations.

    • Operation: The operations section lists the operations that could be performed when the rule is satisfied (i.e. condition becomes True). Currently, we support only one type of operation class HFControls (In this particular example, the class and corresponding operation name hfcontrols are not specified explicitly as they are considered default and can be omitted). The HFControls class supports all operations listed below. More on these operations can be found here.

      • hfcontrols.should_training_stop: Stops the training.
      • hfcontrols.should_epoch_stop: Interrupts the current epoch.
      • hfcontrols.should_save: Saves the model at the current step.
      • hfcontrols.should_evaluate: Should the model be evaluated at current step.
      • hfcontrols.should_log: Should logging happen at current step.