From b32dff39f887e6b347dbdb7b52f6f8b4874e3c5e Mon Sep 17 00:00:00 2001 From: Lev Kurilenko Date: Fri, 1 Mar 2024 01:04:54 +0000 Subject: [PATCH] Update Inference Benchmarking Scripts - Support AML --- .../mii/src/plot_effective_throughput.py | 4 + benchmarks/inference/mii/src/plot_th_lat.py | 104 ++++++++++++------ benchmarks/inference/mii/src/server.py | 4 +- 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/benchmarks/inference/mii/src/plot_effective_throughput.py b/benchmarks/inference/mii/src/plot_effective_throughput.py index efa471c76..97ae16f68 100644 --- a/benchmarks/inference/mii/src/plot_effective_throughput.py +++ b/benchmarks/inference/mii/src/plot_effective_throughput.py @@ -27,6 +27,7 @@ (2600, 60), (2600, 128), ] +# TODO: Remove hardcoded values, generalize, source from logs prompt_gen_pairs_test = [(2600, 60)] @@ -237,6 +238,8 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out raise NotImplementedError("This script is not up to date") args = get_args() + # TODO: Generalize this + # TODO: carry over code from plot_th_lat.py if args.test: tp_sizes = tp_sizes_test prompt_gen_pairs = prompt_gen_pairs_test @@ -248,6 +251,7 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out for tp in tps: for prompt, gen in prompt_gen_pairs: for sla_token_gen in SLA_GEN_TOKENS_PER_SEC: + # TODO: standardize names display_results( model_size, tp, diff --git a/benchmarks/inference/mii/src/plot_th_lat.py b/benchmarks/inference/mii/src/plot_th_lat.py index 9aa292ca6..fb5a7dbbe 100644 --- a/benchmarks/inference/mii/src/plot_th_lat.py +++ b/benchmarks/inference/mii/src/plot_th_lat.py @@ -19,6 +19,8 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--log_dir", type=Path, default="./results") parser.add_argument("--out_dir", type=Path, default="./plots/throughput_latency") + parser.add_argument("--backend", type=str, choices=["aml", "fastgen", "vllm"], default=["aml", "fastgen", "vllm"], \ + nargs="+", help="Specify the backends to generate plots for") args = parser.parse_args() return args @@ -32,6 +34,7 @@ def extract_values(file_pattern): clients = [] throughputs = [] latencies = [] + extra_args = {} for f in files: prof_args, response_details = read_json(f) summary = get_summary(prof_args, response_details) @@ -39,58 +42,88 @@ def extract_values(file_pattern): throughputs.append(summary.throughput) latencies.append(summary.latency) - return clients, throughputs, latencies + if "aml" in args.backend: + extra_args["aml_api_url"] = prof_args["aml_api_url"] + extra_args["deployment_name"] = prof_args["deployment_name"] + + return clients, throughputs, latencies, extra_args def output_charts(model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir): out_dir.mkdir(parents=True, exist_ok=True) result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json" - mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}" - vllm_file_pattern = f"{log_dir}/vllm/{result_file_pattern}" - - _, mii_throughputs, mii_latencies = extract_values(mii_file_pattern) - _, vllm_throughputs, vllm_latencies = extract_values(vllm_file_pattern) # Plotting the scatter plot - plt.figure(figsize=(6, 4)) - - if len(vllm_throughputs) > 0: + # vLLM plot formatting + if "vllm" in args.backend: + vllm_file_pattern = f"{log_dir}/vllm/{result_file_pattern}" + _, vllm_throughputs, vllm_latencies, _ = extract_values(vllm_file_pattern) + if len(vllm_throughputs) > 0: + plt.scatter( + vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange" + ) + fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01) + vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3) + vllm_model_fn = np.poly1d(vllm_vllm_model) + plt.plot( + fit_vllm_x_list, + vllm_model_fn(fit_vllm_x_list), + color="orange", + alpha=0.5, + linestyle="--", + ) + + # FastGen plot formatting + if "fastgen" in args.backend: + mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}" + _, mii_throughputs, mii_latencies, _ = extract_values(mii_file_pattern) plt.scatter( - vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange" + mii_throughputs, + mii_latencies, + label=f"DeepSpeed FastGen", + marker="o", + color="blue", ) - fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01) - vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3) - vllm_model_fn = np.poly1d(vllm_vllm_model) + fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01) + mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3) + mii_model_fn = np.poly1d(mii_fit_model) plt.plot( - fit_vllm_x_list, - vllm_model_fn(fit_vllm_x_list), - color="orange", + fit_mii_x_list, + mii_model_fn(fit_mii_x_list), + color="blue", alpha=0.5, linestyle="--", ) - plt.scatter( - mii_throughputs, - mii_latencies, - label=f"DeepSpeed FastGen", - marker="o", - color="blue", - ) - fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01) - mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3) - mii_model_fn = np.poly1d(mii_fit_model) - plt.plot( - fit_mii_x_list, - mii_model_fn(fit_mii_x_list), - color="blue", - alpha=0.5, - linestyle="--", - ) + # AML plot formatting + if "aml" in args.backend: + aml_file_pattern = f"{log_dir}/aml/{result_file_pattern}" + _, aml_throughputs, aml_latencies, aml_args = extract_values(aml_file_pattern) + aml_endpoint_name = re.match('^https://(.+?)\.', aml_args["aml_api_url"]).groups()[0] + aml_deployment_name = aml_args["deployment_name"] + plt.scatter( + aml_throughputs, + aml_latencies, + label=f"AML {aml_endpoint_name.capitalize()}:{aml_deployment_name}", + marker="o", + color="purple", + ) + fit_aml_x_list = np.arange(min(aml_throughputs), max(aml_throughputs), 0.01) + aml_fit_model = np.polyfit(aml_throughputs, aml_latencies, 3) + aml_model_fn = np.poly1d(aml_fit_model) + plt.plot( + fit_aml_x_list, + aml_model_fn(fit_aml_x_list), + color="purple", + alpha=0.5, + linestyle="--", + ) + # Generic plot formatting plt.title(f"Model {model}, Prompt: {prompt}, Generation: {gen}, TP: {tp_size}") plt.xlabel("Throughput (queries/s)", fontsize=14) - plt.ylabel("Latency", fontsize=14) + plt.ylabel("Latency (s)", fontsize=14) plt.legend() plt.grid(True) plt.tight_layout() @@ -112,7 +145,8 @@ def output_charts(model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir): result_re = re.compile( r"(.+)-tp(\d+)-bs(\d+)-replicas(\d+)-prompt(\d+)-gen(\d+)-clients.*.json" ) - for f in os.listdir(os.path.join(args.log_dir, "fastgen")): + + for f in os.listdir(os.path.join(args.log_dir, args.backend[1])): match = result_re.match(f) if match: result_params.add(match.groups()) diff --git a/benchmarks/inference/mii/src/server.py b/benchmarks/inference/mii/src/server.py index ec04338b5..1fee4311f 100644 --- a/benchmarks/inference/mii/src/server.py +++ b/benchmarks/inference/mii/src/server.py @@ -49,11 +49,11 @@ def start_vllm_server(args: argparse.Namespace) -> None: break if "error" in line.lower(): p.terminate() - stop_vllm_server() + stop_vllm_server(args) raise RuntimeError(f"Error starting VLLM server: {line}") if time.time() - start_time > timeout_after: p.terminate() - stop_vllm_server() + stop_vllm_server(args) raise TimeoutError("Timed out waiting for VLLM server to start") time.sleep(0.01)