Skip to content

Commit

Permalink
allow for padding_free logic in LM data collator
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Aug 30, 2024
1 parent 32918eb commit 5fe8dfd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Standard
from dataclasses import dataclass
import warnings
from types import MethodType

# Third Party
from transformers import DefaultDataCollator, default_data_collator
Expand Down Expand Up @@ -57,6 +58,24 @@ def __call__(self, features, return_tensors=None):
ret["labels"] += [-100] + feature["input_ids"][1:]
return default_data_collator([ret], return_tensors)

# from https://github.com/huggingface/trl/pull/1887
def patch_torch_call_remove_padding(collate_fn):
_old_collate_torch_call = collate_fn.torch_call

def _torch_call_with_remove_pad(self, examples):
batch = _old_collate_torch_call(examples)

# logic for removing padding as found in later TRL releases
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

return batch

collate_fn.torch_call = MethodType(_torch_call_with_remove_pad, collate_fn)
return collate_fn

def calculate_token_lengths(dataset, num_processes):
return np.array(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM
import torch


Expand Down Expand Up @@ -65,11 +66,11 @@ def augmentation(
ModelPatcherTrigger,
)

def _collator_check_seq2seq(collate_fn):
def _collator_check(collate_fn):
# "The padding-free plugin currently only works with a
# `DataCollatorForSeq2Seq` collate_fn,
# otherwise the collation can be unreliable"
return isinstance(collate_fn, DataCollatorForSeq2Seq)
return isinstance(collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM))

# This check is done here to only patch the attention forward
# the PR was merged here
Expand All @@ -92,12 +93,33 @@ def _collator_check_seq2seq(collate_fn):
# Local
from .aadp_utils import DataCollatorWithFlattening

def _collator_replacement_builder(collate_fn):

# in this case, replace seq2seq with flattening collator
if isinstance(collate_fn, DataCollatorForSeq2Seq):
return DataCollatorWithFlattening()

# otherwise it will be DataCollatorForCompletionOnlyLM
# - see _collator_check above
if hasattr(collate_fn, 'padding_free'):
# in the later TRL releases there is a padding_free flag
# that turns on extra logic to support padding free. Just
# turn it on
collate_fn.padding_free = True
else:
# otherwise trl version is old, and we need to patch
# in padding free logic
from .aadp_utils import patch_torch_call_remove_padding
collate_fn = patch_torch_call_remove_padding(collate_fn)

return collate_fn

# setup the collator
AcceleratorPatcher.replace(
"flattening-collator",
AcceleratorPatcherComponent.data_collator,
replacement=DataCollatorWithFlattening(),
pre_requisite_check=_collator_check_seq2seq,
replacement_builder=_collator_replacement_builder,
pre_requisite_check=_collator_check,
)

if _native:
Expand Down
5 changes: 5 additions & 0 deletions scripts/benchmarks/scenarios-orca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
# allows to configure many different datasets
# - there is an option of setting tokenize is True or False

# NOTE: on tokenization
# if tokenize = True then its a pretokenization flow, then below set
# - response_template: null
# - dataset_text_field: null
# otherwise if tokenize = False, then do not set the above to null
data_processing:
dataset_name: microsoft/orca-math-word-problems-200k
chat_template: |
Expand Down

0 comments on commit 5fe8dfd

Please sign in to comment.