From ccaba64320f57db618c217aeb0f8c3b33507080c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 9 Jul 2024 09:33:00 -0400 Subject: [PATCH] Profiler improvements (#355) --- examples/offline_profile.py | 137 +++++++++++------- neuralmagic/tools/profiler/print_table.py | 2 +- neuralmagic/tools/profiler/visualize_trace.py | 15 +- 3 files changed, 92 insertions(+), 62 deletions(-) diff --git a/examples/offline_profile.py b/examples/offline_profile.py index 1c95b5bed451c..e77c55d88fcd6 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -12,19 +12,21 @@ BATCH_SIZE_DEFAULT = 1 PROMPT_LEN_DEFAULT = 256 -MAX_SEQ_LEN_DEFAULT = 1024 +OUTPUT_LEN_DEFAULT = 2 @dataclass class ProfileContext: model: str + tokenizer: str model_revision: str - sparsity: str quantization: str - max_seq_len: int + max_model_len: int max_num_batched_tokens: int prompt_len: int + output_len: int batch_size: int + dtype: str tensor_parallel_size: int allow_cuda_graphs: bool @@ -38,26 +40,29 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], # Create sampling params sampling_params = SamplingParams(temperature=0.8, top_p=0.95, - max_tokens=8, + max_tokens=context.output_len, ignore_eos=True) + # Sparsity is in the future # Create LLM - llm = LLM( - model=context.model, - revision=context.model_revision, - sparsity=context.sparsity, - enforce_eager=not context.allow_cuda_graphs, - tensor_parallel_size=context.tensor_parallel_size, - gpu_memory_utilization=0.9, - max_model_len=context.max_seq_len, - quantization=context.quantization, - max_num_batched_tokens=context.max_num_batched_tokens, - ) + llm = LLM(model=context.model, + tokenizer=context.tokenizer + if context.tokenizer is not None else context.model, + revision=context.model_revision, + enforce_eager=not context.allow_cuda_graphs, + tensor_parallel_size=context.tensor_parallel_size, + gpu_memory_utilization=0.9, + max_model_len=context.max_model_len, + quantization=context.quantization, + dtype=context.dtype, + max_num_batched_tokens=context.max_num_batched_tokens) batch_size = context.batch_size prompt_len = context.prompt_len + output_len = context.output_len scheduler_config = llm.llm_engine.scheduler_config + max_model_len = llm.llm_engine.model_config.max_model_len max_num_batched_tokens = scheduler_config.max_num_batched_tokens max_num_seqs = scheduler_config.max_num_seqs @@ -75,6 +80,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " f"single profile step, please choose a smaller batch size") sys.exit(-1) + print("llm.llm_engine.model_config.max_model_len: ", + llm.llm_engine.model_config.max_model_len) + if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: + print( + f"ERROR: chosen prompt_len + output_len ({prompt_len} + " + f"{output_len} = {prompt_len + output_len}) is larger than the " + f"model's max_model_len ({max_model_len}), please choose a smaller " + f"prompt_len or output_len, or increase --max-model-len") + sys.exit(-1) for i in range(batch_size): prompt_token_ids = torch.randint( @@ -89,11 +103,14 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], with nm_profile() as prefill_prof: llm.llm_engine.step() # First step is prefill - with nm_profile() as decode_prof: - llm.llm_engine.step() + decode_results_list = [] + for _ in range(context.output_len - 1): + with nm_profile() as decode_prof: + llm.llm_engine.step() + decode_results_list.append(decode_prof.results) prefill_results = prefill_prof.results - decode_results = decode_prof.results + has_decode = len(decode_results_list) > 0 print("=" * 80) print(f"= Prefill Model Table " @@ -101,13 +118,16 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], print("=" * 80) print() prefill_results.print_model_table() - print() - print("=" * 80) - print(f"= Decode Model Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * 80) - print() - decode_results.print_model_table() + + if has_decode: + print() + print("=" * 80) + print(f"= First Decode Step Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results_list[0].print_model_table() + print() print("=" * 80) print(f"= Prefill Summary Table " @@ -115,13 +135,14 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], print("=" * 80) print() prefill_results.print_summary_table() - print() - print("=" * 80) - print(f"= Decode Summary Table " - f"(prompt_len={prompt_len}, batch_size={batch_size})") - print("=" * 80) - print() - decode_results.print_summary_table() + if has_decode: + print() + print("=" * 80) + print(f"= First Decode Step Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * 80) + print() + decode_results_list[0].print_summary_table() if csv_output: csv_filename_base = csv_output.rstrip(".csv") @@ -129,10 +150,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], csv_filename_base + "_prefill_model_table.csv") prefill_results.export_summary_stats_table_csv( csv_filename_base + "_prefill_summary_table.csv") - decode_results.export_model_stats_table_csv(\ - csv_filename_base + "_decode_model_table.csv") - decode_results.export_summary_stats_table_csv( - csv_filename_base + "_decode_summary_table.csv") + + if has_decode: + decode_results_list[0].export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results_list[0].export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") if json_output: cuda_devices = [ @@ -149,9 +172,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], **asdict(context) }, "prefill": prefill_results.convert_stats_to_dict(), - "decode": decode_results.convert_stats_to_dict() } + if has_decode: + for idx, dr in enumerate(decode_results_list): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + with open(json_output.rstrip(".json") + ".json", "w+") as f: json.dump(json_dict, f, indent=2) pass @@ -165,6 +191,11 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], type=str, required=True, help='The name or path of a HuggingFace Transformers model.') + parser.add_argument("--tokenizer", + type=str, + default=None, + help="path to the tokenizer") + parser.add_argument("--model-revision", type=str, default=None) parser.add_argument( "--csv", @@ -180,29 +211,23 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], type=str, default=None, help="Export the results as a json file. This should be the filename") - parser.add_argument( - "--sparsity", - "-s", - type=str, - choices=[None, 'sparse_w16a16', 'semi_structured_sparse_w16a16'], - help="Method used to compress sparse weights. If " - "None, we first check the `sparsity_config` attribute" - "in the model config file. If that is None we assume" - "the model weights are dense") parser.add_argument( "--quantization", "-q", type=str, - choices=['awq', 'gptq', 'squeezellm', 'marlin', None], + choices=['awq', 'gptq', 'squeezellm', 'marlin', 'smoothquant', None], default=None, - help="The method used to quantize the model weights, " - "options are \"marlin\", \"awq\", \"gptq\" and \"squeezellm\"") + help="The method used to quantize the model weights, options are " + "\"marlin\", \"awq\", \"gptq\", \"squeezellm\", \"smoothquant\"") + parser.add_argument("--dtype", + type=str, + default='auto', + help="model dtype") parser.add_argument( - "--max-seq-len", + "--max-model-len", type=int, - default=MAX_SEQ_LEN_DEFAULT, - help=f"Maximum length of a sequence (including prompt and output), " - f"default={MAX_SEQ_LEN_DEFAULT}") + default=None, + help="Maximum length of a sequence (including prompt and output)") parser.add_argument( "--max-num-batched-tokens", type=int, @@ -216,6 +241,12 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], default=PROMPT_LEN_DEFAULT, help=f"Length of the random prompt to use when profiling, all batched " f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument( + "--output-len", + type=int, + default=OUTPUT_LEN_DEFAULT, + help= + f"Number of output decode steps to run, default={OUTPUT_LEN_DEFAULT}") parser.add_argument("--batch-size", type=int, default=BATCH_SIZE_DEFAULT, diff --git a/neuralmagic/tools/profiler/print_table.py b/neuralmagic/tools/profiler/print_table.py index 6e56d8d64aa9a..9081583a9f95d 100644 --- a/neuralmagic/tools/profiler/print_table.py +++ b/neuralmagic/tools/profiler/print_table.py @@ -34,7 +34,7 @@ def get_entries(node, curr_depth=0): "examples/offline_profile.py") parser.add_argument("--phase", type=str, - choices=["prefill", "decode"], + choices=["prefill", "decode_1"], required=True, help="The phase to print the table for.") parser.add_argument("--table", diff --git a/neuralmagic/tools/profiler/visualize_trace.py b/neuralmagic/tools/profiler/visualize_trace.py index fa59690da43cf..fd5659161b046 100644 --- a/neuralmagic/tools/profiler/visualize_trace.py +++ b/neuralmagic/tools/profiler/visualize_trace.py @@ -118,7 +118,7 @@ def get_entries_at_depth(depth, for root in profile_data["prefill"]["summary_stats"]: get_entries_at_depth(depth, prefill_entries_and_traces, root) - for root in profile_data["decode"]["summary_stats"]: + for root in profile_data["decode_1"]["summary_stats"]: get_entries_at_depth(depth, decode_entries_and_traces, root) def attempt_to_make_names_unique(entries_and_traces): @@ -199,12 +199,11 @@ def plot_metric(metric: str, ax, add_totals=False): shorten_plot_legend_strings(legend, 50) context = profile_data["context"] - plt.suptitle( - f"{context['model']}\n" - f"Batch={context['batch_size']}, " - f"PromptLen={context['prompt_len']}, " - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + context['sparsity'] if context['sparsity'] else ''}" - ) + sparsity = context.get('sparsity', None) + plt.suptitle(f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") plt.savefig(output, bbox_inches='tight') print("Created: ", output)