Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable loading ckpt for t0 finetuning #309

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
19 changes: 16 additions & 3 deletions finetune_t0_non_causal_decoder.py → finetune_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ def model_provider(pre_process=True, post_process=True):
see_memory_usage(f"After Building Model", force=True)
return model


def fast_normalize(loss_mask: torch.Tensor):
"""
Turn loss_mask from [0,0,0,1,1,0,0,1,0,0,1,1,1] > [0,0,0,0.5,0.5,0,0,1,0,0,0.3,0.3,0.3]
"""
_, inverse_indices, counts = torch.unique_consecutive(loss_mask, return_inverse=True, return_counts=True)
counts = torch.gather(dim=0, index=inverse_indices, input=counts)
return loss_mask / counts

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand Down Expand Up @@ -83,17 +92,21 @@ def get_batch_pipe(data):
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_on_non_pad_only = (labels != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
is_causal=not(args.prefixlm),
causal_mask=~(causal_mask.bool()), # Turn back into tril being ones
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)

if args.norm_target_loss:
loss_mask = loss_mask.view(-1)
loss_mask = fast_normalize(loss_mask)
Comment on lines +107 to +108
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reshaping to the orignal structure is better API? It's better to bave the same shapes as label IMO (we still still do flatten everything)


if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

Expand Down
5 changes: 5 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,8 @@ def _add_checkpointing_args(parser):
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--reset-progress', action='store_true', default=None,
help='Reset iteration to 0 & do not load args.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
Expand Down Expand Up @@ -928,13 +930,16 @@ def __call__(self, parser, args, values, option_string=None):
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',
help='Mask loss on input sequence.')
group.add_argument('--norm-target-loss', action='store_true',
help='Normalize the loss per target. Used for multi-task finetuning with packing.')
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
'positions based on how frequently we train on that position.'
'This is mostly used for prefix_lm training')
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")
group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah actually let's remove that option. I don't think we've trained one successfully. We'll probably do as people have shown that it works but in another PR IMO.



return parser
Expand Down
10 changes: 5 additions & 5 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
load_dir = getattr(args, load_arg)

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
load_optimizer_states = not args.no_load_optim
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just use no_load_optim directly in the method

loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_module_only=not load_optimizer_states, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
set_checkpoint_version(state_dict.get('checkpoint_version', 0))

# Set iteration.
if args.finetune or release:
if args.finetune or release or args.reset_progress:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it that we didn't set finetune to True?

iteration = 0
else:
try:
Expand All @@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
if 'args' in state_dict and not args.reset_progress:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment? Typically this is only used because the metadata loading mechanism screws with us.

checkpoint_args = state_dict['args']
if not args.universal_checkpoint:
check_checkpoint_args(checkpoint_args)
Expand Down Expand Up @@ -480,4 +480,4 @@ def _checkpoint_info():
return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
}
23 changes: 12 additions & 11 deletions megatron/data/decoder_packed_mtf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def pack_samples(self, items):
decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[cur_len: cur_len + total_len] = item_num
decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs
decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs
# targets are already 0 at init, no need to update `decoder_is_inputs`

item_num += 1
Expand Down Expand Up @@ -399,7 +399,7 @@ def _build_index_mappings(
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik you added this code; I think it was for running tests or sth

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arf probably because I wanted to use the data loader only ... Maybe let's remove for now because we should be assuming that torch distributed is always initialized at least in Meg-DS IMO.

if (not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):

Expand Down Expand Up @@ -437,15 +437,16 @@ def _build_index_mappings(
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
if torch.distributed.is_initialized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Load mappings.
start_time = time.time()
Expand Down
8 changes: 6 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

"""GPT-2 model."""

from functools import partial
import torch

from megatron import get_args
Expand Down Expand Up @@ -186,6 +185,10 @@ def CrossEntropy(output, labels):
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
elif args.norm_target_loss:
expected_num_of_target_seqs = loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / expected_num_of_target_seqs
return loss
else:
expected_number_of_tokens = loss_mask.sum()

Expand Down Expand Up @@ -252,7 +255,8 @@ def _to_float16(inputs):
args.num_layers),
layer_number=layer_idx,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=attn_mask_type))
self_attn_mask_type=attn_mask_type)
)

# Undo data format change
def undo(x):
Expand Down
3 changes: 2 additions & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def mask(self):

@property
def bos(self):
raise NotImplementedError("Missing <bos>")
candidate = self.tokenizer.bos_token_id
return self._check_token_candidate(candidate)

@property
def eos(self):
Expand Down
30 changes: 23 additions & 7 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
- segment_ids: torch.IntTensor [batch_size, sequence_length]
Returns:
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length]

Input example for the mask examples:
att_mask_batch = 1
seq_length = 7
decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]])
segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]])
causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length)
"""

"""Causal Inputs Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
mask = [[[[1, 1, 0, 1, 1, 0, 0],
[1, 1, 0, 1, 1, 0, 0],
Muennighoff marked this conversation as resolved.
Show resolved Hide resolved
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
Expand Down Expand Up @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]
Muennighoff marked this conversation as resolved.
Show resolved Hide resolved
"""
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :]

Expand All @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]

If is_causal=True:
mask = [[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is a 1 , because the last row & column is 100% padding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum I'm wondering if this doesn't screw something up. Essentially you're going to compute softmax on a row with only zeros ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last row & last col are the attention scores of the last token with respect to the last token. Since the last token is masked out in our loss_mask it doesn't matter I think.
Also it's a row with only -inf, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you compute softmax, what should be the result of the softmax of a row full of masked out values .... It feels like that would return lots of Nans.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we fill it with -inf?
And the softmax of a row where all values are the same is just 1/n, no? Where would it cause NaNs?

Copy link
Member

@thomasw21 thomasw21 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try writing a test but I would be pretty sure that the actual results are 0. (with current kernel)


"""
attention_mask = causal_inputs_mask * padding_mask * segment_mask

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
attention_mask = causal_inputs_mask * padding_mask * segment_mask

return attention_mask
# True for places we do not want to attend to
return ~attention_mask

def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import deepspeed
import torch

import finetune_t0_non_causal_decoder
import finetune_t0
from megatron import global_vars, get_tokenizer, initialize_megatron, get_args
from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset
from megatron.data.data_samplers import build_pretraining_data_loader
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self):
last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0])


def test_finetune_t0_non_causal_decoder_get_batch_pipe(self):
def test_finetune_t0_get_batch_pipe(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's make it so that the script is causal decoder specific. Let's figure out non causal decoder later on.

command_args = get_default_args()
command_args["--position-embedding-type"] = "alibi"

Expand All @@ -263,7 +263,7 @@ def test_finetune_t0_non_causal_decoder_get_batch_pipe(self):
special_tokens_ids={tokenizer.pad}
)

(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data)
(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0.get_batch_pipe(data)

tokens = tokens.cpu()
position_ids = position_ids.cpu()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from megatron.training import setup_model_and_optimizer
import pretrain_gpt
import pretrain_prefix_lm
import finetune_t0_non_causal_decoder
import finetune_t0


def get_default_args(test_file_dir: str):
Expand Down Expand Up @@ -456,7 +456,7 @@ def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_i
vocab_size=args.padded_vocab_size,
special_tokens_ids={tokenizer.pad}
)
model, _, _ = setup_model_and_optimizer(finetune_t0_non_causal_decoder.model_provider)
model, _, _ = setup_model_and_optimizer(finetune_t0.model_provider)
model = model[0]
model._config.train_micro_batch_size_per_gpu = args.micro_batch_size
model.set_train_batch_size(args.micro_batch_size)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_training_t0(self):
--deepspeed-activation-checkpointing
""".split()

script = [f"{self.src_dir}/finetune_t0_non_causal_decoder.py"]
script = [f"{self.src_dir}/finetune_t0.py"]
launcher = get_launcher(num_gpus)

cmd = launcher + script + args + ds_args
Expand Down
13 changes: 11 additions & 2 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,18 @@ def encode(self, json_line):
ids = {}
for key in self.args.json_keys:
text = data[key]
if self.args.prepend_space:
text = f" {text}"
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
if len(doc_ids) > 0:
if self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
elif self.args.append_bos:
doc_ids[-1].append(Encoder.tokenizer.bos)
ids[key] = doc_ids
return ids, len(json_line)

Expand Down Expand Up @@ -117,6 +122,10 @@ def get_args():
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--append-bos', action='store_true',
help='Append a bos token to the end of a document.')
group.add_argument('--prepend-space', action='store_true',
help='Prepends a space to the beginning of a document')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a mention in which context it's useful, typically it is when you compute targets.

group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
Expand Down