-
Notifications
You must be signed in to change notification settings - Fork 1
/
benchmark_trtllm.py
168 lines (142 loc) · 5.97 KB
/
benchmark_trtllm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import argparse
import multiprocessing as mp
import threading
import time
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient
from tqdm import tqdm
from transformers import AutoTokenizer
from tritonclient.utils import *
from utils import calculate_mean, generate_inputs
def _input(name: str, data: np.ndarray) -> grpcclient.InferInput:
t = grpcclient.InferInput(name, data.shape, np_to_triton_dtype(data.dtype))
t.set_data_from_numpy(data)
return t
def warmup(model_name, client):
batch_size = 1
inputs = [
_input("text", np.array(['hello world, this is to warm up']*batch_size,
dtype=object).reshape(-1, 1)),
_input("max_output_len", np.array([[32]]*batch_size, dtype=np.int32))
]
outputs = [grpcclient.InferRequestedOutput("output")]
client.infer(model_name, inputs, outputs=outputs)
start_time = None
first_token_time = None
end_time = None
output = None
def stream_callback(result, error):
# print('stream_callback')
global first_token_time
global end_time
global output
if error:
raise error
end_time = time.time()
output = result.as_numpy('output')
if first_token_time is None:
first_token_time = end_time
def send_batch(client, model_name, prompts, max_output_len):
n = len(prompts)
threads = []
def send(client, model_name, prompt, max_output_len):
inputs = [
_input("text", np.array([prompt], dtype=object).reshape(-1, 1)),
_input("max_output_len", np.array([[max_output_len]], dtype=np.int32)),
]
client.infer(model_name, inputs, outputs=[grpcclient.InferRequestedOutput("output")])
# Create and start n threads
for i in range(n):
thread = threading.Thread(target=send, args=(client, model_name, prompts[i], max_output_len))
threads.append(thread)
thread.start()
# Wait for all threads to finish
for thread in threads:
thread.join()
def benchmark_triton(
model_name,
tokenizer_path,
max_output_len,
batch_size,
input_len,
streaming,
n,
addr="localhost:8001",
):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
prompts = generate_inputs(tokenizer, input_len, batch_size)
print(f"Prompt: {prompts[0][:32]}..{prompts[0][-512:]}")
inputs = [
_input("text", np.array(prompts, dtype=object).reshape(-1, 1)),
_input("max_output_len", np.array([[max_output_len]]*batch_size, dtype=np.int32)),
]
if streaming:
first_token_latency = [0]*n
throughput = [0]*n
latency = [0]*n
output_tokens = 0
for i in tqdm(range(n)):
with grpcclient.InferenceServerClient(addr, verbose=False) as client:
global first_token_time
first_token_time = None
start_time = time.time()
client.start_stream(callback=partial(stream_callback,))
client.async_stream_infer(model_name, inputs)
global end_time
first_token_latency[i] = first_token_time - start_time
latency[i] = end_time - start_time
tokens = 0
for ot in output:
output_len = len(tokenizer.encode(ot[0].decode())) - 1
output_tokens += output_len
tokens += input_len + output_len # get rid of the start token.
throughput[i] = tokens/latency[i]
print('first_token_latency: ', calculate_mean(first_token_latency))
print('avg_output_len: ', int(output_tokens/n))
print('latency', calculate_mean(latency))
print('throughput: ', calculate_mean(throughput))
return
with grpcclient.InferenceServerClient(addr, verbose=False) as client:
print('warm up')
warmup(model_name, client)
print('done warm up')
latency = [0]*n
for i in tqdm(range(n)):
start_time = time.time()
#response = client.infer(model_name, inputs, outputs=[grpcclient.InferRequestedOutput("output")])
send_batch(client, model_name, prompts, max_output_len)
end_time = time.time()
latency[i] = end_time-start_time
# outputs = response.as_numpy("output")
# generated_text = outputs[0][0].decode()
# # Print the output to compare with each framework
# #print(f"Generated text: {generated_text[:32]}..{generated_text[-32:]}")
# print(f"Generated text: {generated_text[:512]}")
# tokens = tokenizer.encode(outputs[0][0].decode())
# print('output_tokens:', len(tokens))
print(f'latency: {calculate_mean(latency)}')
parser = argparse.ArgumentParser(description="Benchmark")
# Add arguments to the parser
parser.add_argument("--model_name", type=str, default='llama-2-70b-chat-hf-ft-streaming')
parser.add_argument("--tokenizer_path", type=str, default='/models/triton/llama-2-70b-hf-ft_tokenizer/1/')
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--max_output_len", type=int, default=32)
parser.add_argument("--input_len", type=int, default=1)
parser.add_argument("--n", type=int, default=50)
parser.add_argument("--streaming", action='store_true', default=False, help="Whether or not to stream")
# Parse the command-line arguments
args = parser.parse_args()
print('\n=============== Argument ===============')
for key in vars(args):
print('{}: {}'.format(key, vars(args)[key]))
print('========================================')
benchmark_triton(model_name=args.model_name,
tokenizer_path=args.tokenizer_path,
max_output_len=args.max_output_len,
input_len=args.input_len,
batch_size=args.batch_size,
streaming=args.streaming,
n=args.n)
## python3 b.py --model_name llama-2-70b-hf-ft --input_len 1 --batch_size 1 --max_output_len 2048
## python3 b.py --model_name llama-2-70b-hf-ft --input_len 1024 --max_output_len 1024 --batch_size 32