Skip to content

Commit

Permalink
Update Inference Benchmarking Scripts - Support AML
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile committed Mar 1, 2024
1 parent 8182a8b commit b32dff3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 37 deletions.
4 changes: 4 additions & 0 deletions benchmarks/inference/mii/src/plot_effective_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
(2600, 60),
(2600, 128),
]
# TODO: Remove hardcoded values, generalize, source from logs

prompt_gen_pairs_test = [(2600, 60)]

Expand Down Expand Up @@ -237,6 +238,8 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out
raise NotImplementedError("This script is not up to date")
args = get_args()

# TODO: Generalize this
# TODO: carry over code from plot_th_lat.py
if args.test:
tp_sizes = tp_sizes_test
prompt_gen_pairs = prompt_gen_pairs_test
Expand All @@ -248,6 +251,7 @@ def display_results(model_size, tp, bs, sla_token_gen, prompt, gen, log_dir, out
for tp in tps:
for prompt, gen in prompt_gen_pairs:
for sla_token_gen in SLA_GEN_TOKENS_PER_SEC:
# TODO: standardize names
display_results(
model_size,
tp,
Expand Down
104 changes: 69 additions & 35 deletions benchmarks/inference/mii/src/plot_th_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--log_dir", type=Path, default="./results")
parser.add_argument("--out_dir", type=Path, default="./plots/throughput_latency")
parser.add_argument("--backend", type=str, choices=["aml", "fastgen", "vllm"], default=["aml", "fastgen", "vllm"], \
nargs="+", help="Specify the backends to generate plots for")
args = parser.parse_args()
return args

Expand All @@ -32,65 +34,96 @@ def extract_values(file_pattern):
clients = []
throughputs = []
latencies = []
extra_args = {}
for f in files:
prof_args, response_details = read_json(f)
summary = get_summary(prof_args, response_details)
clients.append(prof_args["num_clients"])
throughputs.append(summary.throughput)
latencies.append(summary.latency)

return clients, throughputs, latencies
if "aml" in args.backend:
extra_args["aml_api_url"] = prof_args["aml_api_url"]
extra_args["deployment_name"] = prof_args["deployment_name"]

return clients, throughputs, latencies, extra_args


def output_charts(model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir):
out_dir.mkdir(parents=True, exist_ok=True)

result_file_pattern = f"{model}-tp{tp_size}-bs{bs}-replicas{replicas}-prompt{prompt}-gen{gen}-clients*.json"
mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}"
vllm_file_pattern = f"{log_dir}/vllm/{result_file_pattern}"

_, mii_throughputs, mii_latencies = extract_values(mii_file_pattern)
_, vllm_throughputs, vllm_latencies = extract_values(vllm_file_pattern)

# Plotting the scatter plot
plt.figure(figsize=(6, 4))

if len(vllm_throughputs) > 0:
# vLLM plot formatting
if "vllm" in args.backend:
vllm_file_pattern = f"{log_dir}/vllm/{result_file_pattern}"
_, vllm_throughputs, vllm_latencies, _ = extract_values(vllm_file_pattern)
if len(vllm_throughputs) > 0:
plt.scatter(
vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange"
)
fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01)
vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3)
vllm_model_fn = np.poly1d(vllm_vllm_model)
plt.plot(
fit_vllm_x_list,
vllm_model_fn(fit_vllm_x_list),
color="orange",
alpha=0.5,
linestyle="--",
)

# FastGen plot formatting
if "fastgen" in args.backend:
mii_file_pattern = f"{log_dir}/fastgen/{result_file_pattern}"
_, mii_throughputs, mii_latencies, _ = extract_values(mii_file_pattern)
plt.scatter(
vllm_throughputs, vllm_latencies, label=f"vLLM", marker="x", color="orange"
mii_throughputs,
mii_latencies,
label=f"DeepSpeed FastGen",
marker="o",
color="blue",
)
fit_vllm_x_list = np.arange(min(vllm_throughputs), max(vllm_throughputs), 0.01)
vllm_vllm_model = np.polyfit(vllm_throughputs, vllm_latencies, 3)
vllm_model_fn = np.poly1d(vllm_vllm_model)
fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01)
mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3)
mii_model_fn = np.poly1d(mii_fit_model)
plt.plot(
fit_vllm_x_list,
vllm_model_fn(fit_vllm_x_list),
color="orange",
fit_mii_x_list,
mii_model_fn(fit_mii_x_list),
color="blue",
alpha=0.5,
linestyle="--",
)

plt.scatter(
mii_throughputs,
mii_latencies,
label=f"DeepSpeed FastGen",
marker="o",
color="blue",
)
fit_mii_x_list = np.arange(min(mii_throughputs), max(mii_throughputs), 0.01)
mii_fit_model = np.polyfit(mii_throughputs, mii_latencies, 3)
mii_model_fn = np.poly1d(mii_fit_model)
plt.plot(
fit_mii_x_list,
mii_model_fn(fit_mii_x_list),
color="blue",
alpha=0.5,
linestyle="--",
)
# AML plot formatting
if "aml" in args.backend:
aml_file_pattern = f"{log_dir}/aml/{result_file_pattern}"
_, aml_throughputs, aml_latencies, aml_args = extract_values(aml_file_pattern)
aml_endpoint_name = re.match('^https://(.+?)\.', aml_args["aml_api_url"]).groups()[0]
aml_deployment_name = aml_args["deployment_name"]
plt.scatter(
aml_throughputs,
aml_latencies,
label=f"AML {aml_endpoint_name.capitalize()}:{aml_deployment_name}",
marker="o",
color="purple",
)
fit_aml_x_list = np.arange(min(aml_throughputs), max(aml_throughputs), 0.01)
aml_fit_model = np.polyfit(aml_throughputs, aml_latencies, 3)
aml_model_fn = np.poly1d(aml_fit_model)
plt.plot(
fit_aml_x_list,
aml_model_fn(fit_aml_x_list),
color="purple",
alpha=0.5,
linestyle="--",
)

# Generic plot formatting
plt.title(f"Model {model}, Prompt: {prompt}, Generation: {gen}, TP: {tp_size}")
plt.xlabel("Throughput (queries/s)", fontsize=14)
plt.ylabel("Latency", fontsize=14)
plt.ylabel("Latency (s)", fontsize=14)
plt.legend()
plt.grid(True)
plt.tight_layout()
Expand All @@ -112,7 +145,8 @@ def output_charts(model, tp_size, bs, replicas, prompt, gen, log_dir, out_dir):
result_re = re.compile(
r"(.+)-tp(\d+)-bs(\d+)-replicas(\d+)-prompt(\d+)-gen(\d+)-clients.*.json"
)
for f in os.listdir(os.path.join(args.log_dir, "fastgen")):

for f in os.listdir(os.path.join(args.log_dir, args.backend[1])):
match = result_re.match(f)
if match:
result_params.add(match.groups())
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/inference/mii/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def start_vllm_server(args: argparse.Namespace) -> None:
break
if "error" in line.lower():
p.terminate()
stop_vllm_server()
stop_vllm_server(args)
raise RuntimeError(f"Error starting VLLM server: {line}")
if time.time() - start_time > timeout_after:
p.terminate()
stop_vllm_server()
stop_vllm_server(args)
raise TimeoutError("Timed out waiting for VLLM server to start")
time.sleep(0.01)

Expand Down

0 comments on commit b32dff3

Please sign in to comment.