Skip to content

Commit

Permalink
Add support for weighted train
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Jul 6, 2022
1 parent 3d5d151 commit 66ce0cf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
37 changes: 35 additions & 2 deletions finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Multitask Finetuning T0"""

from multiprocessing.sharedctypes import Value
import torch

from megatron import get_args, get_tokenizer, print_rank_0, mpu
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets
from megatron.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.enums import PositionEmbeddingType, AttnMaskType
from megatron.model import GPTModelPipe
from megatron.training import pretrain
Expand Down Expand Up @@ -123,6 +122,40 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
seed=args.seed,
skip_warmup=(not args.mmap_warmup)
)
# Option 2 of data loading using --(train|valid|test)-weighted-split-paths
elif args.train_weighted_split_paths:
assigned_train_valid_test = []
if args.train_weighted_split_paths is not None:
train_ds = []
assigned_train_valid_test.append("train")
if args.valid_weighted_split_paths is not None:
valid_ds = []
assigned_train_valid_test.append("valid")
if args.test_weighted_split_paths is not None:
test_ds = []
assigned_train_valid_test.append("test")

for s in assigned_train_valid_test:
data_groups = zip(eval(f"args.{s}_weighted_split_paths"),
eval(f"args.{s}_weighted_split_weights"),
eval(f"args.{s}_weighted_split_splits"),
eval(f"args.{s}_weighted_split_names"))
for paths, weights, splits, name in data_groups:
d = build_dataset_group(
dataset_group_name=name,
paths=paths,
weights=weights,
splits=splits,
data_impl=args.data_impl,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length + 1,
pad_token=tokenizer.pad,
eos_token=tokenizer.eos,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_valid_test=s
)
eval(f"{s}_ds").append(d)
else:
raise NotImplementedError("No dataloading argument passed")

Expand Down
4 changes: 3 additions & 1 deletion megatron/data/decoder_packed_mtf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import numpy as np
import torch

from megatron import print_rank_0, mpu
from megatron import print_rank_0, mpu, logging
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_, \
get_train_valid_test_split_
from megatron.data.mtf_dataset import MTFDataset
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset

logger = logging.get_logger(__name__)

def build_train_valid_test_datasets(
data_prefix,
Expand Down Expand Up @@ -487,6 +488,7 @@ def _build_sample_idx(mtf_dataset, document_ids, seq_length, row_offset, old_sam
# TODO @thomasw21 handle the case where a single sample cannot fit inside a row. We can
# - silently skip that value [currently implemented]
# - truncate to `seq_length`, and keep the right part
logger.warning(f"Skipping sample id={document_id}. Maximum sequence length: {seq_length}, sample length: {tok_len}")
current_sample_start = current_sample_end + 1 # skipping
row_length = 0
continue
Expand Down

0 comments on commit 66ce0cf

Please sign in to comment.