Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow PaddingFree to work with DataCollatorForCompletionOnlyLM #78

Merged
merged 6 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions plugins/attention-and-distributed-packing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks
Transformers natively supports padding-free from v4.44.0 [see here](https://github.com/huggingface/transformers/pull/31629). The padding-free plugin will use the transformers library if compatible,
otherwise if `transformers < v4.44.0` the plugin will use an internal implementation instead.

## Native TRL Support for PaddingFree with DataCollatorForCompletionOnlyLM from v0.10.1
Users will be able to use PaddingFree with untokenized data from TRL >= v0.10.1. The flattening of inputs and addition of `position_ids` to the batch
is carried out inside `DataCollatorForCompletionOnlyLM` when keyword `padding_free` is passed to the collator. The plugin uses the TRL library if compatible,
otherwise if `trl < v0.10.1` the plugin will use an internal implementation instead.

If a user still passes in a pretokenized dataset, the plugin will still use `DataCollaterForFlattening` in the `collate_fn`.

## Running Benchmarks

To reproduce the benchmarks, simply run the following commands,
Expand All @@ -30,21 +37,6 @@ Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv

## Known Issues

### Currently Only Supports Pre-Tokenized Dataset

The padding-free plugin currently only works with pre-tokenized datasets, this is because it is currently designed to replace
the data collator from `SFTTrainer` with a custom data collator to manipulate the input to the modified flash attention forward.

There are some cases, the data collator for SFTTrainer will handle the formatting and tokenization from raw text datasets. The plugin
is currently unable to both handle the original data collation and apply its custom data collator over it as the same time. This issue
will be addressed in a future commit to support this case.

In the meantime, the plugin expects the user to provide a pretokenized dataset that
- is formatted with a template for instruct-tuning cases
- is tokenized
- has template labels that are masked to exclude from loss computation
- has eos token appended

### Currenly Only Supports Multipack with Padding-Free

The multipack plugin currently also requires the padding-free plugin to work.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from dataclasses import dataclass
from types import MethodType
import warnings

# Third Party
Expand Down Expand Up @@ -58,6 +59,26 @@ def __call__(self, features, return_tensors=None):
return default_data_collator([ret], return_tensors)


# from https://github.com/huggingface/trl/pull/1887
def patch_torch_call_remove_padding(collate_fn):
_old_collate_torch_call = collate_fn.torch_call

def _torch_call_with_remove_pad(self, examples):
batch = _old_collate_torch_call(examples)

# logic for removing padding as found in later TRL releases
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

return batch

collate_fn.torch_call = MethodType(_torch_call_with_remove_pad, collate_fn)
return collate_fn


def calculate_token_lengths(dataset, num_processes):
return np.array(
dataset.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
import torch


Expand Down Expand Up @@ -65,11 +66,13 @@ def augmentation(
ModelPatcherTrigger,
)

def _collator_check_seq2seq(collate_fn):
def _collator_check(collate_fn):
# "The padding-free plugin currently only works with a
# `DataCollatorForSeq2Seq` collate_fn,
# otherwise the collation can be unreliable"
return isinstance(collate_fn, DataCollatorForSeq2Seq)
return isinstance(
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
)

# This check is done here to only patch the attention forward
# the PR was merged here
Expand All @@ -92,12 +95,35 @@ def _collator_check_seq2seq(collate_fn):
# Local
from .aadp_utils import DataCollatorWithFlattening

def _collator_replacement_builder(collate_fn):

# in this case, replace seq2seq with flattening collator
if isinstance(collate_fn, DataCollatorForSeq2Seq):
return DataCollatorWithFlattening()

# otherwise it will be DataCollatorForCompletionOnlyLM
# - see _collator_check above
if hasattr(collate_fn, "padding_free"):
# in the later TRL releases there is a padding_free flag
# that turns on extra logic to support padding free. Just
# turn it on
collate_fn.padding_free = True
else:
# otherwise trl version is old, and we need to patch
# in padding free logic
# Local
from .aadp_utils import patch_torch_call_remove_padding

collate_fn = patch_torch_call_remove_padding(collate_fn)

return collate_fn

# setup the collator
AcceleratorPatcher.replace(
"flattening-collator",
AcceleratorPatcherComponent.data_collator,
replacement=DataCollatorWithFlattening(),
pre_requisite_check=_collator_check_seq2seq,
replacement_builder=_collator_replacement_builder,
pre_requisite_check=_collator_check,
)

if _native:
Expand Down
12 changes: 6 additions & 6 deletions plugins/framework/src/fms_acceleration/accelerator_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def prepare(self, *args, device_placement=None):
# - then we replace
collator_replacement_rule.pre_req_check(dataloader.collate_fn)

# FIXME: for now we just disable the replacement_builder
assert (
collator_replacement_rule.replacement_builder is None
), "Currently, replacement_builder not allowed for data collator"

# Replace the collate_fn in dataloader
dataloader.collate_fn = collator_replacement_rule.replacement
if collator_replacement_rule.replacement is not None:
dataloader.collate_fn = collator_replacement_rule.replacement
else:
dataloader.collate_fn = collator_replacement_rule.replacement_builder(
dataloader.collate_fn
)

# - special behavior for dataloader replacements
# - need to know if we run the original prepare
Expand Down
6 changes: 3 additions & 3 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def prepare_dataset(
)
response_template = self.response_template

if self.kwargs["tokenize"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# pass in tokenizer to use apply_chat_templates
tokenizer = AutoTokenizer.from_pretrained(model_name)

if self.kwargs["tokenize"]:
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
# for now, if pad_token_id is None, will just do a replacement
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand All @@ -218,7 +219,6 @@ def prepare_dataset(
re.sub(r"[/-]", "_", model_name),
)
else:
tokenizer = None
save_path = DATA_JSON_NAME.format("all")

# get the full path
Expand Down
1 change: 1 addition & 0 deletions scripts/benchmarks/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _build_data_formatting_func(
"response_field must be specified if tokenize=True and response_template=None."

def _format(example):
nonlocal loss_masking # reference to variable in _build_data_formatting_func
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
formatted_and_maybe_tokenized = tokenizer.apply_chat_template(
[example], tokenize=tokenize
)
Expand Down
26 changes: 15 additions & 11 deletions scripts/benchmarks/scenarios-orca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
# allows to configure many different datasets
# - there is an option of setting tokenize is True or False

# NOTE: on tokenization
# if tokenize = True then its a pretokenization flow, then below set
# - response_template: null
# - dataset_text_field: null
# otherwise if tokenize = False, then do not set the above to null
data_processing:
dataset_name: microsoft/orca-math-word-problems-200k
chat_template: |
Expand All @@ -30,9 +35,8 @@ data_processing:
ASSISTANT:
{{ message['answer'] }}
{%- endfor %}
dataset_split: "train[:2000]"
tokenize: True
response_template: "\n\nASSISTANT:"
dataset_split: "train[:8000]"
tokenize: false

# scenarios
scenarios:
Expand All @@ -45,8 +49,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: padding-free
framework_config:
Expand All @@ -60,8 +64,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: accelerated-peft-bnb
framework_config:
Expand All @@ -83,8 +87,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: accelerated-peft-gptq
framework_config:
Expand All @@ -106,5 +110,5 @@ scenarios:
packing: False
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"
Loading