Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge main #331

Merged
merged 2 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def _add_learning_rate_args(parser):
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.')
group.add_argument('--universal-checkpoint', action='store_true',
help='Loading a universal format checkpoint.')

return parser

Expand Down
17 changes: 15 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
mpu,
print_rank_0,
update_num_microbatches,
utils)
utils,
get_tokenizer)
from megatron.enums import PositionEmbeddingType

_CHECKPOINT_VERSION = None
Expand Down Expand Up @@ -131,6 +132,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens
state_dict['checkpoint_info'] = _checkpoint_info()

# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
Expand Down Expand Up @@ -361,7 +363,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
if not args.universal_checkpoint:
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
Expand Down Expand Up @@ -468,3 +471,13 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print(' successfully loaded {}'.format(checkpoint_name))

return model


def _checkpoint_info():
args = get_args()
tokenizer = get_tokenizer()

return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
68 changes: 66 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,32 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler


def sync_hp_to_lp(optimizer):

optimizer.update_lp_params()

# for n,p in model.named_parameters():
# print(n)

# if p._hp_mapping is not None:
# #print(f'rank {rank} fixing hp for input_layernorm')
# #p._hp_mapping.update_hp()

# hp = p._hp_mapping.hp_fragment



# torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())

# # 3. optim states
# for key in ['exp_avg', 'exp_avg_sq']:
# optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
# #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
# torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
# #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')



def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
Expand All @@ -386,12 +412,21 @@ def setup_model_and_optimizer(model_provider_func):

if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
pp = mpu.get_pipeline_model_parallel_world_size()
#pp = mpu.get_pipeline_model_parallel_world_size()

import json
import io
with io.open(args.deepspeed_config, "r", encoding="utf-8") as f:
config = json.load(f)
if args.universal_checkpoint:
config["checkpoint"] = {"load_universal": True}

model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model[0],
optimizer=optimizer,
lr_scheduler=lr_scheduler,
config=config,
args=args,
lr_scheduler=lr_scheduler
)

assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed"
Expand All @@ -416,8 +451,37 @@ def setup_model_and_optimizer(model_provider_func):
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])


# hp -> lp
if args.deepspeed and args.universal_checkpoint:
sync_hp_to_lp(optimizer)


else:
args.iteration = 0

from .utils import dump_weights
dump_weights(f'{args.universal_checkpoint=}', args.iteration, model, optimizer)

# tp_rank = mpu.get_tensor_model_parallel_rank()
# pp_rank = mpu.get_pipeline_model_parallel_rank()
# dp_rank = mpu.get_data_parallel_rank()
# for n,p in model[0].named_parameters():
# if 'word_embeddings.weight' not in n:
# continue
# if tp_rank == 0 and pp_rank == 0:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')

# if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')


# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
Expand Down
76 changes: 76 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,79 @@ def found_kill_switch():
return True
else:
return False

def get_fingerprint_header():
return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata"

def get_fingerprint(p):
return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}"


def dump_weights(preamble, iteration, model, optimizer, tensor=None):
tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
dp_size = mpu.get_data_parallel_world_size()
fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"

# only care for first and last pp stages and dp0 tp0
#if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()):
# return

#if not (tp_rank == 0 and dp_rank == 0):
# return

if tensor is not None:
orig_tensor = tensor
if hasattr(tensor, "_hp_param"):
numel = tensor._hp_param.numel() # // dp_size
tensor = tensor.flatten().narrow(0, 0, numel)

#print(fn)
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")

if tensor is not None:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
else:
for n, p in model[0].named_parameters():
fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n")


return


# until we figure out how to dump the actual fp32 values don't do this
fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")
if tensor is not None:
tensor = orig_tensor
if hasattr(tensor, "_hp_param"):
fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n")
#fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n")
else:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
#fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n")

else:
if hasattr(model[0].module.tied_modules, "embed"):
p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param
fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n")

# for i, param_group in enumerate(optimizer.param_groups):
# fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n")
#fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n")
# if mpu.is_pipeline_first_stage():
# x = optimizer.fp32_groups_flat_partition[0]
# fh.write(f"fp32={x[:402432]}\n")
# if mpu.is_pipeline_last_stage()):
# x = optimizer.fp32_groups_flat_partition[1]
# fh.write(f"fp32={x[-402432:]}\n")

# import os
# import socket
# hostname = socket.gethostname()
# pid = os.getpid()
# global_rank = torch.distributed.get_rank()
#fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt"
65 changes: 40 additions & 25 deletions run_bf16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,58 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
#DATASET_3="<PATH TO THE THIRD DATASET>"
#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"

BASE_DATA_PATH=/data/Megatron-LM/data
#BASE_DATA_PATH=tests/data/gpt2
#DATASET=${BASE_DATA_PATH}/meg-gpt2-openwebtext_text_document
#VOCAB_PATH=${BASE_DATA_PATH}/gpt2-tiny-vocab.json
#MERGE_PATH=${BASE_DATA_PATH}/gpt2-tiny-merges.txt

BASE_DATA_PATH=/vc_data/Megatron-LM/data
DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron
VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt


script_path=$(realpath $0)
script_dir=$(dirname $script_path)
#CONFIG_JSON="$script_dir/ds_config.json"
CONFIG_JSON="/tmp/ds_config.json"
CONFIG_JSON="$script_dir/ds_config.json"
#CONFIG_JSON="/tmp/ds_config.json"

USE_DEEPSPEED=1
ZERO_STAGE=0


# Debug
#TP=4
#PP=4
#LAYERS=8
#HIDDEN=512
#SEQ=1024
#GLOBAL_BATCH=128
#WORKER_STR="-i worker-0"


TP=1
PP=1
DP=2
# Debug
DEBUG_MODE=0
if [[ $DEBUG_MODE == 1 ]]; then
LAYERS=4
HIDDEN=512
SEQ=512
EXIT_INTERVAL=3
else
HIDDEN=1024
LAYERS=24
SEQ=1024
EXIT_INTERVAL=10
fi

TP=2
PP=2
DP=4
WORLD_SIZE=$((TP*PP*DP))
HIDDEN=1024
LAYERS=24
SEQ=1024
GLOBAL_BATCH=1
WORKER_STR=""
GLOBAL_BATCH=4

MICRO_BATCH=1
TRAIN_ITERS=100000
CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}
LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}

LR=6.0e-4
MIN_LR=6.0e-5
DTYPE="bf16"
EXP_DIR=${HOME}/experiments/results/bf16
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3"
EXP_DIR=${HOME}/experiments/results/ckpt_reshape
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont"
mkdir -p $LOG_DIR

while [[ $# -gt 0 ]]
Expand Down Expand Up @@ -89,7 +99,7 @@ options=" \
--max-position-embeddings $SEQ \
--micro-batch-size $MICRO_BATCH \
--global-batch-size $GLOBAL_BATCH \
--train-iters 1000 \
--train-iters $TRAIN_ITERS \
--lr $LR \
--min-lr $MIN_LR \
--lr-decay-style cosine \
Expand All @@ -99,7 +109,7 @@ options=" \
--data-path ${DATASET} \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--save-interval 10000 \
--save-interval 1000 \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
Expand All @@ -108,7 +118,12 @@ options=" \
--init-method-std 0.006 \
--${DTYPE} \
--checkpoint-activations \
--exit-interval 10000 \
--exit-interval ${EXIT_INTERVAL} \
--save ${CHECKPOINT_PATH} \
--load ${LOAD_CHECKPOINT_PATH} \
--position-embedding-type alibi \
--override-lr-scheduler \
--embed-layernorm \
--tensorboard-dir $LOG_DIR
"

Expand Down Expand Up @@ -151,7 +166,7 @@ cat <<EOT > $CONFIG_JSON
}
EOT

WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
#WORKER_STR="-i worker-0:0,1,2,3"
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"
Expand Down
Loading