diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 59e6c547..91d52601 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -11,12 +11,15 @@ # Third Party from tqdm import tqdm -from transformers import AutoConfig, HfArgumentParser, TrainingArguments +from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, TrainingArguments import datasets import pandas as pd import torch import yaml +# First Party +from scripts.benchmarks.data_processing import build_data_formatting_func + """ This benchmarking script 1. Prepares a standard BenchmarkDataset @@ -26,19 +29,6 @@ 4. Consolidates the experiment results into a summary """ -PROMPT_DICT = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), -} - COMMAND_PYTHON = "python" COMMAND_ACCELERATE = "accelerate launch --config_file {accelerate_config_path} --num_processes={num_processes} --main_process_port={process_port}" FMS_TRAINER = "-m tuning.sft_trainer" @@ -50,6 +40,7 @@ FILE_SHELL_COMMAND = "command.sh" FILE_SCRIPT_ARGS = "script.json" FILE_SUMMARY_CSV = "raw_summary.csv" +DATA_JSON_NAME = "cache_{}.json" DIR_BENCHMARKS = os.path.dirname(os.path.realpath(__file__)) DIR_PREFIX_EXPERIMENT = "exp" @@ -86,12 +77,17 @@ HF_TRAINER_LOG_GPU_STAGE_TRAIN = "train_mem_gpu" KEYWORD_PEAKED_DELTA = "peaked_delta" KEYWORD_ALLOC_DELTA = "alloc_delta" -HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics" +HF_ARG_TRAINING_DATA_PATH = "training_data_path" +HF_ARG_RESPONSE_TEMPLATE = "response_template" +HF_ARG_SKIP_MEMORY_METRIC = "skip_memory_metrics" RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes" RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes" ERROR_MESSAGES = "error_messages" DRY_RUN_MESSAGE = "dry_run" +SCENARIOS_STANZA_SCN = "scenarios" +SCENARIOS_STANZA_DATA = "data_processing" # optional + def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]: """ @@ -157,43 +153,80 @@ def get_hf_arguments_with_no_value(dataclass_types): TRUE_FALSE_ARGUMENTS = get_hf_arguments_with_no_value(dataclass_types=TrainingArguments) -def format_fn(example, input_key: str = "input", output_key: str = "output"): - prompt_input, prompt_no_input = ( - PROMPT_DICT["prompt_input"], - PROMPT_DICT["prompt_no_input"], - ) - output = ( - prompt_input.format_map(example) - if example.get(input_key, "") != "" - else prompt_no_input.format_map(example) - ) - output = f"{output} {example[output_key]}" - return {output_key: output} - - class BenchmarkDataset: def __init__( self, - dataset_name: str, - format_fn: Callable, - unused_columns: List[str] = ["instruction", "input"], + data_save_path: str, + dataset_name: str = "yahma/alpaca-cleaned", + dataset_split: str = "train", + formatting: str = "instruct", + tokenize: bool = False, + input_field: str = "input", + dataset_text_field: str = "output", + chat_template: str = None, ) -> None: - self.dataset_name = dataset_name - self.dataset = self.prepare_dataset(format_fn, unused_columns=unused_columns) - def save_to_path(self, save_path: str): - self.dataset.to_json(save_path) + self.dataset_split = datasets.load_dataset(dataset_name, split=dataset_split) + + self.kwargs = { + "formatting": formatting, + "tokenize": tokenize, + "input_field": input_field, + "dataset_text_field": dataset_text_field, + "chat_template": chat_template, + } + self.training_paths = {} # cache to store the training paths + self.data_save_path = data_save_path def prepare_dataset( self, - format_fn: Callable = None, - dataset_split: str = "train", - unused_columns: List[str] = None, + model_name: str, + response_template: str = None, ): - ds = datasets.load_dataset(self.dataset_name) - if format_fn: - ds = ds[dataset_split].map(format_fn, remove_columns=unused_columns) - return ds + if model_name in self.training_paths: + return self.training_paths[model_name] + + if self.kwargs["tokenize"]: + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # for now, if pad_token_id is None, will just do a replacement + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + # replace some special characters in the model name + save_path = DATA_JSON_NAME.format( + re.sub(r"[/-]", "_", model_name), + ) + else: + tokenizer = None + save_path = DATA_JSON_NAME.format("all") + + # get the full path + save_path = os.path.join(self.data_save_path, save_path) + + # build the formatting func + format_fn, kwargs = build_data_formatting_func( + tokenizer, + **self.kwargs, + features=set(self.dataset_split.features), + response_template=response_template, + ) + + if "chat_template" in self.kwargs: + print("*** CHAT TEMPLATE *****") + print(self.kwargs["chat_template"]) + + print(f"Preparing dataset '{save_path}'") + + # call the map + ds = self.dataset_split.map(format_fn, **kwargs) + + # save it + ds.to_json(save_path) + + # store in cache + self.training_paths[model_name] = save_path + return save_path def convert_keypairs_to_map(keypairs: List): @@ -602,10 +635,9 @@ def get_peak_mem_usage_by_device_id(gpu_logs: pd.DataFrame): return peak_values.sub(initial_values), device_name -def prepare_arguments(args): +def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): defaults = ConfigUtils.read_yaml(args.defaults_config_path) - defaults["training_data_path"] = args.dataset_save_path - scenarios = ConfigUtils.read_yaml(args.scenarios_config_path)["scenarios"] + scenarios = ConfigUtils.read_yaml(args.scenarios_config_path)[SCENARIOS_STANZA_SCN] acceleration_config_map = convert_keypairs_to_map( args.acceleration_framework_config_keypairs ) @@ -647,6 +679,20 @@ def prepare_arguments(args): if args.preload_models and len(products) > 0: scenario.preload_models() + # handle the dataset + for x in products: + # prepare the dataset + training_path = benchmark_dataset.prepare_dataset( + x["model_name_or_path"], + ( + x[HF_ARG_RESPONSE_TEMPLATE] + if HF_ARG_RESPONSE_TEMPLATE in x + else constants.get(HF_ARG_RESPONSE_TEMPLATE) + ), + ) + # update + x[HF_ARG_TRAINING_DATA_PATH] = training_path + for ( num_gpus, framework_config, @@ -672,7 +718,7 @@ def generate_list_of_experiments( expr_arg_w_outputdir = exp_arg + [ "--output_dir", os.path.join(experiment_output_dir, hf_products_dir), - HF_ARG_SKIP_MEMORY_METRIC, + "--" + HF_ARG_SKIP_MEMORY_METRIC, not log_memory_in_trainer, ] expr_cls = Experiment if not dry_run else DryRunExperiment @@ -801,10 +847,16 @@ def main(args): args.log_nvidia_smi = False # 1. Prepares a standard BenchmarkDataset - # TODO: consider caching the json file + # - the preperation of the dataset is deferred to when 'prepare_dataset' is called + # - try to read the data_processing stanza of + dataset_processing_args = ConfigUtils.read_yaml(args.scenarios_config_path).get( + SCENARIOS_STANZA_DATA, {} + ) if not args.no_data_processing: - benchmark_dataset = BenchmarkDataset(args.dataset_name, format_fn) - benchmark_dataset.save_to_path(args.dataset_save_path) + benchmark_dataset = BenchmarkDataset( + args.dataset_save_path, + **dataset_processing_args, + ) # dump out the script arguments os.makedirs(args.results_output_path, exist_ok=True) @@ -812,7 +864,7 @@ def main(args): json.dump(vars(args), f, indent=4, sort_keys=True) # 2. Prepares a list of experiment arguments from a set of configs - experiment_args = prepare_arguments(args) + experiment_args = prepare_arguments(args, benchmark_dataset) # 3. Builds a list of experiment objects to run based on the set of experiment arguments experiment_stats = {} @@ -948,16 +1000,10 @@ def main(args): default=f"{DIR_BENCHMARKS}/defaults.yaml", help="path to defaults config file", ) - parser.add_argument( - "--dataset_name", - type=str, - default="yahma/alpaca-cleaned", - help="dataset to benchmark on", - ) parser.add_argument( "--dataset_save_path", type=str, - default=f"{DIR_BENCHMARKS}/data/cache.json", + default=f"{DIR_BENCHMARKS}/data", help="dataset cache path", ) parser.add_argument( diff --git a/scripts/benchmarks/compare_with_reference.py b/scripts/benchmarks/compare_with_reference.py index e974dbdc..953ead5c 100644 --- a/scripts/benchmarks/compare_with_reference.py +++ b/scripts/benchmarks/compare_with_reference.py @@ -37,6 +37,7 @@ RAW_FILENAME = "raw_summary.csv" OUTLIERS_FILENAME = "outliers.csv" + def plot_chart(ax, x, y, title, xlabel, ylabel): ax.scatter(x, y, s=10) ax.set_title(title, fontsize=8) diff --git a/scripts/benchmarks/data_processing.py b/scripts/benchmarks/data_processing.py new file mode 100644 index 00000000..1a860bbe --- /dev/null +++ b/scripts/benchmarks/data_processing.py @@ -0,0 +1,196 @@ +# Standard +from typing import Callable, Dict, List + +# Third Party +from transformers import PreTrainedTokenizer +from trl import DataCollatorForCompletionOnlyLM + +DEFAULT_FIELDS = ["input_ids", "attention_mask", "labels"] + + +def build_data_formatting_func( + tokenizer: PreTrainedTokenizer = None, + formatting: str = "instruct", + tokenize: bool = False, + input_field: str = "input", + dataset_text_field: str = "output", + features: List = None, + response_template: str = None, + chat_template: str = None, +): + if tokenizer is None or chat_template is None: + return _build_data_formatting_func_without_chat_template( + tokenizer, + formatting, + tokenize, + input_field, + dataset_text_field, + features, + response_template, + ) + + return _build_data_formatting_func( + tokenizer, + tokenize, + chat_template, + dataset_text_field, + features, + response_template, + ) + + +# this one uses the chat template and tokenizer +def _build_data_formatting_func( + tokenizer: PreTrainedTokenizer, + tokenize: bool = False, + chat_template: str = None, + dataset_text_field: str = "output", + features: List = None, + response_template: str = None, +): + + tokenizer.chat_template = chat_template + + loss_masking = None + if tokenize and response_template is not None: + loss_masking = instruction_mask_loss(tokenizer, response_template) + + def _format(example): + formatted_and_maybe_tokenized = tokenizer.apply_chat_template( + [example], tokenize=tokenize + ) + key = "input_ids" if tokenize else dataset_text_field + if not loss_masking: + return {key: formatted_and_maybe_tokenized} + return loss_masking(formatted_and_maybe_tokenized) + + return _format, {"remove_columns": features.difference(set(DEFAULT_FIELDS))} + + +# ---- NOTE: remove this eventually and move to check templates ---- +PROMPT_DICT = { + "prompt_input": ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" + ), + "prompt_no_input": ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:" + ), +} + +# combine functions +# c = combine(a, b) then c(i) = b(a(i)) +FUNC = Callable[[Dict], Dict] + + +def combine_functions(*funcs: FUNC) -> FUNC: + def _combine(x): + for f in funcs: + x = f(x) + return x + + return _combine + + +def _build_data_formatting_func_without_chat_template( + tokenizer: PreTrainedTokenizer = None, + formatting: str = "instruct", + tokenize: bool = False, + input_field: str = "input", + dataset_text_field: str = "output", + features: List = None, + response_template: str = None, +): + # FIFO + funcs = [] + + if features is None: + features = set() + + if formatting == "instruct": + funcs.append( + instruction_formatter( + input_field=input_field, dataset_text_field=dataset_text_field + ) + ) + + if tokenize: + funcs.append(tokenization(tokenizer, dataset_text_field=dataset_text_field)) + + if formatting == "instruct" and response_template: + funcs.append(instruction_mask_loss(tokenizer, response_template)) + + if len(funcs) == 0: + raise ValueError("Unable to build a data formatting recipe") + + return combine_functions(*funcs), { + "remove_columns": features.union( + set([input_field, dataset_text_field]) + ).difference(set(DEFAULT_FIELDS)) + } + + +def instruction_formatter( + input_field: str = "input", dataset_text_field: str = "output" +): + def format_fn(example: Dict): + prompt_input, prompt_no_input = ( + PROMPT_DICT["prompt_input"], + PROMPT_DICT["prompt_no_input"], + ) + output = ( + prompt_input.format_map(example) + if example.get(input_field, "") != "" + else prompt_no_input.format_map(example) + ) + output = f"{output} {example[dataset_text_field]}" + return {dataset_text_field: output} + + return format_fn + + +def tokenization(tokenizer: PreTrainedTokenizer, dataset_text_field: str = "output"): + def _tokenize(example): + text_field = example[dataset_text_field] + tokenizer.eos_token + return tokenizer(text_field) + + return _tokenize + + +# ---- NOTE: remove this eventually and move to check templates ---- + + +def instruction_mask_loss( + tokenizer: PreTrainedTokenizer, + response_template: str, + take_from_index: int = 2, +): + + print(f"Applying loss masking to reponse template '{response_template}'") + + # cheat, use the data collator to mask the loss tokens + response_template_ids = tokenizer.encode( + response_template, add_special_tokens=False + ) + + # this ignores the first + if len(response_template_ids) > take_from_index: + response_template_ids = response_template_ids[take_from_index:] + print( + f"Taking response_ids[{take_from_index}:] from '{len(response_template_ids)}' response tokens" + ) + + collator = DataCollatorForCompletionOnlyLM( + response_template_ids, tokenizer=tokenizer, ignore_index=-100 + ) + + def collate_example(example): + # single example + collated_example = collator([example], return_tensors="pt") + # flatten the additional dim + return {k: v.view(-1) for k, v in collated_example.items()} + + return collate_example diff --git a/scripts/benchmarks/scenarios-pretok.yaml b/scripts/benchmarks/scenarios-pretok.yaml new file mode 100644 index 00000000..b7c9a442 --- /dev/null +++ b/scripts/benchmarks/scenarios-pretok.yaml @@ -0,0 +1,62 @@ +# This file holds a sample full-finetuning scenario and +# demonstrates various pretokenization scenarios + +# the data_processing stanza is optional +# - if it is missing, then the defaults is to use alpaca +# with instruct formatting and no tokenization + +# - this is an older style method which does not rely on +# chat templates, this will also do instruct formatting +# - but if tokenize = True, this works only if +# sft_trainer accepts pretokenized dataset +# data_processing: +# dataset_name: yahma/alpaca-cleaned +# formatting: "instruct" +# tokenize: True +# input_field: input + +# - this is the new style, with the chat templates for formatting +# - this is the best approach to keep things flexible and +# allows to configure many different datasets +# - there is an option of setting tokenize is True or False +data_processing: + dataset_name: yahma/alpaca-cleaned + chat_template: | + {%- for message in messages %} + {% if message['input'] != '' %} + Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + {% else %} + Below is an instruction that describes a task. Write a response that appropriately completes the request. + + {% endif %} + ### Instruction: + {{ message['instruction'] }} + + {% if message['input'] != '' %} + ### Input: + {{ message['input'] }} + + {% endif %} + ### Response: + {{ message['output'] + eos_token }} + {% endfor %} + tokenize: True + +# scenarios +scenarios: + - name: full-finetuning + arguments: + learning_rate: 2e-5 + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + torch_dtype: float16 + + - name: padding-free + framework_config: + - ilab-padding-free + arguments: + learning_rate: 2e-5 + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + torch_dtype: float16 \ No newline at end of file diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh index c6098228..5fb83b99 100644 --- a/scripts/run_benchmarks.sh +++ b/scripts/run_benchmarks.sh @@ -25,7 +25,7 @@ SCNTAG_PEFT_AUTOGPTQ=accelerated-peft-gptq # ------------- OTHER CONFIGS ----------------- # data will be cached in here -DATA_CACHE=data/cache.json +DATA_CACHE=data # final result placed here BENCH_RESULT_FILE=benchmarks.csv @@ -44,7 +44,7 @@ MEMORY_LOGGING=${MEMORY_LOGGING:-"all"} NUM_GPUS_MATRIX=${1-"1 2"} RESULT_DIR=${2:-"benchmark_outputs"} SCENARIOS_CONFIG=${3:-$SCENARIOS_CONFIG} -SCENARIOS_FILTER=${4:-$SCNTAG_PEFT_AUTOGPTQ} +SCENARIOS_FILTER=${4-$SCNTAG_PEFT_AUTOGPTQ} echo "NUM_GPUS_MATRIX: $NUM_GPUS_MATRIX" echo "RESULT_DIR: $RESULT_DIR"