Skip to content

Commit

Permalink
MTF train script (#295)
Browse files Browse the repository at this point in the history
Co-authored-by: Lintang Sutawika <[email protected]>
Co-authored-by: Lintang Sutawika <[email protected]>
Co-authored-by: Muennighoff <[email protected]>
  • Loading branch information
4 people authored Jul 5, 2022
1 parent e1c479e commit 3d5d151
Show file tree
Hide file tree
Showing 28 changed files with 1,394 additions and 695 deletions.
142 changes: 142 additions & 0 deletions finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Multitask Finetuning T0"""

from multiprocessing.sharedctypes import Value
import torch

from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_packed_attention_mask

import deepspeed
from deepspeed.runtime.utils import see_memory_usage

try:
from torch.distributed.elastic.multiprocessing.errors import record
except ImportError:
# noop
def record(fn):
return fn

def model_provider(pre_process=True, post_process=True):
"""Build the model."""

print_rank_0("building GPT model ...")
see_memory_usage(f"Before Building Model", force=True)

args = get_args()

with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device == "none" else args.remote_device,
config_dict_or_path=args.deepspeed_config,
enabled=args.zero_stage == 3,
mpu=mpu):
if args.deepspeed:
model = GPTModelPipe(
num_tokentypes=0,
parallel_output=True,
attn_mask_type=AttnMaskType.custom
)
# This is a hack to give us a reference to get_batch_pipe from within training.py
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe
else:
raise NotImplementedError("DeepSpeed is required for T0")

see_memory_usage(f"After Building Model", force=True)
return model

def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
data:
decoder_tokens = [[6, 7, 8, 3, 4, 5, 0]]
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
decoder_is_inputs = [[1, 1, 0, 1, 1, 0, 0]]
"""
args = get_args()
tokenizer = get_tokenizer()

# Broadcast data.
data_b = mpu.broadcast_data(["decoder_token_ids", "decoder_segment_ids"], data, torch.int64)
data_c = mpu.broadcast_data(["decoder_is_inputs"], data, torch.bool)

# Unpack.
tokens_ = data_b["decoder_token_ids"].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

segment_ids = data_b["decoder_segment_ids"].long()[:, :-1]
decoder_is_inputs = data_c["decoder_is_inputs"][:, :-1]

# Get the masks and position ids.
causal_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=False # This is done below
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

return (tokens, position_ids, attention_mask), (labels, loss_mask)


def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
train_ds, valid_ds, test_ds = None, None, None

tokenizer = get_tokenizer()

print_rank_0("> building train, validation, and test datasets for T0 ...")
# Option 1 of data loading using --data-path
if args.data_path:
# TODO: Not yet compatible with dataset weights (Will break at prefixes, weights = analyze_data_prefix(args.data_path))
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
train_valid_test_num_samples=train_val_test_num_samples,
seed=args.seed,
skip_warmup=(not args.mmap_warmup)
)
else:
raise NotImplementedError("No dataloading argument passed")

print_rank_0("> finished creating T0 datasets ...")
return train_ds, valid_ds, test_ds

@record
def main():
pretrain(
train_valid_test_datasets_provider,
model_provider,
forward_step_func=None,
args_defaults={}
)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def _add_training_args(parser):
'please refer https://github.com/facebookresearch/bitsandbytes.',
dest='use_bnb_optimizer')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic', 'decoder_packed'],
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--cpu-optimizer', action='store_true',
help='Run optimizer on CPU')
Expand Down
163 changes: 3 additions & 160 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,77 +15,11 @@

"""Dataloaders."""

from functools import partial

import numpy as np
import torch

from megatron import get_args, get_tokenizer
from megatron import get_args
from megatron import mpu
from megatron.data.mtf_dataset import MTFDataset


def pack_samples(items, max_seq_len: int, micro_batch_size: int, pad_token: int):
"""
Greedily packs samples.
Items:
[
{
'input_tokens': array([6, 7]),
'target_tokens': array([8])
},
{
'input_tokens': array([3, 4]),
'target_tokens': array([5])
}
]
Output:
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, <pad>]]: Concatenation of tokens followed with padding tokens.
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]: Segment ids determine original documents.
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]: `0` depicts inputs, `1` depicts target.
"""

decoder_target_tokens = np.full((micro_batch_size, max_seq_len), pad_token)
decoder_segment_ids = np.zeros((micro_batch_size, max_seq_len))
decoder_causal_attention = np.zeros((micro_batch_size, max_seq_len))

batch_num = 0
# `0` is reserved for padding
item_num = 1
cur_len = 0
for token_dict in items:
input_token_len = len(token_dict["input_tokens"])
target_token_len = len(token_dict["target_tokens"])
total_len = input_token_len + target_token_len
if cur_len + total_len > max_seq_len:
len_diff = max_seq_len - cur_len
# Padding
if len_diff > 0:
decoder_target_tokens[batch_num][cur_len: max_seq_len] = pad_token
decoder_segment_ids[batch_num][cur_len: max_seq_len] = 0
decoder_causal_attention[batch_num][cur_len: max_seq_len] = 0
batch_num += 1
assert batch_num < micro_batch_size
item_num = 1
cur_len = 0

decoder_target_tokens[batch_num][cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_target_tokens[batch_num][cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[batch_num][cur_len: cur_len + total_len] = item_num
decoder_causal_attention[batch_num][cur_len: cur_len + input_token_len] = 1 # input
decoder_causal_attention[batch_num][cur_len + input_token_len: cur_len + total_len] = 0 # target

item_num += 1
cur_len += total_len
assert cur_len < max_seq_len

return {
"decoder_target_tokens": decoder_target_tokens,
"decoder_segment_ids": decoder_segment_ids,
"decoder_causal_attention": decoder_causal_attention,
}
from megatron.data.decoder_packed_mtf_dataset import DecoderPackedMTFDataset


def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
Expand All @@ -110,41 +44,23 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
elif args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
batch_sampler = MegatronDecoderPackedText2TextRandomSampler(
sequence_length=args.seq_length + 1,
dataset=dataset,
total_samples=len(dataset),
consumed_samples=consumed_samples,
micro_batch_size=args.micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size())
else:
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))

if num_workers is None:
num_workers = args.num_workers

collate_fn = None
if args.dataloader_type == 'decoder_packed':
assert isinstance(dataset, MTFDataset)
pad_token = get_tokenizer().pad
collate_fn = partial(pack_samples, max_seq_len=args.seq_length + 1, micro_batch_size=args.micro_batch_size,
pad_token=pad_token)

# Torch dataloader.
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collate_fn,
collate_fn=None,
pin_memory=True
)


class MegatronPretrainingSampler:

def __init__(self, total_samples, consumed_samples, micro_batch_size,
Expand Down Expand Up @@ -246,76 +162,3 @@ def __iter__(self):
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []


class MegatronDecoderPackedText2TextRandomSampler(object):
"""
Converts a two stream dataset with `input_tokens` and `target_tokens` and creates a batch that should be greedily
packed to be passed onto the decoder model.
To be used with `pack_samples` as collate_fn
"""

def __init__(self, sequence_length, dataset, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size):
# Keep a copy of input params for later use.
self.dataset = dataset
self.sequence_length = sequence_length
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.data_parallel_size = data_parallel_size
self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size
self.last_batch_size = \
self.total_samples % self.micro_batch_times_data_parallel_size

# Sanity checks.
assert self.total_samples > 0, \
'no sample to consume: {}'.format(self.total_samples)
assert self.micro_batch_size > 0
assert data_parallel_size > 0
assert self.data_parallel_rank < data_parallel_size, \
'data_parallel_rank should be smaller than data size: {}, ' \
'{}'.format(self.data_parallel_rank, data_parallel_size)

def __len__(self):
return self.total_samples

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)

random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

batch = []
batch_count = 0
token_lens = 0
# Last batch if not complete will be dropped.
for idx in idx_range:
tok_len = len(self.dataset[idx]['input_tokens']) + len(self.dataset[idx]['target_tokens'])
if token_lens + tok_len > self.sequence_length:
batch_count += 1
token_lens = 0

if batch_count == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch_count = 0
batch = []
else:
token_lens += tok_len
batch.append(idx)
Loading

0 comments on commit 3d5d151

Please sign in to comment.