Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use streamer for metrics calculation #874

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion llm_bench/python/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,112 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list,
bench_hook.clear_time_infer_list()


def run_text_generation_genai_with_streamer(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, model_precision, proc_id):
set_seed(args['seed'])
eaidova marked this conversation as resolved.
Show resolved Hide resolved
input_text_list = [input_text] * args['batch_size']
if args["output_dir"] is not None and num == 0:
for bs_index, in_text in enumerate(input_text_list):
llm_bench_utils.output_file.output_input_text(in_text, args, model_precision, prompt_index, bs_index, proc_id)
pt_inputs = tokenizer(input_text_list, return_tensors="pt")
input_token_size = pt_inputs.input_ids.shape[1]
pipe_tokenizer = model.get_tokenizer()
tok_encode_start = time.perf_counter()
input_data = pipe_tokenizer.encode(input_text_list)
tok_encode_end = time.perf_counter()
tok_encode_time = (tok_encode_end - tok_encode_start) * 1000
if args['batch_size'] > 1:
out_str = '[warm-up]' if num == 0 else '[{}]'.format(num)
out_str += " Batch_size={}, ".format(args['batch_size'])
out_str += 'all input token size after padding: {} * {}, '.format(input_token_size, args['batch_size'])
if args['infer_count'] is not None:
out_str += 'all max_output_token_size: {} * {}'.format(args['infer_count'], args['batch_size'])
log.info(out_str)

max_rss_mem_consumption = ''
max_uss_mem_consumption = ''
max_shared_mem_consumption = ''
if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.start_collect_memory_consumption()
max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count']
streamer.reset()
start = time.perf_counter()
generated_tokens = model.generate(input_data, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], streamer=streamer).tokens
end = time.perf_counter()
if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2:
mem_consumption.end_collect_momory_consumption()
max_rss_mem_consumption, max_shared_mem_consumption, max_uss_mem_consumption = mem_consumption.get_max_memory_consumption()
mem_consumption.clear_max_memory_consumption()

generation_time = end - start
tok_decode_start = time.perf_counter()
generated_text = pipe_tokenizer.decode(generated_tokens)
tok_decode_end = time.perf_counter()
tok_decode_time = (tok_decode_end - tok_decode_start) * 1000
# Only text_gen need to minus length of input_data, because generated_text may include input_text
num_tokens = 0
result_md5_list = []
for bs_idx in range(args['batch_size']):
generated_text_len = len(generated_tokens[bs_idx])
num_tokens += generated_text_len
if generated_text_len > max_gen_tokens:
log.error('Output token size is over max output token size!')
result_text = generated_text[bs_idx]
if args["output_dir"] is not None:
llm_bench_utils.output_file.output_gen_text(result_text, args, model_precision, prompt_index, num, bs_idx, proc_id)
result_md5_list.append(hashlib.new("md5", result_text.encode(), usedforsecurity=False).hexdigest())
if len(md5_list[num]) == 0:
md5_list[num] = {prompt_index : result_md5_list}
else:
md5_list[num][prompt_index] = result_md5_list
per_token_time = generation_time * 1000 / (num_tokens / args['batch_size'])
tm_list = streamer.get_time_list()
log.debug('latency of all tokens:')
[log.debug('[{}]{:.4f}'.format(idx, tm)) for idx, tm in enumerate(tm_list)]
iter_data = gen_iterate_data(
num,
input_token_size * args['batch_size'],
len(tm_list),
num_tokens,
generation_time,
per_token_time,
result_md5_list,
max_rss_mem=max_rss_mem_consumption,
max_shared_mem=max_shared_mem_consumption,
max_uss_mem=max_uss_mem_consumption,
prompt_idx=prompt_index,
tokenization_time=(tok_encode_time, tok_decode_time)
)
iter_data_list.append(iter_data)
llm_bench_utils.metrics_print.print_metrics(
num,
iter_data,
tm_list,
[],
warm_up=(num == 0),
max_rss_mem=max_rss_mem_consumption,
max_shared_mem=max_shared_mem_consumption,
max_uss_mem=max_uss_mem_consumption,
tokenization_time=(tok_encode_time, tok_decode_time),
batch_size=args['batch_size']
)
if num > 0:
prev_md5 = md5_list[num - 1][prompt_index]
if result_md5_list != prev_md5:
log.warning(f"[{num}] Prompt[{prompt_index}]'s md5 {result_md5_list} "
f"is different from md5 of the {num - 1} iteration {prev_md5}")
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
if num == 1:
# if the device is CPU, throw exception
if args['devices'].lower().startswith('cpu') is True:
assert (result_md5_list == prev_md5)
else:
# throw exception
assert (result_md5_list == prev_md5)
else:
llm_bench_utils.metrics_print.print_generated(num, warm_up=(num == 0), generated=generated_text[0])
streamer.reset()


def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, streamer, model_precision, proc_id):
set_seed(args['seed'])
input_text_list = [input_text] * args['batch_size']
Expand Down Expand Up @@ -341,7 +447,12 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters
f'prompt nums: {len(text_list)}, prompt idx: {prompt_idx_list}')

# if num_iters == 0, just output warm-up data
text_gen_fn = run_text_generation if not use_genai else run_text_generation_genai
if not use_genai:
text_gen_fn = run_text_generation
elif bench_hook is not None:
text_gen_fn = run_text_generation_genai_with_streamer
eaidova marked this conversation as resolved.
Show resolved Hide resolved
else:
text_gen_fn = run_text_generation_genai
proc_id = os.getpid()
if args['subsequent'] is False:
for num in range(num_iters + 1):
Expand Down
34 changes: 32 additions & 2 deletions llm_bench/python/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def build_ov_tokenizer(hf_tokenizer):
try:
from openvino_tokenizers import convert_tokenizer
except ImportError:
log.warn("OV Tokenizer is unavailable, tokenizer conversion will be skipped")
log.warning("OV Tokenizer is unavailable, tokenizer conversion will be skipped")
return hf_tokenizer

ov_tokenizer, ov_detokenizer = convert_tokenizer(hf_tokenizer, with_detokenizer=True)
Expand Down Expand Up @@ -191,7 +191,37 @@ def create_genai_text_gen_model(model_path, device, ov_config, **kwargs):
end = time.perf_counter()
log.info(f'Pipeline initialization time: {end - start:.2f}s')

return llm_pipe, tokenizer, end - start, None, True
class TokenStreamer(openvino_genai.StreamerBase):
def __init__(self, tokenizer):
openvino_genai.StreamerBase.__init__(self)
self.tokenizer = tokenizer
self.token_generation_time = []
self.generated_tokens = []
self.start_time = time.perf_counter()

def put(self, token_id):
self.token_generation_time.append(time.perf_counter() - self.start_time)
self.generated_tokens.append(token_id)
self.start_time = time.perf_counter()
return False

def reset(self):
self.token_generation_time = []
self.generated_tokens = []
self.start_time = time.perf_counter()

def end(self):
pass

def get_tokens(self):
return self.generated_tokens

def get_time_list(self):
return self.token_generation_time

andrei-kochin marked this conversation as resolved.
Show resolved Hide resolved
streamer = TokenStreamer(llm_pipe.get_tokenizer()) if "NPU" in device.upper() else None

return llm_pipe, tokenizer, end - start, streamer, True


def convert_ov_tokenizer(tokenizer_path):
Expand Down
Loading