Skip to content

Commit

Permalink
Genai/optimum support streaming output (openvinotoolkit#1290)
Browse files Browse the repository at this point in the history
Support chunk streaming mode, mainly to reduce the number of decode
calls, thereby improving performance
  • Loading branch information
zhaohb authored Dec 5, 2024
1 parent 3ca509f commit d294db9
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 30 deletions.
29 changes: 25 additions & 4 deletions tools/llm_bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def get_argprser():
help='Stop the generation even if output token size does not achieve infer_count or max token size ({DEFAULT_OUTPUT_TOKEN_SIZE}}).'
)
parser.add_argument('--set_torch_thread', default=0, type=num_infer_count_type, help='Set the number of Torch thread. ')
parser.add_argument('-tl', '--tokens_len', type=int, required=False, help='The length of tokens print each time in streaming mode, chunk streaming.')
parser.add_argument('--streaming', action='store_true', help='Set whether to use streaming mode, only applicable to LLM.')

return parser.parse_args()

Expand All @@ -170,10 +172,23 @@ def get_argprser():

def main():
logging_kwargs = {"encoding": "utf-8"} if sys.version_info[1] > 8 else {}
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=os.environ.get("LOGLEVEL", log.INFO), stream=sys.stdout, **logging_kwargs)
log.basicConfig(
format='[ %(levelname)s ] %(message)s',
level=os.environ.get("LOGLEVEL", log.INFO),
stream=sys.stdout,
**logging_kwargs
)
args = get_argprser()
model_path, framework, model_args, model_name = llm_bench_utils.model_utils.analyze_args(args)

if args.tokens_len is not None and not args.streaming:
log.error("--tokens_len requires --streaming to be set.")
exit(1)
if args.streaming and args.tokens_len is None:
log.error("--streaming requires --tokens_len to be set.")
exit(1)
model_path, framework, model_args, model_name = (
llm_bench_utils.model_utils.analyze_args(args)
)
# Set the device for running OpenVINO backend for torch.compile()
if model_args['torch_compile_backend']:
ov_torch_backend_device = str(args.device)
Expand Down Expand Up @@ -208,8 +223,14 @@ def main():
if args.memory_consumption:
mem_consumption.start_collect_mem_consumption_thread()
try:
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
model_path, framework, args.device, model_args, args.num_iters, mem_consumption)
if model_args['use_case'] in ['text_gen', 'code_gen']:
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
model_path, framework, args.device, args.tokens_len, args.streaming, model_args,
args.num_iters, mem_consumption)
else:
iter_data_list, pretrain_time, iter_timestamp = CASE_TO_BENCH[model_args['use_case']](
model_path, framework, args.device, model_args, args.num_iters,
mem_consumption)
if args.report is not None or args.report_json is not None:
model_precision = ''
if framework == 'ov':
Expand Down
240 changes: 238 additions & 2 deletions tools/llm_bench/llm_bench_utils/ov_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from transformers import AutoConfig, AutoProcessor
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
from openvino.runtime import Core
import openvino as ov
import logging as log
Expand All @@ -11,9 +11,17 @@
import json
import types
from llm_bench_utils.hook_common import get_bench_hook
from llm_bench_utils.config_class import OV_MODEL_CLASSES_MAPPING, TOKENIZE_CLASSES_MAPPING, DEFAULT_MODEL_CLASSES, IMAGE_GEN_CLS
from llm_bench_utils.config_class import (
OV_MODEL_CLASSES_MAPPING,
TOKENIZE_CLASSES_MAPPING,
DEFAULT_MODEL_CLASSES,
IMAGE_GEN_CLS
)
import openvino.runtime.opset13 as opset
from transformers import pipeline
import openvino_genai as ov_genai
import queue
from transformers.generation.streamers import BaseStreamer


def generate_simplified(self, *args, **kwargs):
Expand Down Expand Up @@ -525,3 +533,231 @@ def is_genai_available(log_msg=False):
log.warning(ex)
return False
return True


class GenaiChunkStreamer(ov_genai.StreamerBase):
"""
A custom streamer class for handling token streaming and detokenization with buffering.
Attributes:
tokenizer (Tokenizer): The tokenizer used for encoding and decoding tokens.
tokens_cache (list): A buffer to accumulate tokens for detokenization.
text_queue (Queue): A synchronized queue for storing decoded text chunks.
print_len (int): The length of the printed text to manage incremental decoding.
"""

def __init__(self, tokenizer, tokens_len=1):
"""
Initializes the IterableStreamer with the given tokenizer.
Args:
tokenizer (Tokenizer): The tokenizer to use for encoding and decoding tokens.
"""
super().__init__()
self.tokenizer = tokenizer
self.tokens_cache = []
self.text_queue = queue.Queue()
self.print_len = 0
self.tokens_len = tokens_len

def __iter__(self):
"""
Returns the iterator object itself.
"""
return self

def __next__(self):
"""
Returns the next value from the text queue.
Returns:
str: The next decoded text chunk.
Raises:
StopIteration: If there are no more elements in the queue.
"""
value = self.text_queue.get() # get() will be blocked until a token is available.
if value is None:
raise StopIteration
return value

def get_stop_flag(self):
"""
Checks whether the generation process should be stopped.
Returns:
bool: Always returns False in this implementation.
"""
return False

def put_word(self, word: str):
"""
Puts a word into the text queue.
Args:
word (str): The word to put into the queue.
"""
self.text_queue.put(word)

def put(self, token_id: int) -> bool:
"""
Processes a token and manages the decoding buffer. Adds decoded text to the queue.
Args:
token_id (int): The token_id to process.
Returns:
bool: True if generation should be stopped, False otherwise.
"""
self.tokens_cache.append(token_id)
if len(self.tokens_cache) % self.tokens_len == 0:
text = self.tokenizer.decode(self.tokens_cache)

word = ''
if len(text) > self.print_len and '\n' == text[-1]:
# Flush the cache after the new line symbol.
word = text[self.print_len:]
self.tokens_cache = []
self.print_len = 0
elif len(text) >= 3 and text[-3:] == chr(65533):
# Don't print incomplete text.
pass
elif len(text) > self.print_len:
# It is possible to have a shorter text after adding new token.
# Print to output only if text lengh is increaesed.
word = text[self.print_len:]
self.print_len = len(text)
self.put_word(word)

if self.get_stop_flag():
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
else:
return False # False means continue generation
else:
return False

def end(self):
"""
Flushes residual tokens from the buffer and puts a None value in the queue to signal the end.
"""
text = self.tokenizer.decode(self.tokens_cache)
if len(text) > self.print_len:
word = text[self.print_len:]
self.put_word(word)
self.tokens_cache = []
self.print_len = 0
self.put_word(None)


class OptimumChunkStreamer(BaseStreamer):
"""
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
<Tip warning={true}>
The API for the streamer classes is still under development and may change in the future.
</Tip>
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
Examples:
```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
>>> streamer = TextStreamer(tok)
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```
"""
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False,
tokens_len: int = 1, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
self.tokens_len = tokens_len

def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
if len(self.token_cache) % self.tokens_len == 0:
text = self.tokenizer.decode(
self.token_cache, **self.decode_kwargs
)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len:]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len:]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len: text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)

def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(
self.token_cache, **self.decode_kwargs
)
printable_text = text[self.print_len:]
self.token_cache = []
self.print_len = 0
else:
printable_text = ""
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_text, stream_end=True)

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
print(text, flush=True, end="" if not stream_end else None)

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
Loading

0 comments on commit d294db9

Please sign in to comment.