Skip to content

Commit

Permalink
minor fixes to support non-pretok benchmarks
Browse files Browse the repository at this point in the history
Signed-off-by: 1000850000 user <[email protected]>
  • Loading branch information
achew010 committed Aug 30, 2024
1 parent 5fe8dfd commit 0dc0562
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

# Standard
from dataclasses import dataclass
import warnings
from types import MethodType
import warnings

# Third Party
from transformers import DefaultDataCollator, default_data_collator
Expand Down Expand Up @@ -58,6 +58,7 @@ def __call__(self, features, return_tensors=None):
ret["labels"] += [-100] + feature["input_ids"][1:]
return default_data_collator([ret], return_tensors)


# from https://github.com/huggingface/trl/pull/1887
def patch_torch_call_remove_padding(collate_fn):
_old_collate_torch_call = collate_fn.torch_call
Expand All @@ -77,6 +78,7 @@ def _torch_call_with_remove_pad(self, examples):
collate_fn.torch_call = MethodType(_torch_call_with_remove_pad, collate_fn)
return collate_fn


def calculate_token_lengths(dataset, num_processes):
return np.array(
dataset.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from transformers import DataCollatorForSeq2Seq, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM
from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error
import torch


Expand Down Expand Up @@ -70,7 +70,9 @@ def _collator_check(collate_fn):
# "The padding-free plugin currently only works with a
# `DataCollatorForSeq2Seq` collate_fn,
# otherwise the collation can be unreliable"
return isinstance(collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM))
return isinstance(
collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM)
)

# This check is done here to only patch the attention forward
# the PR was merged here
Expand Down Expand Up @@ -98,18 +100,20 @@ def _collator_replacement_builder(collate_fn):
# in this case, replace seq2seq with flattening collator
if isinstance(collate_fn, DataCollatorForSeq2Seq):
return DataCollatorWithFlattening()

# otherwise it will be DataCollatorForCompletionOnlyLM
# - see _collator_check above
if hasattr(collate_fn, 'padding_free'):
if hasattr(collate_fn, "padding_free"):
# in the later TRL releases there is a padding_free flag
# that turns on extra logic to support padding free. Just
# that turns on extra logic to support padding free. Just
# turn it on
collate_fn.padding_free = True
else:
# otherwise trl version is old, and we need to patch
# otherwise trl version is old, and we need to patch
# in padding free logic
# Local
from .aadp_utils import patch_torch_call_remove_padding

collate_fn = patch_torch_call_remove_padding(collate_fn)

return collate_fn
Expand Down
12 changes: 6 additions & 6 deletions plugins/framework/src/fms_acceleration/accelerator_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def prepare(self, *args, device_placement=None):
# - then we replace
collator_replacement_rule.pre_req_check(dataloader.collate_fn)

# FIXME: for now we just disable the replacement_builder
assert (
collator_replacement_rule.replacement_builder is None
), "Currently, replacement_builder not allowed for data collator"

# Replace the collate_fn in dataloader
dataloader.collate_fn = collator_replacement_rule.replacement
if collator_replacement_rule.replacement is not None:
dataloader.collate_fn = collator_replacement_rule.replacement
else:
dataloader.collate_fn = collator_replacement_rule.replacement_builder(
dataloader.collate_fn
)

# - special behavior for dataloader replacements
# - need to know if we run the original prepare
Expand Down
8 changes: 4 additions & 4 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,10 @@ def prepare_dataset(
)
response_template = self.response_template

if self.kwargs["tokenize"]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# pass in tokenizer for to use apply_chat_templates
tokenizer = AutoTokenizer.from_pretrained(model_name)

if self.kwargs["tokenize"]:
# for now, if pad_token_id is None, will just do a replacement
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand All @@ -218,7 +219,6 @@ def prepare_dataset(
re.sub(r"[/-]", "_", model_name),
)
else:
tokenizer = None
save_path = DATA_JSON_NAME.format("all")

# get the full path
Expand All @@ -238,7 +238,7 @@ def prepare_dataset(

print(f"Preparing dataset '{save_path}'")

# call the map
# call the mapk
ds = self.dataset_split.map(format_fn, **kwargs)

# save it
Expand Down
1 change: 1 addition & 0 deletions scripts/benchmarks/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _build_data_formatting_func(
"response_field must be specified if tokenize=True and response_template=None."

def _format(example):
nonlocal loss_masking # reference to variable in _build_data_formatting_func
formatted_and_maybe_tokenized = tokenizer.apply_chat_template(
[example], tokenize=tokenize
)
Expand Down
21 changes: 10 additions & 11 deletions scripts/benchmarks/scenarios-orca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ data_processing:
ASSISTANT:
{{ message['answer'] }}
{%- endfor %}
dataset_split: "train[:2000]"
tokenize: True
response_template: "\n\nASSISTANT:"
dataset_split: "train[:8000]"
tokenize: false

# scenarios
scenarios:
Expand All @@ -50,8 +49,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: padding-free
framework_config:
Expand All @@ -65,8 +64,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: accelerated-peft-bnb
framework_config:
Expand All @@ -88,8 +87,8 @@ scenarios:
packing: False
model_name_or_path:
- 'mistralai/Mistral-7B-v0.1'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

- name: accelerated-peft-gptq
framework_config:
Expand All @@ -111,5 +110,5 @@ scenarios:
packing: False
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
response_template: null
dataset_text_field: null
dataset_text_field: 'output'
response_template: "\n\nASSISTANT:"

0 comments on commit 0dc0562

Please sign in to comment.