diff --git a/training/offload_states/README.md b/training/offload_states/README.md new file mode 100644 index 000000000..4add7404b --- /dev/null +++ b/training/offload_states/README.md @@ -0,0 +1,25 @@ +# Offloading States Example + +The script `offload_states.py` demonstrates how to offload the state of a model. Here is the example usage. + +```bash +$ deepspeed --num_gpus=4 offload_states.py --hidden_dim 32768 --nlayers 4 --pin_memory --non_blocking +... +Memory usage (0): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198419456 alloc_after_offload=17763840 +Memory usage (1): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198760960 alloc_after_offload=17763840 +... +Summary: pin_memory=True non_blocking=True offload=5.643414640426636 load=2.4087101459503173 +``` + +`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states. + +```bash +$ ./run_benchmark.sh +... +| |pin_memory=0_non_blocking=0|pin_memory=0_non_blocking=1|pin_memory=1_non_blocking=0|pin_memory=1_non_blocking=1| +|--:|---------------------------|---------------------------|---------------------------|---------------------------| +| 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 | +| 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 | +| 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 | +| 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 |... +``` diff --git a/training/offload_states/offload_states.py b/training/offload_states/offload_states.py new file mode 100644 index 000000000..f80b06e05 --- /dev/null +++ b/training/offload_states/offload_states.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +import argparse + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +import torch + +import deepspeed +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): + super(SimpleModel, self).__init__() + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) + if empty_grad: + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + for l in self.linears: + x = l(x) + return self.cross_entropy_loss(x, y) + + +def random_dataset(total_samples, hidden_dim, device, dtype): + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) + return train_dataset + + +def random_dataloader(model, total_samples, hidden_dim, device, dtype): + batch_size = model.train_micro_batch_size_per_gpu() + train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) + return train_loader + + +def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking, iteration, warmup): + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=iteration, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + + time_offload_list = [] + time_load_list = [] + + dist.barrier() + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + # Start offloading + alloc_before_offload = get_accelerator().memory_allocated() + dist.barrier() + + time_start = time.time() + model.offload_states(include=include, + device=OffloadDeviceEnum.cpu, + pin_memory=pin_memory, + non_blocking=non_blocking) + dist.barrier() + time_after_offload = time.time() + alloc_after_offload = get_accelerator().memory_allocated() + assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" + + # Load offloaded states back + model.reload_states() + dist.barrier() + time_after_load = time.time() + + time_offload_list.append(time_after_offload - time_start) + time_load_list.append(time_after_load - time_after_offload) + + assert alloc_after_offload < get_accelerator().memory_allocated( + ), f"Allocated memory should increase after offload back" + + if dist.get_rank() == 0: + print( + f"Memory usage ({i}): include={include}, pin_memory={pin_memory}, non_blocking={non_blocking} alloc_before_offload={alloc_before_offload} alloc_after_offload={alloc_after_offload}" + ) + + # remove warmup + time_offload_list = time_offload_list[warmup:] + time_load_list = time_load_list[warmup:] + + if dist.get_rank() == 0: + with open("offload_states.log", "a") as f: + offload_time = sum(time_offload_list) / len(time_offload_list) + load_time = sum(time_load_list) / len(time_load_list) + msg = f"{1 if pin_memory else 0},{1 if non_blocking else 0},{offload_time},{load_time}" + f.write(f"{msg}\n") + print(f"Summary: pin_memory={pin_memory} non_blocking={non_blocking} offload={offload_time} load={load_time}") + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +def main(): + parser = argparse.ArgumentParser(description="Test Offload States") + parser.add_argument("--included_state", type=str, choices=[e.name for e in OffloadStateTypeEnum] + [None], default=None, help="State to include") + parser.add_argument("--pin_memory", action='store_true', help="Pin memory") + parser.add_argument("--non_blocking", action='store_true', help="Non blocking") + parser.add_argument("--nlayers", type=int, default=1, help="Number of layers") + parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension") + parser.add_argument('--dtype', choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16', help='Data type') + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank") + parser.add_argument("--iteration", type=int, default=10, help="Warmup") + parser.add_argument("--warmup", type=int, default=5, help="Warmup") + + args = parser.parse_args() + + dtype = eval(args.dtype) + hidden_dim = args.hidden_dim + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + }, + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + with deepspeed.zero.Init(config_dict_or_path=config_dict): + model = SimpleModel(hidden_dim, nlayers=args.nlayers) + + included_state = None if args.included_state is None else [OffloadStateTypeEnum[args.included_state]] + run_model(model, config_dict, hidden_dim, dtype, included_state, args.pin_memory, args.non_blocking, args.iteration, args.warmup) + + +if __name__ == "__main__": + main() diff --git a/training/offload_states/output_table.py b/training/offload_states/output_table.py new file mode 100644 index 000000000..fc1a5b840 --- /dev/null +++ b/training/offload_states/output_table.py @@ -0,0 +1,28 @@ +import pandas as pd +from pytablewriter import MarkdownTableWriter + + +def read_csv(file_path): + return pd.read_csv(file_path) + +df = read_csv('offload_states.log') +df.columns = ['pin_memory', 'non_blocking', 'offload_time', 'load_time'] + +df['ratio_string'] = df['offload_time'].round(2).astype(str) + " / " + df['load_time'].round(2).astype(str) + +result_df = pd.DataFrame({ + 'pin_memory=0_non_blocking=0': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True), + 'pin_memory=0_non_blocking=1': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True), + 'pin_memory=1_non_blocking=0': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True), + 'pin_memory=1_non_blocking=1': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True) +}) +result_df = result_df.dropna() +result_df.index = range(1, len(result_df) + 1) +result_df.index.name = 'trial' +# print(result_df) + +writer = MarkdownTableWriter() +writer.from_dataframe(result_df, + add_index_column=True, +) +writer.write_table() \ No newline at end of file diff --git a/training/offload_states/run_benchmark.sh b/training/offload_states/run_benchmark.sh new file mode 100644 index 000000000..ba18da03e --- /dev/null +++ b/training/offload_states/run_benchmark.sh @@ -0,0 +1,28 @@ +NGPUS=4 +HIDDEN_SIZE=32768 +NUM_LAYERS=4 + +TRIALS=10 + +PIN_MEMORY_OPTS=(0 1) +NON_BLOCKING_OPTS=(0 1) + +for i in $(seq 1 $TRIALS); do + for PIN_MEMORY in "${PIN_MEMORY_OPTS[@]}"; do + PIN_MEMORY_ARG="" + if [ $PIN_MEMORY -eq 1 ]; then + PIN_MEMORY_ARG="--pin_memory" + fi + + for NON_BLOCKING in "${NON_BLOCKING_OPTS[@]}"; do + NON_BLOCKING_ARG="" + if [ $NON_BLOCKING -eq 1 ]; then + NON_BLOCKING_ARG="--non_blocking" + fi + + echo "Running iteration $i" + deepspeed --num_gpus=$NGPUS offload_states.py --hidden_dim $HIDDEN_SIZE --nlayers $NUM_LAYERS $PIN_MEMORY_ARG $NON_BLOCKING_ARG + done + done +done +python output_table.py