From d294db9469c4795bed24ebfb1318cc68adc682e0 Mon Sep 17 00:00:00 2001 From: zhaohongbo Date: Thu, 5 Dec 2024 17:03:39 +0800 Subject: [PATCH] Genai/optimum support streaming output (#1290) Support chunk streaming mode, mainly to reduce the number of decode calls, thereby improving performance --- tools/llm_bench/benchmark.py | 29 ++- tools/llm_bench/llm_bench_utils/ov_utils.py | 240 +++++++++++++++++++- tools/llm_bench/task/text_generation.py | 89 ++++++-- 3 files changed, 328 insertions(+), 30 deletions(-) diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index fe5068b009..bd5a5716a7 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -155,6 +155,8 @@ def get_argprser(): help='Stop the generation even if output token size does not achieve infer_count or max token size ({DEFAULT_OUTPUT_TOKEN_SIZE}}).' ) parser.add_argument('--set_torch_thread', default=0, type=num_infer_count_type, help='Set the number of Torch thread. ') + parser.add_argument('-tl', '--tokens_len', type=int, required=False, help='The length of tokens print each time in streaming mode, chunk streaming.') + parser.add_argument('--streaming', action='store_true', help='Set whether to use streaming mode, only applicable to LLM.') return parser.parse_args() @@ -170,10 +172,23 @@ def get_argprser(): def main(): logging_kwargs = {"encoding": "utf-8"} if sys.version_info[1] > 8 else {} - log.basicConfig(format='[ %(levelname)s ] %(message)s', level=os.environ.get("LOGLEVEL", log.INFO), stream=sys.stdout, **logging_kwargs) + log.basicConfig( + format='[ %(levelname)s ] %(message)s', + level=os.environ.get("LOGLEVEL", log.INFO), + stream=sys.stdout, + **logging_kwargs + ) args = get_argprser() - model_path, framework, model_args, model_name = llm_bench_utils.model_utils.analyze_args(args) + if args.tokens_len is not None and not args.streaming: + log.error("--tokens_len requires --streaming to be set.") + exit(1) + if args.streaming and args.tokens_len is None: + log.error("--streaming requires --tokens_len to be set.") + exit(1) + model_path, framework, model_args, model_name = ( + llm_bench_utils.model_utils.analyze_args(args) + ) # Set the device for running OpenVINO backend for torch.compile() if model_args['torch_compile_backend']: ov_torch_backend_device = str(args.device) @@ -208,8 +223,14 @@ def main(): if args.memory_consumption: mem_consumption.start_collect_mem_consumption_thread() try: - iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']]( - model_path, framework, args.device, model_args, args.num_iters, mem_consumption) + if model_args['use_case'] in ['text_gen', 'code_gen']: + iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']]( + model_path, framework, args.device, args.tokens_len, args.streaming, model_args, + args.num_iters, mem_consumption) + else: + iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']]( + model_path, framework, args.device, model_args, args.num_iters, + mem_consumption) if args.report is not None or args.report_json is not None: model_precision = '' if framework == 'ov': diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index 9ebd1363e3..c5fa422824 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -2,7 +2,7 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from pathlib import Path -from transformers import AutoConfig, AutoProcessor +from transformers import AutoConfig, AutoProcessor, AutoTokenizer from openvino.runtime import Core import openvino as ov import logging as log @@ -11,9 +11,17 @@ import json import types from llm_bench_utils.hook_common import get_bench_hook -from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS +from llm_bench_utils.config_class import ( + OV_MODEL_CLASSES_MAPPING, + TOKENIZE_CLASSES_MAPPING, + DEFAULT_MODEL_CLASSES, + IMAGE_GEN_CLS +) import openvino.runtime.opset13 as opset from transformers import pipeline +import openvino_genai as ov_genai +import queue +from transformers.generation.streamers import BaseStreamer def generate_simplified(self, *args, **kwargs): @@ -525,3 +533,231 @@ def is_genai_available(log_msg=False): log.warning(ex) return False return True + + +class GenaiChunkStreamer(ov_genai.StreamerBase): + """ + A custom streamer class for handling token streaming and detokenization with buffering. + + Attributes: + tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens. + tokens_cache (list): A buffer to accumulate tokens for detokenization. + text_queue (Queue): A synchronized queue for storing decoded text chunks. + print_len (int): The length of the printed text to manage incremental decoding. + """ + + def __init__(self, tokenizer, tokens_len=1): + """ + Initializes the IterableStreamer with the given tokenizer. + + Args: + tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens. + """ + super().__init__() + self.tokenizer = tokenizer + self.tokens_cache = [] + self.text_queue = queue.Queue() + self.print_len = 0 + self.tokens_len = tokens_len + + def __iter__(self): + """ + Returns the iterator object itself. + """ + return self + + def __next__(self): + """ + Returns the next value from the text queue. + + Returns: + str: The next decoded text chunk. + + Raises: + StopIteration: If there are no more elements in the queue. + """ + value = self.text_queue.get() # get() will be blocked until a token is available. + if value is None: + raise StopIteration + return value + + def get_stop_flag(self): + """ + Checks whether the generation process should be stopped. + + Returns: + bool: Always returns False in this implementation. + """ + return False + + def put_word(self, word: str): + """ + Puts a word into the text queue. + + Args: + word (str): The word to put into the queue. + """ + self.text_queue.put(word) + + def put(self, token_id: int) -> bool: + """ + Processes a token and manages the decoding buffer. Adds decoded text to the queue. + + Args: + token_id (int): The token_id to process. + + Returns: + bool: True if generation should be stopped, False otherwise. + """ + self.tokens_cache.append(token_id) + if len(self.tokens_cache) % self.tokens_len == 0: + text = self.tokenizer.decode(self.tokens_cache) + + word = '' + if len(text) > self.print_len and '\n' == text[-1]: + # Flush the cache after the new line symbol. + word = text[self.print_len:] + self.tokens_cache = [] + self.print_len = 0 + elif len(text) >= 3 and text[-3:] == chr(65533): + # Don't print incomplete text. + pass + elif len(text) > self.print_len: + # It is possible to have a shorter text after adding new token. + # Print to output only if text lengh is increaesed. + word = text[self.print_len:] + self.print_len = len(text) + self.put_word(word) + + if self.get_stop_flag(): + # When generation is stopped from streamer then end is not called, need to call it here manually. + self.end() + return True # True means stop generation + else: + return False # False means continue generation + else: + return False + + def end(self): + """ + Flushes residual tokens from the buffer and puts a None value in the queue to signal the end. + """ + text = self.tokenizer.decode(self.tokens_cache) + if len(text) > self.print_len: + word = text[self.print_len:] + self.put_word(word) + self.tokens_cache = [] + self.print_len = 0 + self.put_word(None) + + +class OptimumChunkStreamer(BaseStreamer): + """ + Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. + + The API for the streamer classes is still under development and may change in the future. + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + Examples: + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer + >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") + >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") + >>> streamer = TextStreamer(tok) + >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. + >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) + An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, + ``` + """ + def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, + tokens_len: int = 1, **decode_kwargs): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.decode_kwargs = decode_kwargs + # variables used in the streaming process + self.token_cache = [] + self.print_len = 0 + self.next_tokens_are_prompt = True + self.tokens_len = tokens_len + + def put(self, value): + """ + Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. + """ + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("TextStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + # Add the new token to the cache and decodes the entire thing. + self.token_cache.extend(value.tolist()) + if len(self.token_cache) % self.tokens_len == 0: + text = self.tokenizer.decode( + self.token_cache, **self.decode_kwargs + ) + # After the symbol for a new line, we flush the cache. + if text.endswith("\n"): + printable_text = text[self.print_len:] + self.token_cache = [] + self.print_len = 0 + # If the last token is a CJK character, we print the characters. + elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len:] + self.print_len += len(printable_text) + # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, + # which may change with the subsequent token -- there are probably smarter ways to do this!) + else: + printable_text = text[self.print_len: text.rfind(" ") + 1] + self.print_len += len(printable_text) + self.on_finalized_text(printable_text) + + def end(self): + """Flushes any remaining cache and prints a newline to stdout.""" + # Flush the cache, if it exists + if len(self.token_cache) > 0: + text = self.tokenizer.decode( + self.token_cache, **self.decode_kwargs + ) + printable_text = text[self.print_len:] + self.token_cache = [] + self.print_len = 0 + else: + printable_text = "" + self.next_tokens_are_prompt = True + self.on_finalized_text(printable_text, stream_end=True) + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Prints the new text to stdout. If the stream is ending, also prints a newline.""" + print(text, flush=True, end="" if not stream_end else None) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + return False diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index 63ce0d8cae..5fbf950d2c 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -10,9 +10,11 @@ import llm_bench_utils.model_utils as model_utils import numpy as np import hashlib +import threading import llm_bench_utils.metrics_print as metrics_print import llm_bench_utils.output_csv from transformers import set_seed +from llm_bench_utils.ov_utils import GenaiChunkStreamer, OptimumChunkStreamer import llm_bench_utils.output_json import llm_bench_utils.output_file import llm_bench_utils.gen_output_data as gen_output_data @@ -24,7 +26,7 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, md5_list, - prompt_index, bench_hook, model_precision, proc_id, mem_consumption): + prompt_index, bench_hook, tokens_len, streaming, model_precision, proc_id, mem_consumption): set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: @@ -53,25 +55,48 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, mem_consumption.start_collect_memory_consumption() max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count'] start = time.perf_counter() - if args['infer_count'] is not None and args['end_token_stopping'] is False: - model.generation_config.eos_token_id = None - model.config.eos_token_id = None - result = model.generate( - **input_data, - max_new_tokens=int(max_gen_tokens), - num_beams=args['num_beams'], - use_cache=True, - eos_token_id=None, - do_sample=False - ) + if streaming: + if args['infer_count'] is not None and args['end_token_stopping'] is False: + model.generation_config.eos_token_id = None + model.config.eos_token_id = None + result = model.generate( + **input_data, + max_new_tokens=int(max_gen_tokens), + num_beams=args['num_beams'], + use_cache=True, + eos_token_id=None, + do_sample=False, + streamer=OptimumChunkStreamer(tokenizer, tokens_len=tokens_len) + ) + else: + result = model.generate( + **input_data, + max_new_tokens=int(max_gen_tokens), + num_beams=args['num_beams'], + use_cache=True, + do_sample=False, + streamer=OptimumChunkStreamer(tokenizer, tokens_len=tokens_len) + ) else: - result = model.generate( - **input_data, - max_new_tokens=int(max_gen_tokens), - num_beams=args['num_beams'], - use_cache=True, - do_sample=False - ) + if args['infer_count'] is not None and args['end_token_stopping'] is False: + model.generation_config.eos_token_id = None + model.config.eos_token_id = None + result = model.generate( + **input_data, + max_new_tokens=int(max_gen_tokens), + num_beams=args['num_beams'], + use_cache=True, + eos_token_id=None, + do_sample=False + ) + else: + result = model.generate( + **input_data, + max_new_tokens=int(max_gen_tokens), + num_beams=args['num_beams'], + use_cache=True, + do_sample=False + ) end = time.perf_counter() if (args['mem_consumption'] == 1 and num == 0) or args['mem_consumption'] == 2: mem_consumption.end_collect_momory_consumption() @@ -172,7 +197,7 @@ def run_text_generation(input_text, num, model, tokenizer, args, iter_data_list, def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data_list, md5_list, prompt_index, - streamer, model_precision, proc_id, mem_consumption): + streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: @@ -208,7 +233,23 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data config_info += f" assistant_confidence_threshold {gen_config.assistant_confidence_threshold}" log.info(config_info) start = time.perf_counter() - generation_result = model.generate(input_text_list, gen_config) + if streaming: + text_print_streamer = GenaiChunkStreamer(model.get_tokenizer(), tokens_len) + + def token_printer(): + # Getting next elements from iterable will be blocked until a new token is available. + for word in text_print_streamer: + print(word, end='', flush=True) + printer_thread = threading.Thread(target=token_printer, daemon=True) + printer_thread.start() + generation_result = model.generate( + input_text_list, + gen_config, + streamer=text_print_streamer + ) + printer_thread.join() + else: + generation_result = model.generate(input_text_list, gen_config) end = time.perf_counter() generated_text = generation_result.texts perf_metrics = generation_result.perf_metrics @@ -300,7 +341,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, args, iter_data_list, md5_list, - prompt_index, streamer, model_precision, proc_id, mem_consumption): + prompt_index, streamer, tokens_len, streaming, model_precision, proc_id, mem_consumption): set_seed(args['seed']) input_text_list = [input_text] * args['batch_size'] if args["output_dir"] is not None and num == 0: @@ -422,7 +463,7 @@ def run_text_generation_genai_with_stream(input_text, num, model, tokenizer, arg streamer.reset() -def run_text_generation_benchmark(model_path, framework, device, args, num_iters, mem_consumption): +def run_text_generation_benchmark(model_path, framework, device, tokens_len, streaming, args, num_iters, mem_consumption): model, tokenizer, pretrain_time, bench_hook, use_genai = FW_UTILS[framework].create_text_gen_model(model_path, device, **args) model_precision = model_utils.get_model_precision(model_path.parts) iter_data_list = [] @@ -461,7 +502,7 @@ def run_text_generation_benchmark(model_path, framework, device, args, num_iters log.info(f'[warm-up][P{p_idx}] Input text: {input_text}') iter_timestamp[num][p_idx]['start'] = datetime.datetime.now().isoformat() text_gen_fn(input_text, num, model, tokenizer, args, iter_data_list, md5_list, - p_idx, bench_hook, model_precision, proc_id, mem_consumption) + p_idx, bench_hook, tokens_len, streaming, model_precision, proc_id, mem_consumption) iter_timestamp[num][p_idx]['end'] = datetime.datetime.now().isoformat() prefix = '[warm-up]' if num == 0 else '[{}]'.format(num) log.info(f"{prefix}[P{p_idx}] start: {iter_timestamp[num][p_idx]['start']}, end: {iter_timestamp[num][p_idx]['end']}")