Skip to content

Commit

Permalink
Allow PaddingFree to work with DataCollatorForCompletionOnlyLM (#78)
Browse files Browse the repository at this point in the history
* allow for padding_free logic in LM data collator

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* minor fixes to support non-pretok benchmarks

Signed-off-by: 1000850000 user <[email protected]>

* addressed code review

Signed-off-by: 1000850000 user <[email protected]>

* added trl dependency

Signed-off-by: 1000850000 user <[email protected]>

* fixes to installation of aadp

Signed-off-by: 1000850000 user <[email protected]>

* updated orca pf benchmarks

Signed-off-by: 1000850000 user <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: 1000850000 user <[email protected]>
Co-authored-by: 1000850000 user <[email protected]>
  • Loading branch information
fabianlim and achew010 committed Sep 6, 2024
1 parent d082809 commit c8459bc
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 84 deletions.
22 changes: 7 additions & 15 deletions plugins/attention-and-distributed-packing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks
Transformers natively supports padding-free from v4.44.0 [see here](https://github.com/huggingface/transformers/pull/31629). The padding-free plugin will use the transformers library if compatible,
otherwise if `transformers < v4.44.0` the plugin will use an internal implementation instead.

## Native TRL Support for PaddingFree with DataCollatorForCompletionOnlyLM from v0.10.1
Users will be able to use PaddingFree with untokenized data from TRL >= v0.10.1. The flattening of inputs and addition of `position_ids` to the batch
is carried out inside `DataCollatorForCompletionOnlyLM` when keyword `padding_free` is passed to the collator. The plugin uses the TRL library if compatible,
otherwise if `trl < v0.10.1` the plugin will use an internal implementation instead.

If a user still passes in a pretokenized dataset, the plugin will still use `DataCollaterForFlattening` in the `collate_fn`.

## Running Benchmarks

To reproduce the benchmarks, simply run the following commands,
Expand All @@ -30,21 +37,6 @@ Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv

## Known Issues

### Currently Only Supports Pre-Tokenized Dataset

The padding-free plugin currently only works with pre-tokenized datasets, this is because it is currently designed to replace
the data collator from `SFTTrainer` with a custom data collator to manipulate the input to the modified flash attention forward.

There are some cases, the data collator for SFTTrainer will handle the formatting and tokenization from raw text datasets. The plugin
is currently unable to both handle the original data collation and apply its custom data collator over it as the same time. This issue
will be addressed in a future commit to support this case.

In the meantime, the plugin expects the user to provide a pretokenized dataset that
- is formatted with a template for instruct-tuning cases
- is tokenized
- has template labels that are masked to exclude from loss computation
- has eos token appended

### Currenly Only Supports Multipack with Padding-Free

The multipack plugin currently also requires the padding-free plugin to work.
Expand Down
2 changes: 1 addition & 1 deletion plugins/attention-and-distributed-packing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers=[
"Programming Language :: Python :: 3.11",
]

dependencies = ["numba"]
dependencies = ["numba", "trl"]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration_aadp"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from dataclasses import dataclass
from types import MethodType
import warnings

# Third Party
Expand Down Expand Up @@ -58,6 +59,26 @@ def __call__(self, features, return_tensors=None):
return default_data_collator([ret], return_tensors)


# from https://github.com/huggingface/trl/pull/1887
def patch_torch_call_remove_padding(collate_fn):
_old_collate_torch_call = collate_fn.torch_call

def _torch_call_with_remove_pad(self, examples):
batch = _old_collate_torch_call(examples)

# logic for removing padding as found in later TRL releases
attn_mask = batch.pop("attention_mask")
batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

return batch

collate_fn.torch_call = MethodType(_torch_call_with_remove_pad, collate_fn)
return collate_fn


def calculate_token_lengths(dataset, num_processes):
return np.array(
dataset.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
import torch


Expand Down Expand Up @@ -65,11 +66,13 @@ def augmentation(
ModelPatcherTrigger,
)

def _collator_check_seq2seq(collate_fn):
def _collator_check(collate_fn):
# "The padding-free plugin currently only works with a
# `DataCollatorForSeq2Seq` collate_fn,
# otherwise the collation can be unreliable"
return isinstance(collate_fn, DataCollatorForSeq2Seq)
return isinstance(
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
)

# This check is done here to only patch the attention forward
# the PR was merged here
Expand All @@ -92,12 +95,35 @@ def _collator_check_seq2seq(collate_fn):
# Local
from .aadp_utils import DataCollatorWithFlattening

def _collator_replacement_builder(collate_fn):

# in this case, replace seq2seq with flattening collator
if isinstance(collate_fn, DataCollatorForSeq2Seq):
return DataCollatorWithFlattening()

# otherwise it will be DataCollatorForCompletionOnlyLM
# - see _collator_check above
if hasattr(collate_fn, "padding_free"):
# in the later TRL releases there is a padding_free flag
# that turns on extra logic to support padding free. Just
# turn it on
collate_fn.padding_free = True
else:
# otherwise trl version is old, and we need to patch
# in padding free logic
# Local
from .aadp_utils import patch_torch_call_remove_padding

collate_fn = patch_torch_call_remove_padding(collate_fn)

return collate_fn

# setup the collator
AcceleratorPatcher.replace(
"flattening-collator",
AcceleratorPatcherComponent.data_collator,
replacement=DataCollatorWithFlattening(),
pre_requisite_check=_collator_check_seq2seq,
replacement_builder=_collator_replacement_builder,
pre_requisite_check=_collator_check,
)

if _native:
Expand Down
1 change: 0 additions & 1 deletion plugins/attention-and-distributed-packing/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ envlist = py, lint
deps =
pytest>=7
-e {toxinidir}
skip_install = true
commands =

# install the dependencies here to ensure
Expand Down
14 changes: 8 additions & 6 deletions plugins/framework/src/fms_acceleration/accelerator_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,15 @@ def prepare(self, *args, device_placement=None):
# - then we replace
collator_replacement_rule.pre_req_check(dataloader.collate_fn)

# FIXME: for now we just disable the replacement_builder
assert (
collator_replacement_rule.replacement_builder is None
), "Currently, replacement_builder not allowed for data collator"

# Replace the collate_fn in dataloader
dataloader.collate_fn = collator_replacement_rule.replacement
if collator_replacement_rule.replacement is not None:
dataloader.collate_fn = collator_replacement_rule.replacement
else:
dataloader.collate_fn = (
collator_replacement_rule.replacement_builder(
dataloader.collate_fn
)
)

# - special behavior for dataloader replacements
# - need to know if we run the original prepare
Expand Down
3 changes: 3 additions & 0 deletions plugins/framework/src/fms_acceleration/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def install_plugin(
assert len(pkg_name) == 1, "Please specify exactly one plugin to install"
pkg_name = pkg_name[0]

# if toxinidir is specified in path, replace with cwd
pkg_name = pkg_name.format(toxinidir=os.getcwd())

# take the flags
args = [x for x in args if x.startswith("-")]

Expand Down
6 changes: 4 additions & 2 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,11 @@ def prepare_dataset(
)
response_template = self.response_template

if self.kwargs["tokenize"]:
if (
self.kwargs['tokenize']
or (not self.kwargs['tokenize'] and self.kwargs['chat_template'])
):
tokenizer = AutoTokenizer.from_pretrained(model_name)

# for now, if pad_token_id is None, will just do a replacement
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
6 changes: 4 additions & 2 deletions scripts/benchmarks/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def _build_data_formatting_func(
"response_field must be specified if tokenize=True and response_template=None."

def _format(example):
# `nonlocal` is used because the format_fn will be passed to dataset.map and
# `loss_masking` needs to be bounded by `nonlocal` otherwise the spawned
# processes will have no reference to it
nonlocal loss_masking
formatted_and_maybe_tokenized = tokenizer.apply_chat_template(
[example], tokenize=tokenize
)
Expand All @@ -84,8 +88,6 @@ def _format(example):
'labels': [ ignore_index ] * len(formatted_and_maybe_tokenized) + response
}

loss_masking = instruction_mask_loss(tokenizer, response_template)

if not loss_masking:
return {key: formatted_and_maybe_tokenized}
return loss_masking(formatted_and_maybe_tokenized)
Expand Down
Loading

0 comments on commit c8459bc

Please sign in to comment.