-
Notifications
You must be signed in to change notification settings - Fork 221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable loading ckpt for t0 finetuning #309
base: main
Are you sure you want to change the base?
Changes from all commits
90b8f46
abdd703
0fcb19c
63daa46
89460c0
fb8ecb8
a55d2fb
2dfe5d1
ca740f1
cb0313b
b62dcaf
b15ca2d
dc8d0ab
0a32459
2699721
1e77844
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -664,6 +664,8 @@ def _add_checkpointing_args(parser): | |
help='Do not load optimizer when loading checkpoint.') | ||
group.add_argument('--no-load-rng', action='store_true', default=None, | ||
help='Do not load rng state when loading checkpoint.') | ||
group.add_argument('--reset-progress', action='store_true', default=None, | ||
help='Reset iteration to 0 & do not load args.') | ||
group.add_argument('--finetune', action='store_true', | ||
help='Load model for finetuning. Do not load optimizer ' | ||
'or rng state from checkpoint and set iteration to 0. ' | ||
|
@@ -928,13 +930,16 @@ def __call__(self, parser, args, values, option_string=None): | |
help='Mask loss for the end of document tokens.') | ||
group.add_argument('--loss-on-targets-only', action='store_true', | ||
help='Mask loss on input sequence.') | ||
group.add_argument('--norm-target-loss', action='store_true', | ||
help='Normalize the loss per target. Used for multi-task finetuning with packing.') | ||
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", | ||
help='Some objectives require us to sample loss_mask. This might introduce bias towards ' | ||
'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' | ||
'positions based on how frequently we train on that position.' | ||
'This is mostly used for prefix_lm training') | ||
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density") | ||
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length") | ||
group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah actually let's remove that option. I don't think we've trained one successfully. We'll probably do as people have shown that it works but in another PR IMO. |
||
|
||
|
||
return parser | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |
load_dir = getattr(args, load_arg) | ||
|
||
if args.deepspeed: | ||
load_optimizer_states = False if args.no_load_optim else True | ||
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states) | ||
load_optimizer_states = not args.no_load_optim | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just use no_load_optim directly in the method |
||
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_module_only=not load_optimizer_states, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states) | ||
if loaded_dir is None: | ||
print_rank_0('WARNING: could not find the metadata file {} '.format( | ||
load_dir)) | ||
|
@@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) | ||
|
||
# Set iteration. | ||
if args.finetune or release: | ||
if args.finetune or release or args.reset_progress: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it that we didn't set finetune to True? |
||
iteration = 0 | ||
else: | ||
try: | ||
|
@@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True | |
# Check arguments. | ||
assert args.consumed_train_samples == 0 | ||
assert args.consumed_valid_samples == 0 | ||
if 'args' in state_dict: | ||
if 'args' in state_dict and not args.reset_progress: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment? Typically this is only used because the metadata loading mechanism screws with us. |
||
checkpoint_args = state_dict['args'] | ||
if not args.universal_checkpoint: | ||
check_checkpoint_args(checkpoint_args) | ||
|
@@ -480,4 +480,4 @@ def _checkpoint_info(): | |
return { | ||
"padded_vocab_size": args.padded_vocab_size, | ||
"original_vocab_size": tokenizer.vocab_size, | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -358,7 +358,7 @@ def pack_samples(self, items): | |
decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"] | ||
decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"] | ||
decoder_segment_ids[cur_len: cur_len + total_len] = item_num | ||
decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs | ||
decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs | ||
# targets are already 0 at init, no need to update `decoder_is_inputs` | ||
|
||
item_num += 1 | ||
|
@@ -399,7 +399,7 @@ def _build_index_mappings( | |
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy' | ||
|
||
# Build the indexed mapping if not exist. | ||
if torch.distributed.get_rank() == 0: | ||
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Afaik you added this code; I think it was for running tests or sth There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arf probably because I wanted to use the data loader only ... Maybe let's remove for now because we should be assuming that torch distributed is always initialized at least in Meg-DS IMO. |
||
if (not os.path.isfile(sample_idx_filename)) or \ | ||
(not os.path.isfile(shuffle_idx_filename)): | ||
|
||
|
@@ -437,15 +437,16 @@ def _build_index_mappings( | |
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping' | ||
' (seconds): {:4f}'.format(time.time() - start_time)) | ||
|
||
# This should be a barrier but nccl barrier assumes | ||
# device_index=rank which is not the case for model | ||
# parallel case | ||
counts = torch.cuda.LongTensor([1]) | ||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) | ||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) | ||
assert counts[0].item() == ( | ||
torch.distributed.get_world_size() // | ||
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) | ||
if torch.distributed.is_initialized(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
# This should be a barrier but nccl barrier assumes | ||
# device_index=rank which is not the case for model | ||
# parallel case | ||
counts = torch.cuda.LongTensor([1]) | ||
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) | ||
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) | ||
assert counts[0].item() == ( | ||
torch.distributed.get_world_size() // | ||
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) | ||
|
||
# Load mappings. | ||
start_time = time.time() | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
- segment_ids: torch.IntTensor [batch_size, sequence_length] | ||||||
Returns: | ||||||
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length] | ||||||
|
||||||
Input example for the mask examples: | ||||||
att_mask_batch = 1 | ||||||
seq_length = 7 | ||||||
decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]]) | ||||||
segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]]) | ||||||
causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length) | ||||||
""" | ||||||
|
||||||
"""Causal Inputs Mask: | ||||||
mask = [[[[1, 1, 0, 0, 0, 0, 0], | ||||||
[1, 1, 0, 0, 0, 0, 0], | ||||||
mask = [[[[1, 1, 0, 1, 1, 0, 0], | ||||||
[1, 1, 0, 1, 1, 0, 0], | ||||||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
[1, 1, 1, 0, 0, 0, 0], | ||||||
[1, 1, 1, 1, 1, 0, 0], | ||||||
[1, 1, 1, 1, 1, 0, 0], | ||||||
|
@@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
[0, 0, 0, 1, 1, 1, 0], | ||||||
[0, 0, 0, 1, 1, 1, 0], | ||||||
[0, 0, 0, 1, 1, 1, 0], | ||||||
[0, 0, 0, 0, 0, 0, 0]]]] | ||||||
[0, 0, 0, 0, 0, 0, 1]]]] | ||||||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :] | ||||||
|
||||||
|
@@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode | |||||
[0, 0, 0, 1, 1, 0, 0], | ||||||
[0, 0, 0, 1, 1, 1, 0], | ||||||
[0, 0, 0, 0, 0, 0, 0]]]] | ||||||
|
||||||
If is_causal=True: | ||||||
mask = [[[[1, 0, 0, 0, 0, 0, 0], | ||||||
[1, 1, 0, 0, 0, 0, 0], | ||||||
[1, 1, 1, 0, 0, 0, 0], | ||||||
[0, 0, 0, 1, 0, 0, 0], | ||||||
[0, 0, 0, 1, 1, 0, 0], | ||||||
[0, 0, 0, 1, 1, 1, 0], | ||||||
[0, 0, 0, 0, 0, 0, 0]]]] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there is a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we fill it with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel) |
||||||
|
||||||
""" | ||||||
attention_mask = causal_inputs_mask * padding_mask * segment_mask | ||||||
|
||||||
# Convert attention mask to binary: | ||||||
attention_mask = (attention_mask < 0.5) | ||||||
attention_mask = causal_inputs_mask * padding_mask * segment_mask | ||||||
|
||||||
return attention_mask | ||||||
# True for places we do not want to attend to | ||||||
return ~attention_mask | ||||||
|
||||||
def param_size(parameter): | ||||||
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement() | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import deepspeed | ||
import torch | ||
|
||
import finetune_t0_non_causal_decoder | ||
import finetune_t0 | ||
from megatron import global_vars, get_tokenizer, initialize_megatron, get_args | ||
from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset | ||
from megatron.data.data_samplers import build_pretraining_data_loader | ||
|
@@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self): | |
last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0]) | ||
|
||
|
||
def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): | ||
def test_finetune_t0_get_batch_pipe(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah let's make it so that the script is causal decoder specific. Let's figure out non causal decoder later on. |
||
command_args = get_default_args() | ||
command_args["--position-embedding-type"] = "alibi" | ||
|
||
|
@@ -263,7 +263,7 @@ def test_finetune_t0_non_causal_decoder_get_batch_pipe(self): | |
special_tokens_ids={tokenizer.pad} | ||
) | ||
|
||
(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data) | ||
(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0.get_batch_pipe(data) | ||
|
||
tokens = tokens.cpu() | ||
position_ids = position_ids.cpu() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,13 +82,18 @@ def encode(self, json_line): | |
ids = {} | ||
for key in self.args.json_keys: | ||
text = data[key] | ||
if self.args.prepend_space: | ||
text = f" {text}" | ||
doc_ids = [] | ||
for sentence in Encoder.splitter.tokenize(text): | ||
sentence_ids = Encoder.tokenizer.tokenize(sentence) | ||
if len(sentence_ids) > 0: | ||
doc_ids.append(sentence_ids) | ||
if len(doc_ids) > 0 and self.args.append_eod: | ||
doc_ids[-1].append(Encoder.tokenizer.eod) | ||
if len(doc_ids) > 0: | ||
if self.args.append_eod: | ||
doc_ids[-1].append(Encoder.tokenizer.eod) | ||
elif self.args.append_bos: | ||
doc_ids[-1].append(Encoder.tokenizer.bos) | ||
ids[key] = doc_ids | ||
return ids, len(json_line) | ||
|
||
|
@@ -117,6 +122,10 @@ def get_args(): | |
help='Path to the BPE merge file (if necessary).') | ||
group.add_argument('--append-eod', action='store_true', | ||
help='Append an <eod> token to the end of a document.') | ||
group.add_argument('--append-bos', action='store_true', | ||
help='Append a bos token to the end of a document.') | ||
group.add_argument('--prepend-space', action='store_true', | ||
help='Prepends a space to the beginning of a document') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a mention in which context it's useful, typically it is when you compute targets. |
||
group.add_argument("--tokenizer-name-or-path", type=str, default=None, | ||
help="Name or path of the huggingface tokenizer.") | ||
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe reshaping to the orignal structure is better API? It's better to bave the same shapes as label IMO (we still still do flatten everything)