From c8459bc95ea9e964203cb5860f14501443ff409f Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 5 Sep 2024 13:52:53 +0800 Subject: [PATCH] Allow PaddingFree to work with DataCollatorForCompletionOnlyLM (#78) * allow for padding_free logic in LM data collator Signed-off-by: Yu Chin Fabian Lim * minor fixes to support non-pretok benchmarks Signed-off-by: 1000850000 user * addressed code review Signed-off-by: 1000850000 user * added trl dependency Signed-off-by: 1000850000 user * fixes to installation of aadp Signed-off-by: 1000850000 user * updated orca pf benchmarks Signed-off-by: 1000850000 user --------- Signed-off-by: Yu Chin Fabian Lim Signed-off-by: 1000850000 user Co-authored-by: 1000850000 user --- .../README.md | 22 ++--- .../pyproject.toml | 2 +- .../src/fms_acceleration_aadp/aadp_utils.py | 21 +++++ .../framework_plugin_padding_free.py | 34 ++++++- .../attention-and-distributed-packing/tox.ini | 1 - .../fms_acceleration/accelerator_patcher.py | 14 +-- plugins/framework/src/fms_acceleration/cli.py | 3 + scripts/benchmarks/benchmark.py | 6 +- scripts/benchmarks/data_processing.py | 6 +- scripts/benchmarks/refs_orca/a100_80gb_pf.csv | 82 ++++++++--------- .../benchmarks/refs_orca/requirements_pf.txt | 88 +++++++++++++++++++ scripts/benchmarks/scenarios-orca.yaml | 26 +++--- tox.ini | 2 +- 13 files changed, 223 insertions(+), 84 deletions(-) create mode 100644 scripts/benchmarks/refs_orca/requirements_pf.txt diff --git a/plugins/attention-and-distributed-packing/README.md b/plugins/attention-and-distributed-packing/README.md index 1da9b612..a246df79 100644 --- a/plugins/attention-and-distributed-packing/README.md +++ b/plugins/attention-and-distributed-packing/README.md @@ -18,6 +18,13 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks Transformers natively supports padding-free from v4.44.0 [see here](https://github.com/huggingface/transformers/pull/31629). The padding-free plugin will use the transformers library if compatible, otherwise if `transformers < v4.44.0` the plugin will use an internal implementation instead. +## Native TRL Support for PaddingFree with DataCollatorForCompletionOnlyLM from v0.10.1 +Users will be able to use PaddingFree with untokenized data from TRL >= v0.10.1. The flattening of inputs and addition of `position_ids` to the batch +is carried out inside `DataCollatorForCompletionOnlyLM` when keyword `padding_free` is passed to the collator. The plugin uses the TRL library if compatible, +otherwise if `trl < v0.10.1` the plugin will use an internal implementation instead. + +If a user still passes in a pretokenized dataset, the plugin will still use `DataCollaterForFlattening` in the `collate_fn`. + ## Running Benchmarks To reproduce the benchmarks, simply run the following commands, @@ -30,21 +37,6 @@ Reproduce [MultiPack on A100 80GB](scripts/benchmarks/refs_orca/a100_80gb_mp.csv ## Known Issues -### Currently Only Supports Pre-Tokenized Dataset - -The padding-free plugin currently only works with pre-tokenized datasets, this is because it is currently designed to replace -the data collator from `SFTTrainer` with a custom data collator to manipulate the input to the modified flash attention forward. - -There are some cases, the data collator for SFTTrainer will handle the formatting and tokenization from raw text datasets. The plugin -is currently unable to both handle the original data collation and apply its custom data collator over it as the same time. This issue -will be addressed in a future commit to support this case. - -In the meantime, the plugin expects the user to provide a pretokenized dataset that -- is formatted with a template for instruct-tuning cases -- is tokenized -- has template labels that are masked to exclude from loss computation -- has eos token appended - ### Currenly Only Supports Multipack with Padding-Free The multipack plugin currently also requires the padding-free plugin to work. diff --git a/plugins/attention-and-distributed-packing/pyproject.toml b/plugins/attention-and-distributed-packing/pyproject.toml index 3675fa26..166625f4 100644 --- a/plugins/attention-and-distributed-packing/pyproject.toml +++ b/plugins/attention-and-distributed-packing/pyproject.toml @@ -23,7 +23,7 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] -dependencies = ["numba"] +dependencies = ["numba", "trl"] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_aadp"] diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py index 10d6e93a..6ccaf88f 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/aadp_utils.py @@ -14,6 +14,7 @@ # Standard from dataclasses import dataclass +from types import MethodType import warnings # Third Party @@ -58,6 +59,26 @@ def __call__(self, features, return_tensors=None): return default_data_collator([ret], return_tensors) +# from https://github.com/huggingface/trl/pull/1887 +def patch_torch_call_remove_padding(collate_fn): + _old_collate_torch_call = collate_fn.torch_call + + def _torch_call_with_remove_pad(self, examples): + batch = _old_collate_torch_call(examples) + + # logic for removing padding as found in later TRL releases + attn_mask = batch.pop("attention_mask") + batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0) + batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1 + batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0) + batch["labels"][batch["position_ids"] == 0] = self.ignore_index + + return batch + + collate_fn.torch_call = MethodType(_torch_call_with_remove_pad, collate_fn) + return collate_fn + + def calculate_token_lengths(dataset, num_processes): return np.array( dataset.map( diff --git a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py index dc60e5aa..f2dbed87 100644 --- a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py +++ b/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_padding_free.py @@ -20,6 +20,7 @@ from fms_acceleration import AccelerationPlugin from peft import LoraConfig from transformers import DataCollatorForSeq2Seq, TrainingArguments +from trl import DataCollatorForCompletionOnlyLM # pylint: disable=import-error import torch @@ -65,11 +66,13 @@ def augmentation( ModelPatcherTrigger, ) - def _collator_check_seq2seq(collate_fn): + def _collator_check(collate_fn): # "The padding-free plugin currently only works with a # `DataCollatorForSeq2Seq` collate_fn, # otherwise the collation can be unreliable" - return isinstance(collate_fn, DataCollatorForSeq2Seq) + return isinstance( + collate_fn, (DataCollatorForSeq2Seq, DataCollatorForCompletionOnlyLM) + ) # This check is done here to only patch the attention forward # the PR was merged here @@ -92,12 +95,35 @@ def _collator_check_seq2seq(collate_fn): # Local from .aadp_utils import DataCollatorWithFlattening + def _collator_replacement_builder(collate_fn): + + # in this case, replace seq2seq with flattening collator + if isinstance(collate_fn, DataCollatorForSeq2Seq): + return DataCollatorWithFlattening() + + # otherwise it will be DataCollatorForCompletionOnlyLM + # - see _collator_check above + if hasattr(collate_fn, "padding_free"): + # in the later TRL releases there is a padding_free flag + # that turns on extra logic to support padding free. Just + # turn it on + collate_fn.padding_free = True + else: + # otherwise trl version is old, and we need to patch + # in padding free logic + # Local + from .aadp_utils import patch_torch_call_remove_padding + + collate_fn = patch_torch_call_remove_padding(collate_fn) + + return collate_fn + # setup the collator AcceleratorPatcher.replace( "flattening-collator", AcceleratorPatcherComponent.data_collator, - replacement=DataCollatorWithFlattening(), - pre_requisite_check=_collator_check_seq2seq, + replacement_builder=_collator_replacement_builder, + pre_requisite_check=_collator_check, ) if _native: diff --git a/plugins/attention-and-distributed-packing/tox.ini b/plugins/attention-and-distributed-packing/tox.ini index 2eff9357..f95ee738 100644 --- a/plugins/attention-and-distributed-packing/tox.ini +++ b/plugins/attention-and-distributed-packing/tox.ini @@ -5,7 +5,6 @@ envlist = py, lint deps = pytest>=7 -e {toxinidir} -skip_install = true commands = # install the dependencies here to ensure diff --git a/plugins/framework/src/fms_acceleration/accelerator_patcher.py b/plugins/framework/src/fms_acceleration/accelerator_patcher.py index 6864e1a1..4e49ff46 100644 --- a/plugins/framework/src/fms_acceleration/accelerator_patcher.py +++ b/plugins/framework/src/fms_acceleration/accelerator_patcher.py @@ -243,13 +243,15 @@ def prepare(self, *args, device_placement=None): # - then we replace collator_replacement_rule.pre_req_check(dataloader.collate_fn) - # FIXME: for now we just disable the replacement_builder - assert ( - collator_replacement_rule.replacement_builder is None - ), "Currently, replacement_builder not allowed for data collator" - # Replace the collate_fn in dataloader - dataloader.collate_fn = collator_replacement_rule.replacement + if collator_replacement_rule.replacement is not None: + dataloader.collate_fn = collator_replacement_rule.replacement + else: + dataloader.collate_fn = ( + collator_replacement_rule.replacement_builder( + dataloader.collate_fn + ) + ) # - special behavior for dataloader replacements # - need to know if we run the original prepare diff --git a/plugins/framework/src/fms_acceleration/cli.py b/plugins/framework/src/fms_acceleration/cli.py index 687a8c28..8e28fff2 100644 --- a/plugins/framework/src/fms_acceleration/cli.py +++ b/plugins/framework/src/fms_acceleration/cli.py @@ -42,6 +42,9 @@ def install_plugin( assert len(pkg_name) == 1, "Please specify exactly one plugin to install" pkg_name = pkg_name[0] + # if toxinidir is specified in path, replace with cwd + pkg_name = pkg_name.format(toxinidir=os.getcwd()) + # take the flags args = [x for x in args if x.startswith("-")] diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index d21b2fbe..b8c4915d 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -206,9 +206,11 @@ def prepare_dataset( ) response_template = self.response_template - if self.kwargs["tokenize"]: + if ( + self.kwargs['tokenize'] + or (not self.kwargs['tokenize'] and self.kwargs['chat_template']) + ): tokenizer = AutoTokenizer.from_pretrained(model_name) - # for now, if pad_token_id is None, will just do a replacement if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/scripts/benchmarks/data_processing.py b/scripts/benchmarks/data_processing.py index 3125ca44..f763cbdb 100644 --- a/scripts/benchmarks/data_processing.py +++ b/scripts/benchmarks/data_processing.py @@ -64,6 +64,10 @@ def _build_data_formatting_func( "response_field must be specified if tokenize=True and response_template=None." def _format(example): + # `nonlocal` is used because the format_fn will be passed to dataset.map and + # `loss_masking` needs to be bounded by `nonlocal` otherwise the spawned + # processes will have no reference to it + nonlocal loss_masking formatted_and_maybe_tokenized = tokenizer.apply_chat_template( [example], tokenize=tokenize ) @@ -84,8 +88,6 @@ def _format(example): 'labels': [ ignore_index ] * len(formatted_and_maybe_tokenized) + response } - loss_masking = instruction_mask_loss(tokenizer, response_template) - if not loss_masking: return {key: formatted_and_maybe_tokenized} return loss_masking(formatted_and_maybe_tokenized) diff --git a/scripts/benchmarks/refs_orca/a100_80gb_pf.csv b/scripts/benchmarks/refs_orca/a100_80gb_pf.csv index c055de88..a9aec32b 100644 --- a/scripts/benchmarks/refs_orca/a100_80gb_pf.csv +++ b/scripts/benchmarks/refs_orca/a100_80gb_pf.csv @@ -1,41 +1,41 @@ -fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -,none,2e-5,,,77527.0,72468863488,43468103168,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.36955297470092774,362.6052,5.516,0.689,2383.937 -,none,2e-5,,,54982.0,38899449344,28984259072,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.3714465112686157,320.0386,6.249,0.781,1102.661 -,none,2e-5,,,76911.0,72465051648,43467904512,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.3604792728424072,356.7945,5.605,0.35,2933.722 -,none,2e-5,,,58821.0,42812754432,28984268288,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.3584610557556152,231.2066,8.65,0.541,1930.1 -,aadp-padding-free,2e-5,,,71665.0,72470858752,43468621312,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.3722459201812744,291.8317,6.853,0.857,1874.076 -,aadp-padding-free,2e-5,,,53231.0,38670022656,28984259072,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.37080725002288817,305.9316,6.537,0.817,886.777 -,aadp-padding-free,2e-5,,,75107.0,72452382720,43467883008,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.365696475982666,213.0649,9.387,0.587,2566.894 -,aadp-padding-free,2e-5,,,54301.0,39207462400,28984429056,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.36816050434112546,176.9167,11.305,0.707,1584.933 -True,accelerated-peft-bnb,2e-4,16,0.1,13927.0,10322870272,4306494976,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3517864627838135,341.8255,5.851,0.731,2528.858 -True,accelerated-peft-bnb,2e-4,16,0.1,7789.0,6435678720,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3503192720413208,767.1497,2.607,0.326,460.007 -True,accelerated-peft-bnb,2e-4,16,0.1,23271.0,16366288896,4306296320,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3807634754180908,317.0715,6.308,0.394,3301.262 -True,accelerated-peft-bnb,2e-4,16,0.1,11944.0,9927788032,2244423168,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3727553825378418,392.5799,5.095,0.318,1136.716 -True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,7905.0,6284557312,4306259456,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3511854820251465,342.8506,5.833,0.729,1595.199 -True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,6277.0,4901515776,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.353749755859375,772.1829,2.59,0.324,351.333 -True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,11113.0,6883823104,4307159552,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3688219118118286,210.9553,9.481,0.593,2592.564 -True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,6972.0,5302929408,2244420096,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.37023702239990236,398.2732,5.022,0.314,704.042 -True,accelerated-peft-bnb-foak,2e-4,16,0.1,12751.0,9080075776,4306494976,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3516845245361328,314.1839,6.366,0.796,2751.344 -True,accelerated-peft-bnb-foak,2e-4,16,0.1,7943.0,6377985024,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3619532585144043,348.4855,5.739,0.717,1012.65 -True,accelerated-peft-bnb-foak,2e-4,16,0.1,21411.0,13907961856,4306296320,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.36872802734375,272.9149,7.328,0.458,3835.394 -True,accelerated-peft-bnb-foak,2e-4,16,0.1,11558.0,9761232384,2244423168,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3788349189758301,190.089,10.521,0.658,2347.595 -True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,7039.0,5898254848,4306259456,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.5028951711654663,245.3348,8.152,1.019,2229.26 -True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,6407.0,4856763904,2244413952,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.4248905544281006,307.9769,6.494,0.812,880.888 -True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,9223.0,6381574656,4306274816,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.7981027908325196,167.693,11.927,0.745,3261.406 -True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,6950.0,5258480640,2244616704,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.6221811065673828,155.5657,12.856,0.804,1802.46 -True,accelerated-peft-autogptq,2e-4,16,0.1,13269.0,10353179648,4336804352,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3607017431259155,312.6812,6.396,0.8,2764.567 -True,accelerated-peft-autogptq,2e-4,16,0.1,8276.0,6452275200,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3964675521850586,767.574,2.606,0.326,459.752 -True,accelerated-peft-autogptq,2e-4,16,0.1,23229.0,16396598272,4336605696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.37328079795837404,305.2517,6.552,0.409,3429.091 -True,accelerated-peft-autogptq,2e-4,16,0.1,12347.0,9945317888,2261101056,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.37490358924865724,388.2367,5.151,0.322,1149.433 -True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,7307.0,6314883072,4336585216,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.4573829107284546,304.2352,6.574,0.822,1797.672 -True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,6791.0,4916809216,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.4595833406448364,751.4294,2.662,0.333,361.036 -True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,11083.0,6914132480,4337468928,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.5659423522949218,194.3331,10.292,0.643,2814.318 -True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,7347.0,5320475648,2261114368,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.38012803745269774,386.4168,5.176,0.323,725.644 -True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12825.0,9110761472,4337180672,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.47571081733703613,287.4231,6.958,0.87,3007.51 -True,accelerated-peft-autogptq-foak,2e-4,16,0.1,8359.0,6395514880,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.42341567993164064,359.3769,5.565,0.696,981.961 -True,accelerated-peft-autogptq-foak,2e-4,16,0.1,21441.0,13938271232,4336605696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3734574165344238,260.8468,7.667,0.479,4012.837 -True,accelerated-peft-autogptq-foak,2e-4,16,0.1,12080.0,9778762240,2261101056,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.4276853837966919,186.9887,10.696,0.668,2386.519 -True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,6895.0,5930145792,4336568832,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.5187568559646606,225.1688,8.882,1.11,2428.911 -True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,6816.0,4874150400,2261091840,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.43749408531188966,311.8489,6.413,0.802,869.95 -True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,9135.0,6411245056,4337468928,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.7783932838439941,157.2511,12.719,0.795,3477.972 -True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,7477.0,5276182528,2261097984,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.6172291069030762,158.2566,12.638,0.79,1771.812 +epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +1.0,,none,2e-5,,,80909.0,72468057600.0,43467546624.0,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.325950190782547,1503.8936,5.32,0.665,2368.247 +1.0,,none,2e-5,,,56868.0,39993946624.0,28984043520.0,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.3231666214466095,1285.8919,6.221,0.778,1147.187 +1.0,,none,2e-5,,,80331.0,72470203904.0,43467592704.0,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.3189650802612305,1455.3148,5.497,0.344,2881.774 +1.0,,none,2e-5,,,65559.0,42135790080.0,28984861696.0,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.31443149185180663,936.3469,8.544,0.534,1888.505 +1.0,,aadp-padding-free,2e-5,,,74353.0,72468633600.0,43468121088.0,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.3279020154476166,1155.985,6.921,0.865,2001.795 +1.0,,aadp-padding-free,2e-5,,,54675.0,39052670976.0,28984342016.0,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.3226120362281799,1197.1233,6.683,0.835,972.229 +1.0,,aadp-padding-free,2e-5,,,78087.0,72462073344.0,43467552768.0,mistralai/Mistral-7B-v0.1,1,,8,,,float16,0.3190155177116394,903.3507,8.856,0.553,2561.624 +1.0,,aadp-padding-free,2e-5,,,55821.0,39294218752.0,28984041984.0,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.3147675342559814,727.4014,10.998,0.687,1577.283 +1.0,True,accelerated-peft-bnb,2e-4,16,0.1,17729.0,11081566208.0,4306167808.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.30459334325790405,1401.894,5.707,0.713,2540.557 +1.0,True,accelerated-peft-bnb,2e-4,16,0.1,9844.0,6439074816.0,2244927488.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3045877537727356,3060.9332,2.614,0.327,481.931 +1.0,True,accelerated-peft-bnb,2e-4,16,0.1,30463.0,17775202304.0,4306213888.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.31843464851379394,1294.5409,6.18,0.386,3239.672 +1.0,True,accelerated-peft-bnb,2e-4,16,0.1,17129.0,9156645888.0,2244385792.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3206515383720398,1635.9996,4.89,0.306,1080.866 +1.0,True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,10809.0,6460631552.0,4306152448.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.30525984621047975,1328.0551,6.024,0.753,1742.431 +1.0,True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,6933.0,4908467712.0,2244365824.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3044490315914154,2997.8816,2.669,0.334,388.233 +1.0,True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,13933.0,7270007808.0,4306526208.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.31760683155059816,853.203,9.376,0.586,2712.186 +1.0,True,accelerated-peft-bnb-padding-free,2e-4,16,0.1,7693.0,5431026688.0,2244368896.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3171143732070923,1511.9468,5.291,0.331,758.835 +1.0,True,accelerated-peft-bnb-foak,2e-4,16,0.1,16365.0,9696364032.0,4306970624.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.30473455238342284,1271.1661,6.293,0.787,2801.83 +1.0,True,accelerated-peft-bnb-foak,2e-4,16,0.1,9251.0,6383434752.0,2244370432.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3177356355190277,1337.038,5.983,0.748,1103.303 +1.0,True,accelerated-peft-bnb-foak,2e-4,16,0.1,27769.0,15030454784.0,4306213888.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3168714098930359,1113.9074,7.182,0.449,3765.024 +1.0,True,accelerated-peft-bnb-foak,2e-4,16,0.1,15724.0,9015170048.0,2244467712.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3279157266616821,760.7837,10.515,0.657,2324.308 +1.0,True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,9043.0,6051350016.0,4306152448.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3044880964756012,1219.9642,6.558,0.82,1896.814 +1.0,True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,6847.0,4865420800.0,2244365824.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3189001739025116,1332.9121,6.002,0.75,873.184 +1.0,True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,11343.0,6702468096.0,4306526208.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.317565279006958,803.4723,9.957,0.622,2880.056 +1.0,True,accelerated-peft-bnb-foak-padding-free,2e-4,16,0.1,7279.0,5381966848.0,2244368896.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3281474394798279,720.8593,11.098,0.694,1591.598 +1.0,True,accelerated-peft-autogptq,2e-4,16,0.1,17307.0,11117607424.0,4336477184.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.30721326303482055,1280.3884,6.248,0.781,2788.229 +1.0,True,accelerated-peft-autogptq,2e-4,16,0.1,10278.0,6458277376.0,2261621760.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3083444125652313,3015.2557,2.653,0.332,490.622 +1.0,True,accelerated-peft-autogptq,2e-4,16,0.1,30747.0,17816344576.0,4336523264.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3211287794113159,1238.8866,6.457,0.404,3392.0 +1.0,True,accelerated-peft-autogptq,2e-4,16,0.1,17537.0,9177420288.0,2261063680.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.32159506034851076,1553.3816,5.15,0.322,1141.09 +1.0,True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,10451.0,6495941632.0,4336461824.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3068936700820923,1156.9306,6.915,0.864,2007.357 +1.0,True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,7418.0,4927370240.0,2261043712.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3077240388393402,2927.4903,2.733,0.342,398.994 +1.0,True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,13995.0,7305247232.0,4336483328.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3207572865486145,798.4771,10.019,0.626,2908.503 +1.0,True,accelerated-peft-autogptq-padding-free,2e-4,16,0.1,8266.0,5451901440.0,2261046784.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3210757732391357,1516.0155,5.277,0.33,759.554 +1.0,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,15995.0,9730710528.0,4336477184.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.30693142366409304,1184.6303,6.753,0.844,3013.612 +1.0,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,9583.0,6404390400.0,2261048320.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.32040582585334776,1427.4259,5.604,0.701,1036.377 +1.0,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,30553.0,15069368832.0,4336523264.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.3212864685058594,1058.3667,7.559,0.472,3970.556 +1.0,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,16012.0,9036583424.0,2261882880.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3336929588317871,759.5563,10.532,0.658,2333.662 +1.0,True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,8715.0,6084141056.0,4336461824.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.3071680905818939,1128.5951,7.088,0.886,2057.756 +1.0,True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,7354.0,4884587520.0,2261043712.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.3200926628112793,1330.477,6.013,0.752,877.92 +1.0,True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,11267.0,6737581056.0,4336770048.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.32083065795898436,752.3622,10.633,0.665,3086.775 +1.0,True,accelerated-peft-autogptq-foak-padding-free,2e-4,16,0.1,7772.0,5402608128.0,2261046784.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.33387429666519164,738.3533,10.835,0.677,1559.546 diff --git a/scripts/benchmarks/refs_orca/requirements_pf.txt b/scripts/benchmarks/refs_orca/requirements_pf.txt new file mode 100644 index 00000000..e30eaa34 --- /dev/null +++ b/scripts/benchmarks/refs_orca/requirements_pf.txt @@ -0,0 +1,88 @@ +accelerate==0.33.0 +aiohappyeyeballs==2.4.0 +aiohttp==3.10.5 +aiosignal==1.3.1 +async-timeout==4.0.3 +attrs==24.2.0 +bitsandbytes==0.43.3 +certifi==2024.8.30 +charset-normalizer==3.3.2 +contourpy==1.3.0 +cycler==0.12.1 +datasets==2.21.0 +dill==0.3.8 +docstring_parser==0.16 +einops==0.8.0 +filelock==3.15.4 +fire==0.6.0 +flash-attn==2.6.3 +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@4831b35b3f2d8c8116e55faed3d050a51ef962a0#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@4831b35b3f2d8c8116e55faed3d050a51ef962a0#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@4831b35b3f2d8c8116e55faed3d050a51ef962a0#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@4831b35b3f2d8c8116e55faed3d050a51ef962a0#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@05b40037a3876c2001dda8535dfa3f9b83714b8e +fonttools==4.53.1 +frozenlist==1.4.1 +fsspec==2024.6.1 +huggingface-hub==0.24.6 +idna==3.8 +Jinja2==3.1.4 +kiwisolver==1.4.5 +llvmlite==0.43.0 +markdown-it-py==3.0.0 +MarkupSafe==2.1.5 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.0.5 +multiprocess==0.70.16 +networkx==3.3 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvtx-cu12==12.1.105 +packaging==24.1 +pandas==2.2.2 +peft==0.12.0 +pillow==10.4.0 +protobuf==5.28.0 +psutil==6.0.0 +pyarrow==17.0.0 +Pygments==2.18.0 +pyparsing==3.1.4 +python-dateutil==2.9.0.post0 +pytz==2024.1 +PyYAML==6.0.2 +regex==2024.7.24 +requests==2.32.3 +rich==13.8.0 +safetensors==0.4.4 +sentencepiece==0.2.0 +shtab==1.7.1 +simpleeval==0.9.13 +six==1.16.0 +sympy==1.13.2 +termcolor==2.4.0 +threadpoolctl==3.5.0 +tokenizers==0.19.1 +torch==2.4.0 +tqdm==4.66.5 +transformers==4.44.2 +triton==3.0.0 +trl==0.10.1 +typing_extensions==4.12.2 +tyro==0.8.10 +tzdata==2024.1 +urllib3==2.2.2 +xxhash==3.5.0 +yarl==1.9.7 diff --git a/scripts/benchmarks/scenarios-orca.yaml b/scripts/benchmarks/scenarios-orca.yaml index df6ac097..9bb5f913 100644 --- a/scripts/benchmarks/scenarios-orca.yaml +++ b/scripts/benchmarks/scenarios-orca.yaml @@ -20,6 +20,11 @@ # allows to configure many different datasets # - there is an option of setting tokenize is True or False +# NOTE: on tokenization +# if tokenize = True then its a pretokenization flow, then below set +# - response_template: null +# - dataset_text_field: null +# otherwise if tokenize = False, then do not set the above to null data_processing: dataset_name: microsoft/orca-math-word-problems-200k chat_template: | @@ -30,9 +35,8 @@ data_processing: ASSISTANT: {{ message['answer'] }} {%- endfor %} - dataset_split: "train[:2000]" - tokenize: True - response_template: "\n\nASSISTANT:" + dataset_split: "train[:8000]" + tokenize: false # scenarios scenarios: @@ -45,8 +49,8 @@ scenarios: packing: False model_name_or_path: - 'mistralai/Mistral-7B-v0.1' - response_template: null - dataset_text_field: null + dataset_text_field: 'output' + response_template: "\n\nASSISTANT:" - name: padding-free framework_config: @@ -60,8 +64,8 @@ scenarios: packing: False model_name_or_path: - 'mistralai/Mistral-7B-v0.1' - response_template: null - dataset_text_field: null + dataset_text_field: 'output' + response_template: "\n\nASSISTANT:" - name: accelerated-peft-bnb framework_config: @@ -83,8 +87,8 @@ scenarios: packing: False model_name_or_path: - 'mistralai/Mistral-7B-v0.1' - response_template: null - dataset_text_field: null + dataset_text_field: 'output' + response_template: "\n\nASSISTANT:" - name: accelerated-peft-gptq framework_config: @@ -106,5 +110,5 @@ scenarios: packing: False model_name_or_path: - 'TheBloke/Mistral-7B-v0.1-GPTQ' - response_template: null - dataset_text_field: null + dataset_text_field: 'output' + response_template: "\n\nASSISTANT:" diff --git a/tox.ini b/tox.ini index 52f9bdb3..d29f1f24 100644 --- a/tox.ini +++ b/tox.ini @@ -38,7 +38,7 @@ commands = # NOTE: when there are more plugins install here python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels - python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention_and_distributed_packing + python -m fms_acceleration.cli install -e {toxinidir}/plugins/attention-and-distributed-packing # run the benchmark script bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs}