diff --git a/examples_deepspeed/universal_checkpointing/README.md b/examples_deepspeed/universal_checkpointing/README.md index 5ef71503a8..341b0d113f 100644 --- a/examples_deepspeed/universal_checkpointing/README.md +++ b/examples_deepspeed/universal_checkpointing/README.md @@ -77,6 +77,40 @@ Please see the corresponding [pull request](https://github.com/microsoft/Megatro Combining sequence parallelism with data parallelism is another good use case for universal checkpointing, see [sp pull request](https://github.com/microsoft/DeepSpeed/pull/4752) for example and visualization of matching loss values. +### TensorBoard Log Analysis + +The Universal Checkpointing example includes a TensorBoard analysis script that will generate `csv` files and `png` plots across the unviersal checkpointing training steps for comparison of training and validation loss curves. + +After Step 3 is completed, the script may be executed as follows: +```bash +bash examples_deepspeed/universal_checkpointing/run_tb_analysis.sh z1_uni_ckpt +``` + +The script will output the following `csv` files: + - uc_out_tp_2_pp_2_dp_2_sp_1.csv + - uc_out_tp_2_pp_2_dp_1_sp_1.csv + - val_uc_out_tp_2_pp_2_dp_2_sp_1.csv + - val_uc_out_tp_2_pp_2_dp_1_sp_1.csv + +The script will also output the following `png` files: + - uc_char_training_loss.png + - uc_char_validation_loss.png + +Below is the visualization of the `png` files generated from this example. + +
+ + + *Figure 1: Training LM loss curve for first 200 training steps of Step 1 (TP=2, PP=2, DP=2) and training steps 101 to 200 of Step 3 (TP=2, PP=2, DP=1), which was loaded using the Universal Checkpoint.* +
+ +
+ + + *Figure 2: Validation LM loss curve for first 200 training steps of Step 1 (TP=2, PP=2, DP=2) and training steps 101 to 200 of Step 3 (TP=2, PP=2, DP=1), which was loaded using the Universal Checkpoint.* +
+ + ## ZeRO stage 2 training Repeat steps in ZeRO stage 1 training above with the following modifications to your job batch scripts: * Set ZERO_STAGE=2 diff --git a/examples_deepspeed/universal_checkpointing/assets/image/uc_char_training_loss.png b/examples_deepspeed/universal_checkpointing/assets/image/uc_char_training_loss.png new file mode 100644 index 0000000000..4df1ff1fc8 Binary files /dev/null and b/examples_deepspeed/universal_checkpointing/assets/image/uc_char_training_loss.png differ diff --git a/examples_deepspeed/universal_checkpointing/assets/image/uc_char_validation_loss.png b/examples_deepspeed/universal_checkpointing/assets/image/uc_char_validation_loss.png new file mode 100644 index 0000000000..5a65f6bd12 Binary files /dev/null and b/examples_deepspeed/universal_checkpointing/assets/image/uc_char_validation_loss.png differ diff --git a/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh b/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh index 80a31d6373..7aa988a0a0 100755 --- a/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh +++ b/examples_deepspeed/universal_checkpointing/run_tb_analysis.sh @@ -22,7 +22,7 @@ python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_scrip python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \ --tb_dir $OUTPUT_PATH \ --tb_event_key "lm-loss-validation/lm loss validation" \ - --csv_name "val" \ + --csv_name "val_" \ --plot_name "uc_char_validation_loss.png" \ --plot_title "Megatron-GPT Universal Checkpointing - Validation Loss" \ --plot_y_label "Validation LM Loss" \ diff --git a/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py b/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py index db57920463..f5809c3dc1 100644 --- a/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py +++ b/examples_deepspeed/universal_checkpointing/tb_analysis/uc_analysis.py @@ -19,7 +19,7 @@ def set_names(self, path_name): tp, pp, dp, sp = match.groups() self._label_name = f"Training Run: TP: {tp}, PP: {pp}, DP: {dp}" - self._csv_name = f"uc_out_tp_{tp}_pp_{pp}_dp_{dp}_sp_{sp}_val_loss" + self._csv_name = f"uc_out_tp_{tp}_pp_{pp}_dp_{dp}_sp_{sp}" def get_label_name(self): return self._label_name