Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TensorBoard analysis script to Universal Checkpointing Example #345

Merged
merged 24 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions examples_deepspeed/universal_checkpointing/run_tb_analysis.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

OUTPUT_PATH=$1

if [ "$OUTPUT_PATH" == "" ]; then
OUTPUT_PATH="z1_uni_ckpt"
fi

# Training Loss
python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \
--tb_dir $OUTPUT_PATH \
--tb_event_key "lm-loss-training/lm loss" \
--plot_name "uc_char_training_loss.png" \
--plot_title "Megatron-GPT Universal Checkpointing - Training Loss" \
--use_sns

# Validation Loss
python3 examples_deepspeed/universal_checkpointing/tb_analysis/tb_analysis_script.py \
--tb_dir $OUTPUT_PATH \
--tb_event_key "lm-loss-validation/lm loss validation" \
--csv_name "val" \
--plot_name "uc_char_validation_loss.png" \
--plot_title "Megatron-GPT Universal Checkpointing - Validation Loss" \
--plot_y_label "Validation LM Loss" \
--use_sns
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import abc
from abc import ABC


class TensorBoardAnalysis(ABC):

def __init__(self):
self._name = None
self._label_name = None
self._csv_name = None

@abc.abstractmethod
def set_names(self, path_name):
...

@abc.abstractmethod
def get_label_name(self):
...

@abc.abstractmethod
def get_csv_filename(self):
...

@abc.abstractmethod
def path_regex(self):
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--tb_dir", required=True, type=str, help="Directory for tensorboard output")
parser.add_argument("--analyzer", default="universal_checkpointing", type=str, choices=["universal_checkpointing"], help="Specify the analyzer to use")
parser.add_argument("--tb_event_key", required=False, default="lm-loss-training/lm loss", type=str, help="Optional override of the TensorBoard event key")
parser.add_argument("--plot_title", required=False, default="Megatron-GPT Universal Checkpointing", type=str, help="Optional override of the plot title")
parser.add_argument("--plot_x_label", required=False, default="Training Step", type=str, help="Optional override of the plot x-label")
parser.add_argument("--plot_y_label", required=False, default="LM Loss", type=str, help="Optional override of the plot y-label")
parser.add_argument("--plot_name", required=False, default="uni_ckpt_char.png", type=str, help="Optional override of the plot file name")
parser.add_argument("--skip_plot", action='store_true', help="Skip generation of plot file")
parser.add_argument("--skip_csv", action='store_true', help="Skip generation of csv files")
parser.add_argument("--use_sns", action='store_true', help="Use the SNS library to format plot")
parser.add_argument("--csv_name", required=False, default="", type=str, help="Unique name for CSV files")
lekurile marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import re
import pandas as pd
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from utils import get_analyzer, find_files
from arguments import parser

args = parser.parse_args()

if args.use_sns:
import seaborn as sns
sns.set()

def main():
target_affix = 'events.out.tfevents'
tb_log_paths = find_files(args.tb_dir, target_affix)

analyzer = get_analyzer(args.analyzer)

for tb_path in tb_log_paths:
print(f"Processing: {tb_path}")
analyzer.set_names(tb_path)

event_accumulator = EventAccumulator(tb_path)
event_accumulator.Reload()

events = event_accumulator.Scalars(args.tb_event_key)

x = [x.step for x in events]
y = [x.value for x in events]

plt.plot(x, y, label=f'{analyzer.get_label_name()}')

if not args.skip_csv:
df = pd.DataFrame({"step": x, "value": y})
df.to_csv(f"{args.csv_name}{analyzer.get_csv_filename()}.csv")

if not args.skip_plot:
plt.legend()
plt.title(args.plot_title)
plt.xlabel(args.plot_x_label)
plt.ylabel(args.plot_y_label)
plt.savefig(args.plot_name)

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import re
from abstract_analysis import TensorBoardAnalysis


class UniversalCheckpointingAnalysis(TensorBoardAnalysis):

def __init__(self):
self._name = "universal_checkpointing"

def set_names(self, path_name):
match = re.match(self.path_regex(), path_name)
if not match:
raise ValueError(f"Path ({path_name}) did not match regex ({self.path_regex()})")
tp, pp, dp, sp = match.groups()

self._label_name = f"Training Run: TP: {tp}, PP: {pp}, DP: {dp}"
self._csv_name = f"uc_out_tp_{tp}_pp_{pp}_dp_{dp}_sp_{sp}_val_loss"

def get_label_name(self):
return self._label_name

def get_csv_filename(self):
return self._csv_name

def path_regex(self):
return '.*tp(\d+).*pp(\d+).*dp(\d+).*sp(\d+)'
32 changes: 32 additions & 0 deletions examples_deepspeed/universal_checkpointing/tb_analysis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
from uc_analysis import UniversalCheckpointingAnalysis


def find_files(directory, file_affix):
"""
Searches for files with a specific affix in a directory using os.walk().
Args:
directory (str): The path to the directory to search.
file_affix (str): The desired file affix.
Returns:
list: A list of paths to matching files.
"""
matching_paths = []
for root, _, files in os.walk(directory):
lekurile marked this conversation as resolved.
Show resolved Hide resolved
for filename in files:
if root not in matching_paths and filename.lower().startswith(file_affix.lower()):
matching_paths.append(os.path.join(root))
return matching_paths

def get_analyzer(analyzer_name):
if analyzer_name == 'universal_checkpointing':
return UniversalCheckpointingAnalysis()
else:
raise ValueError(f"Unsupported analyzer {analyzer_name}")